# Import and setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import few

from few.trajectory.inspiral import EMRIInspiral
from few.trajectory.ode import SchwarzEccFlux, KerrEccEqFlux
# from few.amplitude.romannet import RomanAmplitude
from few.amplitude.ampinterp2d import AmpInterpKerrEccEq
from few.summation.interpolatedmodesum import InterpolatedModeSum


from few.utils.ylm import GetYlms
from few.utils.modeselector import ModeSelector
from few.summation.interpolatedmodesum import CubicSplineInterpolant
from few import get_file_manager

from few.waveform import (
    FastKerrEccentricEquatorialFlux,
    FastSchwarzschildEccentricFlux, 
    SlowSchwarzschildEccentricFlux, 
    Pn5AAKWaveform,
    GenerateEMRIWaveform
)

from few.utils.geodesic import get_fundamental_frequencies

import os
import sys

# Change to the desired directory
os.chdir('/nfs/home/svu/e1498138/localgit/FEWNEW/work/')

# Add it to Python path
sys.path.insert(0, '/nfs/home/svu/e1498138/localgit/FEWNEW/work/')

import GWfuncs
import gc
import pickle
import os
import cupy as cp
import multiprocessing as mp
from multiprocessing import Queue, Process
from functools import partial
from SNR_tutorial_utils import LISA_Noise
from lisatools.sensitivity import *

from few.utils.constants import YRSID_SI, Gpc, MRSUN_SI



# import pandas as pd
# tune few configuration
cfg_set = few.get_config_setter(reset=True)
cfg_set.set_log_level("info");

In [None]:
for backend in ["cpu", "cuda11x", "cuda12x", "cuda", "gpu"]: 
    print(f" - Backend '{backend}': {"available" if few.has_backend(backend) else "unavailable"}")  

In [None]:
N_traj = 5000

In [None]:
# Parameters
m1 = 1e6 #M
m2 = 1e1 #mu
a = 0.5
p0 = 9.5
e0 = 0.2
theta = np.pi / 3.0 
phi = np.pi / 4.0  
dt = 10.0
T = 1
xI0 = 1.0 
#in the paper xI0 = 0.866, but that would be non-equatorial case

use_gpu = True 
traj = EMRIInspiral(func=KerrEccEqFlux, force_backend="cuda12x", use_gpu=use_gpu, npoints = N_traj) #theres npoints flag here
amp = AmpInterpKerrEccEq(force_backend="cuda12x") # default lmax=10, nmax=55
interpolate_mode_sum = InterpolatedModeSum(force_backend="cuda12x")
ylm_gen = GetYlms(include_minus_m=False, force_backend="cuda12x")

# Generate waveform

In [None]:
delta_T = T*YRSID_SI/5000 # change amount of points here 
# im using 5000 as an extreme example to show the NaNs and non-zero values of the resulting waveform 
# for lower amount of points, the NaNs dominate
delta_T 

In [None]:
# %%time 
# # Calc trajectory
# (t, p, e, x, Phi_phi, Phi_theta, Phi_r) = traj(m1, m2, a, p0, e0, xI0, T=T, dt=delta_T, upsample=True)#upsampling=True, fix_t=True

# # Get amplitudes along trajectory
# teuk_modes = amp(a, p, e, x)

# # Get Ylms
# ylms = ylm_gen(amp.unique_l, amp.unique_m, theta, phi).copy()[amp.inverse_lm]

# cp.cuda.Stream.null.synchronize()

In [None]:
%%time 
# Calc trajectory

# with upsampling -> creates effect of dense stepping
# do consider if fix_T is needed 
(t_u, p_u, e_u, x_u, Phi_phi_u, Phi_theta_u, Phi_r_u) = traj(m1, m2, a, p0, e0, xI0, T=T, dt=delta_T, upsample=True) 
spline_t_u = traj.integrator_spline_t.copy()  
spline_coeff_u = traj.integrator_spline_phase_coeff[:, [0, 2]].copy()

# without upsampling (default case)
(t_f, p_f, e_f, x_f, Phi_phi_f, Phi_theta_f, Phi_r_f) = traj(m1, m2, a, p0, e0, xI0, T=T, dt=dt)
spline_t_f = traj.integrator_spline_t.copy() 
spline_coeff_f = traj.integrator_spline_phase_coeff[:, [0, 2]].copy()

