In [1]:
import numpy as np
import matplotlib.pyplot as plt
from numba import cuda, float64, complex128
from numba.cuda import jit as cuda_jit
import math

import few

from few.trajectory.inspiral import EMRIInspiral
from few.trajectory.ode import KerrEccEqFlux
from few.amplitude.ampinterp2d import AmpInterpKerrEccEq
from few.summation.interpolatedmodesum import InterpolatedModeSum 


from few.utils.ylm import GetYlms

from few import get_file_manager

from few.waveform import FastKerrEccentricEquatorialFlux

from few.utils.geodesic import get_fundamental_frequencies

import os
import sys

# Change to the desired directory
os.chdir('/nfs/home/svu/e1498138/localgit/FEWNEW/work/')

# Add it to Python path
sys.path.insert(0, '/nfs/home/svu/e1498138/localgit/FEWNEW/work/')

import GWfuncs
# import gc
# import pickle
import cupy as cp

# tune few configuration
cfg_set = few.get_config_setter(reset=True)
cfg_set.set_log_level("info")

# print("Importing dynesty...")
# import dynesty

<few.utils.globals.ConfigurationSetter at 0x1554cca7d400>

In [2]:
for backend in ["cpu", "cuda11x", "cuda12x", "cuda", "gpu"]: 
    print(f" - Backend '{backend}': {"available" if few.has_backend(backend) else "unavailable"}")  

 - Backend 'cpu': available
 - Backend 'cuda11x': unavailable
 - Backend 'cuda12x': available
 - Backend 'cuda': available
 - Backend 'gpu': available


In [3]:
import numpy as np
try:
    import cupy as cp
except ImportError:
    cp = None
from lisatools.sensitivity import get_sensitivity, LISASens

In [4]:
from few.utils.constants import YRSID_SI, Gpc, MRSUN_SI


In [5]:
cp

<module 'cupy' from '/scratch/e1498138/anaconda3/envs/fewsm/lib/python3.12/site-packages/cupy/__init__.py'>

In [6]:
import numpy as np
try:
    import cupy as cp
except ImportError:
    cp = None
from lisatools.sensitivity import get_sensitivity, LISASens

class ModeSelector:
    """ Class to select modes based on a given threshold. """

    def __init__(self, params, traj, amp, ylm_gen, delta_T, gwf):
        """ Initialize the ModeSelector with the provided parameters. 
        Parameters and some notes:
        - params: List of parameters 
                  [NOTE same order as loglike] 
                  [m1, m2, a, p0, e0, xI0, theta, phi, dist]
        - traj: The trajectory module with delta_T (NOT the finer dt)
                NOTE: easier way would be to access it from waveform_gen
                      but in this case we specify diff args for traj module
        
        - amp: Amplitude module 
        - delta_T: Time step used to generate the Teukolsky modes. 
                    NOT the same as dt used for interpolation
        - gwf: GravWaveAnalysis object 

        TODO: calc gw_freqs and gw_phases inside this class? 
        maybe pass parameter set as a vector instead to simplify things...
        """
        self.params = params
        self.traj = traj
        self.amp = amp
        self.ylm_gen = ylm_gen
        self.delta_T = delta_T
        self.gwf = gwf

        # Caching
        self._traj_data = None
        self._teuk_modes = None
        self._gw_freqs = None
        self._gw_phases = None

    def _calc_traj(self):
        """ Calculate trajectory """
        if self._traj_data is None:
            m1, m2, a, p0, e0, xI0, _, _, _ = self.params
            self._traj_data = self.traj(m1, m2, a, p0, e0, xI0, T=self.gwf.T, dt=self.delta_T, upsample=True)
        
        return self._traj_data


    @property
    def teuk_modes(self):
        """ Calculate Teukolsky modes """
        if self._teuk_modes is None:
            _, p, e, x, _, _, _ = self._calc_traj()
            _, _, a, _, _, _, _, _, _ = self.params
            self._teuk_modes = self.amp(a, p, e, x)
        return self._teuk_modes
    
    @property
    def ylms(self):
        """ Calculate spherical harmonics """
        if not hasattr(self, '_ylms'):
            _, _, _, _, _, _, theta, phi, _ = self.params
            self._ylms = self.ylm_gen(self.amp.unique_l, self.amp.unique_m, theta, phi).copy()[self.amp.inverse_lm]
        return self._ylms

    @property
    def gw_freqs(self):
        """ Calculate GW frequencies """
        if self._gw_freqs is None:
            _, p, e, x, _, _, _ = self._calc_traj()
            _, _, a, _, _, _, _, _, _ = self.params

            # Get fundamental frequencies
            # TODO make it GPU compatible (i always get a problem???)
            OmegaPhi, _, OmegaR = get_fundamental_frequencies(a, p, e, x)

            gw_frequencies_per_mode = []
            for idx in range(len(self.amp.l_arr)):
                m = self.amp.m_arr[idx]
                n = self.amp.n_arr[idx]

                m = m.get() if isinstance(m, cp.ndarray) else m
                n = n.get() if isinstance(n, cp.ndarray) else n

                # Calculate GW frequencies
                f_gw = m * OmegaPhi + n * OmegaR
                gw_frequencies_per_mode.append(f_gw)

            self._gw_freqs = gw_frequencies_per_mode
        return self._gw_freqs

    @property
    def gw_phases(self):
        """ Calculate GW phases """
        if self._gw_phases is None:
            _, _, _, _, Phi_phi, _, Phi_r = self._calc_traj()

            gw_phases_per_mode = []
            for idx in range(len(self.amp.l_arr)):
                m = self.amp.m_arr[idx]
                n = self.amp.n_arr[idx]

                m = m.get() if isinstance(m, cp.ndarray) else m
                n = n.get() if isinstance(n, cp.ndarray) else n

                # Calculate GW phases
                phase_gw = m * Phi_phi + n * Phi_r  # k * Phi_theta = 0 for equatorial
                gw_phases_per_mode.append(phase_gw)

            self._gw_phases = gw_phases_per_mode
        return self._gw_phases
    
    @property
    def factor(self):
        if not hasattr(self, '_factor'):
            # Calculate the distance factor
            _, m2, _, _, _, _, _, _, dist = self.params
            self._factor = self.gwf.dist_factor(dist, m2)
        return self._factor
         


    def inner_approx(self, mode_i, mode_j): 
        """ Calculate the approximate inner product of two modes.
        Based on Eqn. 18-20 in arxiv 2109.14254 [non-local parameter degeneracy paper]

        Parameters:
        mode_i, mode_j: Lists of mode indices (so can be multiple, 
                        this is so that we can calculate
                        the inner product of combinations of modes)"""
        
        # Initilize the inner product
        total_inner = 0.0

        # Loop over all the modes
        for idx_i in mode_i:
            for idx_j in mode_j:
                # Obtain the lmns 
                l_i = self.amp.l_arr[idx_i]
                m_i = self.amp.m_arr[idx_i]
                n_i = self.amp.n_arr[idx_i] 
            
                l_j = self.amp.l_arr[idx_j]
                m_j = self.amp.m_arr[idx_j]
                n_j = self.amp.n_arr[idx_j]

                # Get Teukolsky modes
                # Check if negative m -> use conjugate of positive m mode
                if m_i >= 0:
                    A_i = self.teuk_modes[:, idx_i]
            
                elif m_i < 0:
                    pos_m_mask_i = (self.amp.l_arr == l_i) \
                                    & (self.amp.m_arr == -m_i) \
                                    & (self.amp.n_arr == -n_i)
                    pos_m_idx_i = self.gwf.xp.where(pos_m_mask_i)[0][0]
                    A_i_pos = self.teuk_modes[:, pos_m_idx_i]
                    A_i = (-1)**l_i * self.gwf.xp.conj(A_i_pos)
            
                if m_j >= 0:
                    A_j = self.teuk_modes[:, idx_j]
                    
                elif m_j < 0:
                    pos_m_mask_j = (self.amp.l_arr == l_j) \
                                    & (self.amp.m_arr == -m_j) \
                                    & (self.amp.n_arr == -n_j)
                    pos_m_idx_j = self.gwf.xp.where(pos_m_mask_j)[0][0]
                    A_j_pos = self.teuk_modes[:, pos_m_idx_j]
                    A_j = (-1)**l_j * self.gwf.xp.conj(A_j_pos)

                # Get sensitivity for each mode 
                Sn_i = get_sensitivity(self.gw_freqs[idx_i], 
                                       sens_fn=LISASens, 
                                       return_type="PSD"
                                       )
                
                Sn_j = get_sensitivity(self.gw_freqs[idx_j], 
                                        sens_fn=LISASens, 
                                        return_type="PSD"
                                        )

                # Get noise-weighted amplitudes
                # TODO: make this more compatible GPU-wise?
                bar_A_i = A_i.get() / np.sqrt(Sn_i)
                bar_A_j = A_j.get() / np.sqrt(Sn_j)

                # Define mask where the phase difference is small
                phase_mask = np.abs(self.gw_phases[idx_i] - self.gw_phases[idx_j]) < 1.0 

                # Calculate the product of the two waveforms w/ the phase mask
                prod = np.conj(bar_A_i[phase_mask]) * bar_A_j[phase_mask]
            
                # Calculate full inner product
                innerprod = np.sum(np.real(prod)) * self.delta_T * 1/(self.factor**2)

                # Add to the total inner product
                total_inner += innerprod

        return total_inner
            
    def SNR_approx(self, mode_idx):
        """ Calculate approximate SNR. """
        return np.sqrt(self.inner_approx(mode_idx, mode_idx))
    
    def overlap_approx(self, mode_i, mode_j):
        """ Calculate approximate overlap between two modes. """
        return self.inner_approx(mode_i, mode_j) / (self.SNR_approx(mode_i) * self.SNR_approx(mode_j))
    
    def select_modes(self, 
                     M_init = 100, 
                     M_sel = 5, 
                     threshold=0.01
                    ):
        """ Select modes based on a given threshold.
        M_init = number of modes to select initially (for the power sorting)
        M_sel = number of modes to select in the end (default: 5)
        threshold = for accept/reject cond. of inner products between modes
        """

        ###### Step 0: Initialization and Setup ######

        # mode_labels =  [(l,m,n) for l,m,n in zip(self.amp.l_arr, self.amp.m_arr, self.amp.n_arr)]
        mode_labels = []
        for l, m, n in zip(self.amp.l_arr, self.amp.m_arr, self.amp.n_arr):
            l_int = l.item() if hasattr(l, 'item') else int(l)
            m_int = m.item() if hasattr(m, 'item') else int(m)
            n_int = n.item() if hasattr(n, 'item') else int(n)
            mode_labels.append((l_int, m_int, n_int))

        # TODO: add mode info here during search 

        # Calculate power and sort 
        m0mask = self.amp.m_arr_no_mask != 0
        total_power = self.gwf.calc_power(self.teuk_modes, self.ylms, m0mask)

        # Top M_init indices in descending order (based on power)
        top_indices = self.gwf.xp.argsort(total_power)[-M_init:][::-1] 
        top_indices = top_indices.get() if isinstance(top_indices, cp.ndarray) else top_indices

        # Get sorted mode labels and power values
        # TODO: noise-weighted power mode selection?
        mp_modes = [mode_labels[idx] for idx in top_indices]
        mp_power = total_power[top_indices]


        ### Initialize selected set S with h0

        # Using the original index of the mode for the teuk_modes, amp..
        selected_modes = [[top_indices[0]]] 

        # Below is using h0 index in the SORTED list
        selected_labels = [[mp_modes[0]]]

        # Keep track of all processed modes (using ORIGINAL indices)
        processed_ori_indices = [top_indices[0]]  

        print(f"Initial mode selected: {mp_modes[0]} with power {mp_power[0]}")
        ###### Step 1, 2, ... N
        # Iterate through remaining modes on the sorted list ######
        
        # Iterate till M_init to to fulfill cond. <h_i|h_i> > 1 
        # Do note the idx i runs through the sorted indices
        for i in range(1, M_init):
            # Break if M_sel is reached 
            if len(selected_modes) >= M_sel:
                break

            print(f"Considering mode {i} / {M_init} : {mp_modes[i]} with power {mp_power[i]}")
        
            # Get next candidate mode h_j'
            hj_prime_idx = top_indices[i]
            hj_prime_label = mp_modes[i]
            
            # Keep track of processed indices
            processed_ori_indices.append(hj_prime_idx)

            max_inner = 0 
            max_inner_idx = -1 

            for k, selected_mode in enumerate(selected_modes):
                # Calculate with each selected mode |<h_sel|h_j'>/(<h_sel|h_sel>*<h_j'|h_j'>)^(1/2)|
                # Basically abs of overlap 
                calc_inner = abs(self.overlap_approx(selected_mode, [hj_prime_idx]))
                print(f" - Inner product with selected mode {k}: {calc_inner}")

                
                # Check if this is the maximum inner product found so far
                if calc_inner > max_inner:
                    max_inner = calc_inner
                    max_inner_idx = k

            # Check if the maximum inner product is below the threshold
            if max_inner < threshold:
                # Fulfill cond: Accept the mode
                selected_modes.append([hj_prime_idx])
                selected_labels.append([hj_prime_label])
            
            else:
                # Doesn't fulfill cond: Reject the mode and add to the most correlated mode 
                selected_modes[max_inner_idx].append(hj_prime_idx)
                selected_labels[max_inner_idx].append(hj_prime_label)
                

        ###### Step N+1: Handle remaining modes as h_M ######
        
        # Get indices of remaining modes with SORTED indices
        all_original_indices = set(top_indices[:M_init])  # All original indices we considered
        remaining_ori_indices = list(all_original_indices - set(processed_ori_indices))

        # Continue only if there are remaining modes
        if remaining_ori_indices:

            # Check condition 1 : <h_M|h_M> > 1
            hM_inner = self.inner_approx(remaining_ori_indices, remaining_ori_indices)
            cond_one = hM_inner > 1

            # Check condition 2 : <h_sel|h_M> << threshold
            inners_w_sel = []
            max_inner_with_sel = 0
            max_inner_with_sel_idx = -1

            for k, selected_mode in enumerate(selected_modes):
                selM_inner = self.inner_approx(selected_mode, remaining_ori_indices)
                inners_w_sel.append(selM_inner)

                if selM_inner > max_inner_with_sel:
                    max_inner_with_sel = selM_inner
                    max_inner_with_sel_idx = k

            cond_two = max_inner_with_sel < threshold

            # Check if both conditions are fulfilled
            if cond_one and cond_two:
                # Fulfill cond: Accept the mode as h_M
                selected_modes.append(remaining_ori_indices)
                hM_labels = [(self.amp.l_arr[idx].item(), self.amp.m_arr[idx].item(), self.amp.n_arr[idx].item()) for idx in remaining_ori_indices]
                selected_labels.append(hM_labels)
                
            # Cond one violated (not ortho), but cond two fulfilled
            elif not cond_one and cond_two:
                # Reject h_M (becomes error term)
                pass

            # Cond one fulfilled, but cond two violated
            elif cond_one and not cond_two:
                # Add h_M to the most correlated mode
                selected_modes[max_inner_with_sel_idx].extend(remaining_ori_indices)
                remaining_labels = [(self.amp.l_arr[idx].item(), self.amp.m_arr[idx].item(), self.amp.n_arr[idx].item()) for idx in remaining_ori_indices]
                selected_labels[max_inner_with_sel_idx].extend(remaining_labels)
            
            # Both conditions violated
            else:
                # Reject h_M (becomes error term)
                pass
            
        return selected_modes, selected_labels