# Get amplitudes along trajectory
teuk_modes_u = amp(a, p_u, e_u, x_u)
teuk_modes_f = amp(a, p_f, e_f, x_f)

# Get Ylms
ylms = ylm_gen(amp.unique_l, amp.unique_m, theta, phi).copy()[amp.inverse_lm]

cp.cuda.Stream.null.synchronize()

In [None]:
# Trajectory points: shape (N_points, 7) - t, p, e, x, Phi_phi, Phi_theta, Phi_r
trajectory_points = traj.trajectory  

# Spline time cache: shape (N_points,)
spline_times = traj.integrator_spline_t

# Spline coefficients: shape (N_points-1, 6, 8)
spline_coeffs = traj.integrator_spline_coeff

In [None]:
trajectory_points.shape, spline_times.shape, spline_coeffs.shape

In [None]:
plt.plot(np.diff(t_u))

In [None]:
np.sum(np.isnan(teuk_modes_u[:,0]))

In [None]:
teuk_modes_u[:,0]

In [None]:
teuk_modes_f[:,0]

In [None]:
# %%time

# t_gpu = cp.asarray(t)

# # need to prepare arrays for sum with all modes due to +/- m setup
# ls = amp.l_arr[: teuk_modes.shape[1]]
# ms = amp.m_arr[: teuk_modes.shape[1]]
# ns = amp.n_arr[: teuk_modes.shape[1]]

# keep_modes = np.arange(teuk_modes.shape[1])
# temp2 = keep_modes * (keep_modes < amp.num_m0) + (keep_modes + amp.num_m_1_up) * (
#     keep_modes >= amp.num_m0
# ) # amp.num_m0 gives number of modes with m == 0, amp.num_m_1_up gives number of modes with m > 0

# ylmkeep = np.concatenate([keep_modes, temp2])
# ylms_in = ylms[ylmkeep]
# teuk_modes_in = teuk_modes

# cp.cuda.Stream.null.synchronize()

In [None]:
%%time

t_u_gpu = cp.asarray(t_u)
t_f_gpu = cp.asarray(t_f)


# need to prepare arrays for sum with all modes due to +/- m setup
ls_u = amp.l_arr[: teuk_modes_u.shape[1]]
ls_f = amp.l_arr[: teuk_modes_f.shape[1]]

ms_u = amp.m_arr[: teuk_modes_u.shape[1]]
ms_f = amp.m_arr[: teuk_modes_f.shape[1]]

ns_u = amp.n_arr[: teuk_modes_u.shape[1]]
ns_f = amp.n_arr[: teuk_modes_f.shape[1]]

keep_modes_u = np.arange(teuk_modes_u.shape[1])
keep_modes_f = np.arange(teuk_modes_f.shape[1])

temp2_u = keep_modes_u * (keep_modes_u < amp.num_m0) + (keep_modes_u + amp.num_m_1_up) * (
    keep_modes_u >= amp.num_m0
) 

temp2_f = keep_modes_f * (keep_modes_f < amp.num_m0) + (keep_modes_f + amp.num_m_1_up) * (
    keep_modes_f >= amp.num_m0
) 

ylmkeep_u = np.concatenate([keep_modes_u, temp2_u])
ylmkeep_f = np.concatenate([keep_modes_f, temp2_f])

ylms_in_u = ylms[ylmkeep_u]
ylms_in_f = ylms[ylmkeep_f]

teuk_modes_in_u = teuk_modes_u
teuk_modes_in_f = teuk_modes_f


cp.cuda.Stream.null.synchronize()

In [None]:
# %%time

# # perform summation
# waveform1 = interpolate_mode_sum(
#     t_gpu,
#     teuk_modes_in,
#     ylms_in,
#     traj.integrator_spline_t,
#     traj.integrator_spline_phase_coeff[:, [0, 2]],
#     ls,
#     ms,
#     ns,
#     dt=delta_T,
#     T=T,
# )

# cp.cuda.Stream.null.synchronize()

In [None]:
%%time

# perform summation
waveform1_u = interpolate_mode_sum(
    t_u_gpu,
    teuk_modes_in_u,
    ylms_in_u,
    spline_t_u,
    spline_coeff_u,
    ls_u,
    ms_u,
    ns_u,
    dt=dt, #using finer (small delta)t for the waveform
    T=T,
)

cp.cuda.Stream.null.synchronize()

In [None]:
%%time

# interpolate only one mode 
waveform1_u_1 = interpolate_mode_sum(
    t_u_gpu,
    teuk_modes_in_u[:, 1:2],
    ylms_in_u[1:2],
    spline_t_u,
    spline_coeff_u,
    ls_u[1:2],
    ms_u[1:2],
    ns_u[1:2],
    dt=dt,
    T=T,
)

cp.cuda.Stream.null.synchronize()

In [None]:
len(waveform1_u)

In [None]:
waveform1_u_1[~np.isnan(waveform1_u_1)]

In [None]:
waveform1_u[~np.isnan(waveform1_u)]

In [None]:
np.sum(np.isnan(waveform1_u_1))

In [None]:
np.sum(~np.isnan(waveform1_u))

In [None]:
# see if the nans appear for the same points
np.array_equal(~np.isnan(waveform1_u), ~np.isnan(waveform1_u_1))

In [None]:
%%time
# just trying to see if we sample with delta_T instead 

# perform summation
waveform1_u_alt = interpolate_mode_sum(
    t_u_gpu,
    teuk_modes_in_u,
    ylms_in_u,
    spline_t_u,
    spline_coeff_u,
    ls_u,
    ms_u,
    ns_u,
    dt=delta_T,
    T=T,
)

cp.cuda.Stream.null.synchronize()

In [None]:
len(waveform1_u_alt)

In [None]:
np.sum(np.isnan(waveform1_u_alt))

# Reference values for the waveform interpmodesum

In [None]:
%%time

# perform summation
waveform1_f = interpolate_mode_sum(
    t_f_gpu,
    teuk_modes_in_f,
    ylms_in_f,
    spline_t_f,
    spline_coeff_f,
    ls_f,
    ms_f,
    ns_f,
    dt=dt,
    T=T,
)

cp.cuda.Stream.null.synchronize()

In [None]:
%%time
# Just one mode
# perform summation
waveform1_f_0 = interpolate_mode_sum(
    t_f_gpu,
    teuk_modes_in_f[:, 1:2],
    ylms_in_f[1:2],
    spline_t_f,
    spline_coeff_f,
    ls_f[1:2],
    ms_f[1:2],
    ns_f[1:2],
    dt=dt,
    T=T,
)

cp.cuda.Stream.null.synchronize()

In [None]:
np.sum(np.isnan(waveform1_f_0))

# Factors, SNRs etc

In [None]:
# For reference vals im using the finer waveform
waveform1 = waveform1_f

N = int(len(waveform1)) 
gwf = GWfuncs.GravWaveAnalysis(N=N,dt=dt)
# Calculate distance dimensionless
dist = 1.0 #Gpc
factor = gwf.dist_factor(dist, m2)
waveform1_scaled = waveform1.get()/factor

In [None]:
hfull_f = gwf.freq_wave(waveform1_scaled)

SNR_ref = gwf.SNR(hfull_f)
print("SNR:", SNR_ref)
print("SNR squared:", SNR_ref**2)

In [None]:
# for the following im going to use the upsampled ver 
N_traj = teuk_modes_u.shape[0]  # number of trajectory points
print("Number of trajectory points:", N_traj)
# delta_T = T_sd / N_traj  # time step in seconds
print("Time step in seconds", delta_T)

In [None]:
# Get mode labels
mode_labels = [f"({l},{m},{n})" for l,m,n in zip(amp.l_arr, amp.m_arr, amp.n_arr)]

# Generate mode frequencies 

Using *get_fundamental_frequencies* instead. TODO: pick only one (CPU/GPU)? which would be better in this case?

In [None]:
OmegaPhi, OmegaTheta, OmegaR = get_fundamental_frequencies(a, p_u, e_u, x_u) #could use GPU here but im running to mismatch probs
#actually whats the benefit of running cpu -> convert gpu vs running all in gpu in the first place?

In [None]:
isinstance(amp.m_arr, cp.ndarray)

In [None]:
OmegaPhi.shape