In [7]:
N_traj = 500 # change amount of points here 
T = 1 #yr
delta_T = T*YRSID_SI/N_traj 

In [8]:
delta_T, T

(63116.29952709119, 1)

In [9]:
# Parameters
m1 = 1e6 #M
m2 = 1e1 #mu
a = 0.5
p0 = 9.5
e0 = 0.2
theta = np.pi / 3.0 
phi = np.pi / 4.0  
dt = 10.0
xI0 = 1.0 
dist= 1
#in the paper xI0 = 0.866, but that would be non-equatorial case

use_gpu = True 
traj = EMRIInspiral(func=KerrEccEqFlux, npoints=N_traj) #theres npoints flag here
amp = AmpInterpKerrEccEq(force_backend="cuda12x") # default lmax=10, nmax=55
interpolate_mode_sum = InterpolatedModeSum(force_backend="cuda12x")
ylm_gen = GetYlms(include_minus_m=False, force_backend="cuda12x")
gwf = GWfuncs.GravWaveAnalysis(T=T, dt=dt)

In [10]:
modesel = ModeSelector(params=[m1, m2, a, p0, e0, xI0, theta, phi, dist], 
    traj=traj, 
    amp=amp, 
    ylm_gen=ylm_gen,
    delta_T=delta_T,
    gwf=gwf
)