In [None]:
l_cpu = amp.l_arr.get()
m_cpu = amp.m_arr.get()
n_cpu = amp.n_arr.get()

In [None]:
gw_frequencies_per_mode = []

for idx in range(len(mode_labels)):
    # TODO: do convert this so everything is in either CPU/GPU?
    # l = amp.l_arr[idx]
    # m = amp.m_arr[idx] 
    # n = amp.n_arr[idx]
    
    l = l_cpu[idx]
    m = m_cpu[idx] 
    n = n_cpu[idx]
    
    # Calculate GW frequencies
    # k = 0 for equatorial case
    f_gw = m * OmegaPhi + n * OmegaR
    
    gw_frequencies_per_mode.append(f_gw)

In [None]:
gw_phase_per_mode = []
for idx in range(len(mode_labels)):
    # l = amp.l_arr[idx]
    # m = amp.m_arr[idx] 
    # n = amp.n_arr[idx]

    l = l_cpu[idx]
    m = m_cpu[idx] 
    n = n_cpu[idx]
    
    # Calculate GW phases per mode
    phi_mode = m * Phi_phi_u + n * Phi_r_u
    
    gw_phase_per_mode.append(phi_mode)

# Calculate inner product

In [None]:
idx_i = 1165 # 220
idx_j = 1166 # 221
mode_labels[idx_i], mode_labels[idx_j]

In [None]:
# Get complex amplitudes for the two modes
A0 = teuk_modes_u[:, idx_i]
A1 = teuk_modes_u[:, idx_j]
print("A_0:", A0)
print("A_1:", A1)

In [None]:
# Get sensitivity for the two modes
Sn0 = get_sensitivity(gw_frequencies_per_mode[idx_i], sens_fn=LISASens, return_type="PSD")
Sn1 = get_sensitivity(gw_frequencies_per_mode[idx_j], sens_fn=LISASens, return_type="PSD")

In [None]:
barA0 = A0.get() / np.sqrt(Sn0) #TODO: do convert this in C/GPU only?
barA0 

In [None]:
barA1 = A1.get() / np.sqrt(Sn1)
barA1

In [None]:
phase01 = np.abs(gw_phase_per_mode[idx_i] - gw_phase_per_mode[idx_j]) < 1.0 
phase01

In [None]:
np.sum(phase01)

## Cross-term inner product of (2,2,0) & (2,2,1)

In [None]:
crossprod01 = np.conj(barA0[phase01]) * barA1[phase01]
crossprod01

In [None]:
inner_contrib_01 = np.sum(crossprod01) * delta_T * 1/(factor**2)
np.real(inner_contrib_01)

## Self-term inner product of (2,2,0), (2,2,1)

In [None]:
selfprod00 = np.conj(barA0)*barA0 #barA0**2
selfprod00

In [None]:
inner_contrib_00 = np.sum(np.real(selfprod00)) * delta_T * 1/(factor**2)
inner_contrib_00

In [None]:
inner_contrib_11 = np.sum(np.real(np.conj(barA1)*barA1)) * delta_T * 1/(factor**2)
inner_contrib_11

In [None]:
np.abs(inner_contrib_01 / np.sqrt(inner_contrib_00 * inner_contrib_11) )

# Creating Inner Product Function

In [None]:
def calc_inner(idx_i, idx_j, teuk_modes, amp, freqs, phases, delta_T, factor):
    # Obtain the lmn-s
    l_i = amp.l_arr[idx_i]
    m_i = amp.m_arr[idx_i]
    n_i = amp.n_arr[idx_i]

    l_j = amp.l_arr[idx_j]
    m_j = amp.m_arr[idx_j]
    n_j = amp.n_arr[idx_j]

    # Get Teukolsky modes
    # Check if negative m 
    if m_i >= 0:
        A_i = teuk_modes[:, idx_i]

    elif m_i < 0:
        pos_m_mask_i = (amp.l_arr == l_i) & (amp.m_arr == -m_i) & (amp.n_arr == n_i)
        pos_m_idx_i = np.where(pos_m_mask_i)[0]
        A_i_pos = teuk_modes[:, pos_m_idx_i]
        A_i = (-1)**l_i * np.conj(A_i_pos)

    if m_j >= 0:
        A_j = teuk_modes[:, idx_j]
        
    elif m_j < 0:
        pos_m_mask_j = (amp.l_arr == l_j) & (amp.m_arr == -m_j) & (amp.n_arr == n_j)
        pos_m_idx_j = np.where(pos_m_mask_j)[0]
        A_j_pos = teuk_modes[:, pos_m_idx_j]
        A_j = (-1)**l_j * np.conj(A_j_pos)

    # Get sensitivity for each mode 
    Sn_i = get_sensitivity(freqs[idx_i], sens_fn=LISASens, return_type="PSD")
    Sn_j = get_sensitivity(freqs[idx_j], sens_fn=LISASens, return_type="PSD")

    # Get noise-weighted amplitudes
    bar_A_i = A_i.get() / np.sqrt(Sn_i)
    bar_A_j = A_j.get() / np.sqrt(Sn_j)

    # Get phase mask
    phase_mask = np.abs(phases[idx_i] - phases[idx_j]) < 1.0 

    # Calculate product
    prod = np.conj(bar_A_i[phase_mask]) * bar_A_j[phase_mask]

    # Calculate full inner product
    innerprod = np.sum(np.real(prod)) * delta_T * 1/(factor**2)

    return innerprod

In [None]:
# (2,2,0) & (2,2,0)
calc_inner(1165,1165, teuk_modes_u, amp, gw_frequencies_per_mode, gw_phase_per_mode, delta_T, factor)

In [None]:
# (2,2,0) & (2,2,1)
calc_inner(1165,1166, teuk_modes_u, amp, gw_frequencies_per_mode, gw_phase_per_mode, delta_T, factor)

In [None]:
mode_labels[7159]

In [None]:
# (2,2,0) & (2, -2, 0)
calc_inner(1165,7159, teuk_modes_u, amp, gw_frequencies_per_mode, gw_phase_per_mode, delta_T, factor)

# Reference values

In [None]:
indices = [1165, 1166, 7159]

In [None]:
waveform_per_mode = []
for idx in indices:
    l = amp.l_arr[idx]
    m = amp.m_arr[idx]
    n = amp.n_arr[idx]
    print('Mode: ', mode_labels[idx])

    if m >= 0:
        # For m >= 0, directly use the mode
        teuk_modes_single = teuk_modes_f[:, [idx]]
        ylms_single = ylms[[idx]]
        m_arr = amp.m_arr[[idx]]
    else:
        # Finding corresponding m>0 mode instead of mapping
        print('NEGATIVE M MODE')
        pos_m_mask = (amp.l_arr == l) & (amp.m_arr == -m) & (amp.n_arr == n) 
        print(amp.l_arr[pos_m_mask], amp.m_arr[pos_m_mask], amp.n_arr[pos_m_mask])
        pos_m_idx = np.where(pos_m_mask)[0]
        print(pos_m_idx)
        
        teuk_modes_single = (-1)**l * np.conj(teuk_modes_f[:, [pos_m_idx]])
        print(teuk_modes_single)
        # ylms_single = (-1)**(-m) * np.conj(ylms[[pos_m_idx]])
        ylms_single = ylms[[idx]]
        print(ylms_single)
        m_arr = np.abs(amp.m_arr[[idx]])  # To pass positive m 

    waveform = interpolate_mode_sum(
        t_f_gpu,
        teuk_modes_single,
        ylms_single,
        traj.integrator_spline_t,
        traj.integrator_spline_phase_coeff[:, [0, 2]],
        amp.l_arr[[idx]], 
        m_arr,  
        amp.n_arr[[idx]], 
        dt=dt,
        T=T
    )
    waveform_per_mode.append(waveform/factor)

In [None]:
# Convert each waveform to frequency domain
hf_per_mode = [gwf.freq_wave(waveform.get()) for waveform in waveform_per_mode]

In [None]:
# (2,2,0) & (2,2,0)
gwf.inner(hf_per_mode[0], hf_per_mode[0])

In [None]:
# (2,2,0) & (2,2,1)
gwf.inner(hf_per_mode[0], hf_per_mode[1])

In [None]:
# (2,2,0) & (2,-2,0)
gwf.inner(hf_per_mode[0], hf_per_mode[2])