In [11]:
sel_modes, sel_labs = modesel.select_modes(M_init=20, M_sel=10, threshold=0.01)

Initial mode selected: (2, 2, 0) with power 16.46865454281919
Considering mode 1 / 20 : (2, 2, 1) with power 4.7659650803333635
 - Inner product with selected mode 0: 0.0021594627596660167
Considering mode 2 / 20 : (2, 2, -1) with power 1.6149074449428453
 - Inner product with selected mode 0: 0.0021440787661995133
 - Inner product with selected mode 1: 0.0023143486553815805
Considering mode 3 / 20 : (3, 3, 0) with power 0.8744987496064003
 - Inner product with selected mode 0: 0.000357858701810566
 - Inner product with selected mode 1: 0.0003166677914570845
 - Inner product with selected mode 2: 0.0004930362813332653
Considering mode 4 / 20 : (3, 3, 1) with power 0.7455900716447382
 - Inner product with selected mode 0: 0.0005054667492400171
 - Inner product with selected mode 1: 0.00046937457666803477
 - Inner product with selected mode 2: 0.000662271464781077
 - Inner product with selected mode 3: 0.0019133078605032521
Considering mode 5 / 20 : (2, 2, 2) with power 0.592139497647797

In [12]:
sel_modes

[[1165],
 [1166],
 [1164, 7048, 1168, 1941, 1942, 1943, 1944, 7158, 7160, 1501, 1055],
 [1498],
 [1499],
 [1167],
 [1497],
 [1054],
 [7159],
 [1500]]

In [13]:
sel_labs

[[(2, 2, 0)],
 [(2, 2, 1)],
 [(2, 2, -1),
  (2, -1, 0),
  (2, 2, 3),
  (4, 4, -1),
  (4, 4, 0),
  (4, 4, 1),
  (4, 4, 2),
  (2, -2, -1),
  (2, -2, 1),
  (3, 3, 3),
  (2, 1, 1)],
 [(3, 3, 0)],
 [(3, 3, 1)],
 [(2, 2, 2)],
 [(3, 3, -1)],
 [(2, 1, 0)],
 [(2, -2, 0)],
 [(3, 3, 2)]]

# Calculate inner product

In [None]:
inner_modes = []
for sel_idx, mode_idx in enumerate(sel_modes):
    # print(sel_idx, mode_idx)
    inner_modes.append(modesel.inner_approx(sel_modes[sel_idx], sel_modes[sel_idx]).get())

In [None]:
inner_modes

In [None]:
np.sqrt(np.sum(inner_modes))

## Generating waveforms for each mode

In [None]:
use_gpu = True
force_backend = "cuda12x"  
print("Setting up waveform generator...")
# keyword arguments for inspiral generator 
inspiral_kwargs={
        "func": 'KerrEccEqFlux',
        "DENSE_STEPPING": 0, #change to 1/True for uniform sampling
        "include_minus_m": False, 
        "use_gpu" : use_gpu,
        "force_backend":force_backend
}

# keyword arguments for inspiral generator 
amplitude_kwargs = {
    "force_backend": force_backend,
    # "use_gpu" : use_gpu
}

# keyword arguments for Ylm generator (GetYlms)
Ylm_kwargs = {
    "force_backend": force_backend,
    # "assume_positive_m": True  # if we assume positive m, it will generate negative m for all m>0
}

# keyword arguments for summation generator (InterpolatedModeSum)
sum_kwargs = {
    "force_backend":force_backend,
    "pad_output": True
    # "separate_modes": True
    # "use_gpu" : use_gpu
}

print("Creating FastKerrEccentricEquatorialFlux...")
# Kerr eccentric flux
waveform_gen = FastKerrEccentricEquatorialFlux(
    inspiral_kwargs=inspiral_kwargs,
    amplitude_kwargs=amplitude_kwargs,
    Ylm_kwargs=Ylm_kwargs,
    sum_kwargs=sum_kwargs,
    use_gpu=use_gpu,
)

In [None]:
def _generate_selected_waveforms(self, params):
    """Generate waveforms for selected mode groups."""
    m1, m2, a, p0, e0, xI0, theta, phi, dist = params

    waveforms_per_group = []

    for group in sel_labs:
        print(group)

        h_group = waveform_gen(
            m1, m2, a, p0, e0, xI0,
            theta, phi,
            dist=dist,
            dt=dt,
            T=T,
            mode_selection=group, 
            include_minus_mkn=False,
        )

        waveforms_per_group.append(h_group)

    return waveforms_per_group

In [None]:
%%time
waveform_per_mode = _generate_selected_waveforms(modesel, [m1, m2, a, p0, e0, xI0, theta, phi, dist])

In [None]:
len(waveform_per_mode)

In [None]:
waveform_per_mode

In [None]:
hM_mode = waveform_per_mode[-1].get()
plt.plot(hM_mode.real)  
plt.plot(hM_mode.imag)
plt.xlim(0, 1000)
plt.show()

In [None]:
mode_pos = waveform_per_mode[0].get()
mode_neg = waveform_per_mode[8].get()

plt.plot(mode_pos.real,label='(2,2,0) real')  
# plt.plot(mode_pos.imag,label='(2,2,0) imag')
plt.plot(mode_neg.real,label='(2,-2,0) real')  
# plt.plot(mode_neg.imag,label='(2,-2,0) imag')
plt.legend(loc='upper right')
plt.xlim(0, 1000)
plt.show()

In [None]:
plt.plot(mode_pos.real,label='(2,2,0) real')  
# plt.plot(mode_pos.imag,label='(2,2,0) imag')
plt.plot(mode_neg.real,label='(2,-2,0) real')  
# plt.plot(mode_neg.imag,label='(2,-2,0) imag')
plt.legend(loc='upper right')
plt.xlim(3155000, 3155815)
plt.show()

In [None]:
plt.plot(mode_pos.imag,label='(2,2,0) imag')  
plt.plot(mode_neg.imag,label='(2,-2,0) imag')  
plt.legend(loc='upper right')
plt.xlim(0, 1000)
# plt.xlim(3155000, 3155815)
plt.show()

In [None]:
hf_per_mode = []
for mode in waveform_per_mode:
    hf_per_mode.append(gwf.freq_wave(mode))

In [None]:
inner_per_mode = []
for mode in hf_per_mode:
    inner_per_mode.append(gwf.inner(mode, mode))

In [None]:
inner_per_mode

In [None]:
cp.sqrt(cp.sum(cp.array(inner_per_mode)))

In [None]:
np.abs(gwf.overlap(hf_per_mode[0], hf_per_mode[8]))

# Reference

In [None]:
h_true = waveform_gen(m1, m2, a, p0, e0, xI0, theta, phi, dist=dist, dt=dt, T=T)


In [None]:
h_true_f = gwf.freq_wave(h_true)
snr_ref = gwf.SNR(h_true_f)
snr_ref

TODO: understand why (2,-2,0) is not overlapping, check graphically? evolution and dephased?