In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# package imports
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'  # set available gpus before importing torch

# utils
import time
import json

from pathlib import Path
from functools import partial
from tqdm import tqdm

# typing
from typing import Optional, Union, List, Dict, Tuple

# modelling
import math
import numpy as np
import pandas as pd

import torch
from torch.distributions.uniform import Uniform

import pytorch_lightning as pl

# multiprocessing
import multiprocessing
import concurrent.futures

# visualisation
import seaborn as sns
import matplotlib.pyplot as plt

import plotly.express as px
import plotly.figure_factory as ff

In [3]:
# astronomy
import pycbc.psd
from pycbc.detector import Detector
from pycbc.waveform import get_waveform_filter_length_in_time, get_td_waveform, get_fd_waveform

from lal import MSUN_SI, REARTH_SI, C_SI, PC_SI
from lalsimulation import (
    SimInspiralTransformPrecessingNewInitialConditions,
    SimInspiralChooseFDWaveform,
    SimInspiralFD,
    SimInspiralImplementedFDApproximants,
    GetApproximantFromString
)

# astronomy - typing
# from gwpy.frequencyseries import FrequencySeries
# from gwpy.timeseries import TimeSeries
# from pycbc.types import FrequencySeries, TimeSeries

In [4]:
# local imports
from lfigw.waveform_generator import source_frame_to_radiation, is_fd_waveform

In [5]:
# import scipy
# import numpy as np
# import h5py
# from pathlib import Path
# from sklearn.utils.extmath import randomized_svd
# from tqdm import tqdm


# class SVDBasis(object):

#     def __init__(self):
#         self.whitening_dict = {}
#         self.standardization_dict = {}
#         self.T_matrices = None
#         self.T_matrices_deriv = None

#     def generate_basis(self, training_data, n, method='random'):
#         """Generate the SVD basis from training data and store it.

#         The SVD decomposition takes

#         training_data = U @ diag(s) @ Vh

#         where U and Vh are unitary.

#         Arguments:
#             training_data {array} -- waveforms in frequency domain

#         Keyword Arguments:
#             n {int} -- number of basis elements to keep.
#                        n=0 keeps all basis elements. (default: {0})
#         """
#         print(f'training_data: {training_data}')
#         print(f'n: {n}')

#         if method == 'random':
#             U, s, Vh = randomized_svd(training_data, n)

#             self.Vh = Vh.astype(np.complex64)
#             self.V = self.Vh.T.conj()

#             self.n = n

#         elif method == 'scipy':
#             # Code below uses scipy's svd tool. Likely slower.

#             U, s, Vh = scipy.linalg.svd(training_data, full_matrices=False)
#             V = Vh.T.conj()

#             if (n == 0) or (n > len(V)):
#                 self.V = V
#                 self.Vh = Vh
#             else:
#                 self.V = V[:, :n]
#                 self.Vh = Vh[:n, :]

#             self.n = len(self.Vh)

#     def basis_coefficients_to_fseries(self, coefficients):
#         """Convert from basis coefficients to frequency series.

#         Arguments:
#             coefficients {array} -- basis coefficients

#         Returns:
#             array -- frequency series
#         """

#         return coefficients @ self.Vh

#     def fseries_to_basis_coefficients(self, fseries):
#         """Convert from frequency series to basis coefficients.

#         Arguments:
#             fseries {array} -- frequency series

#         Returns:
#             array -- basis coefficients
#         """

#         return fseries @ self.V

#     #
#     # Time translation
#     #

#     def init_time_translation(self, t_min, t_max, Nt, f_grid):
#         """Initialize the time translation matrices.

#         The time translation in frequency domain corresponds to multiplication
#         by e^{ - 2 pi i f dt }. If we only have waveforms in terms of basis
#         coefficients, however, this is quite expensive: first one must
#         transform to frequency domain, then time translate, then transform
#         back to the reduced basis domain. Generally the dimensionality of
#         FD waveforms will be much higher than the dimension of the reduced
#         basis, so this is very costly.

#         This function pre-computes N x N matrices in the reduced basis domain,
#         where N is the dimension of the reduced basis. Matrices are computed
#         at a discrete set of dt's. Later, interpolation is used to compute time
#         translated coefficients away from these discrete points.

#         Arguments:
#             t_min {float} -- minimum value of dt
#             t_max {float} -- maximum value of dt
#             Nt {int} -- number of discrete points at which to compute matrices
#             f_grid {array} -- frequencies at which FD waveforms are evaluated
#         """

#         self.t_grid = np.linspace(t_min, t_max, num=Nt, endpoint=True,
#                                   dtype=np.float32)

#         self.T_matrices = np.empty((Nt, self.n, self.n),
#                                    dtype=np.complex64)
#         self.T_matrices_deriv = np.empty((Nt, self.n, self.n),
#                                          dtype=np.complex64)

#         print('Building time translation matrices.')
#         for i in tqdm(range(Nt)):

#             # Translation by dt in FD is multiplication by e^{- 2 pi i f dt}
#             T_fd = np.exp(- 2j * np.pi * self.t_grid[i] * f_grid)
#             T_deriv_fd = - 2j * np.pi * f_grid * T_fd

#             # Convert to FD, apply t translation, convert to reduced basis
#             T_basis = (self.Vh * T_fd) @ self.V
#             T_deriv_basis = (self.Vh * T_deriv_fd) @ self.V

#             self.T_matrices[i] = T_basis
#             self.T_matrices_deriv[i] = T_deriv_basis

#     def time_translate(self, coefficients, dt, interpolation='linear'):
#         """Calculate basis coefficients for a time-translated waveform.

#         The new waveform h_new(t) = h_old(t - dt). In other words, if the
#         original merger time is t=0, then the new merger time is t=dt.

#         In frequency domain, this corresponds to multiplication by
#         e^{ - 2 pi i f dt }.

#         This method is capable of linear or cubic interpolation.

#         Arguments:
#             coefficients {array} -- basis coefficients of initial waveform
#             dt {float} -- time translation

#         Keyword Arguments:
#             interpolation {str} -- 'linear' or 'cubic' interpolation
#                                    (default: {'linear'})

#         Returns:
#             array -- basis coefficients of time-translated waveform
#         """

#         pos = np.searchsorted(self.t_grid, dt, side='right') - 1

#         if self.t_grid[pos] == dt:

#             # No interpolation needed
#             translated = coefficients @ self.T_matrices[pos]

#         else:
#             t_left = self.t_grid[pos]
#             t_right = self.t_grid[pos+1]

#             # Interpolation parameter u(dt) defined so that:
#             #           u(t_left) = 0
#             #           u(t_right) = 1

#             u = (dt - t_left) / (t_right - t_left)

#             # Require coefficients evaluated on boundaries of interval
#             y_left = coefficients @ self.T_matrices[pos]
#             y_right = coefficients @ self.T_matrices[pos+1]

#             if interpolation == 'linear':

#                 translated = y_left * (1 - u) + y_right * u

#             elif interpolation == 'cubic':

#                 # Also require derivative of coefficients wrt dt
#                 dydt_left = coefficients @ self.T_matrices_deriv[pos]
#                 dydt_right = coefficients @ self.T_matrices_deriv[pos+1]

#                 # Cubic interpolation over interval
#                 # See https://en.wikipedia.org/wiki/Cubic_Hermite_spline

#                 h00 = 2*(u**3) - 3*(u**2) + 1
#                 h10 = u**3 - 2*(u**2) + u
#                 h01 = -2*(u**3) + 3*(u**2)
#                 h11 = u**3 - u**2

#                 translated = (y_left * h00
#                               + dydt_left * h10 * (t_right - t_left)
#                               + y_right * h01
#                               + dydt_right * h11 * (t_right - t_left))

#         return translated

#     #
#     # Whitening
#     #
#     # At present, we must know the fiducial and new noise PSD in advance, in
#     # order to prepare the transformation matrices for reduced basis
#     # coefficients. This is needed for dealing with detectors with different
#     # PSDs.
#     #
#     # In the future, when we draw PSDs at random at train time, this will need
#     # to be modified.
#     #

#     def init_whitening(self, ref_psd_name, ref_psd,
#                        new_psd_name, new_psd):
#         """Initialize whitening.

#         Constructs and saves the whitening matrix for changing from a reference
#         to a new noise PSD. This matrix acts on vectors of reduced basis
#         coefficients.

#         Arguments:
#             ref_psd_name {str} -- label for fiducial PSD
#             ref_psd {array} -- frequency series for fiducial PSd
#             new_psd_name {str} -- label for new PSD
#             new_psd {array} -- frequency series for new PSD
#         """

#         if ((new_psd_name != ref_psd_name)
#                 and (new_psd_name not in self.whitening_dict.keys())):

#             # ref_psd = np.array(ref_psd)
#             # new_psd = np.array(new_psd)

#             whitening_FD = (ref_psd / new_psd) ** 0.5

#             # Convert to float32 *after* dividing. PSDs can have very small
#             # numbers.
#             whitening_FD = whitening_FD.astype(np.float32)

#             # Convert to RB representation
#             whitening_RB = (self.Vh * whitening_FD) @ self.V

#             whitening_RB = whitening_RB.astype(np.complex64)

#             self.ref_psd_name = ref_psd_name
#             self.whitening_dict[new_psd_name] = whitening_RB

#     def whiten(self, coefficients, psd_name):
#         """Whiten a waveform, given as a vector of reduced-basis coefficients.
#         Waveform is assumed to already be white wrt reference PSD.

#         Whitening must be first initialized with with init_whitening method.

#         Arguments:
#             coefficients {array} -- basis coefficients of initial waveform
#             psd_name {str} -- label for new PSD

#         Returns:
#             array -- basis coefficients for whitened waveform
#         """

#         if psd_name != self.ref_psd_name:
#             return coefficients @ self.whitening_dict[psd_name]

#         else:
#             return coefficients

#     #
#     # Truncation
#     #

#     def truncate(self, n):

#         self.V = self.V[:, :n]
#         self.Vh = self.Vh[:n, :]

#         for ifo in self.standardization_dict.keys():
#             self.standardization_dict[ifo] = self.standardization_dict[ifo][:n]

#         for psd in self.whitening_dict.keys():
#             self.whitening_dict[psd] = self.whitening_dict[psd][:n, :n]

#         if self.T_matrices is not None:
#             self.T_matrices = self.T_matrices[:, :n, :n]
#             self.T_matrices_deriv = self.T_matrices_deriv[:, :n, :n]

#         self.n = n

#     #
#     # Standardization
#     #
#     # Given a whitened noisy waveform, we want to rescale each component to
#     # have unit variance. This is to improve neural network training. The mean
#     # should already be zero.
#     #

#     def init_standardization(self, ifo, h_array, noise_std):

#         # Standard deviation of data. Divide by sqrt(2) because we want real
#         # and imaginary parts to have unit standard deviation.
#         std = np.std(h_array, axis=0) / np.sqrt(2)

#         # Total standard deviation
#         std_total = np.sqrt(std**2 + noise_std**2)

#         self.standardization_dict[ifo] = 1.0 / std_total

#     def standardize(self, h, ifo):

#         return h * self.standardization_dict[ifo]

#     #
#     # File I/O
#     #

#     def save(self, directory='.', filename='reduced_basis.hdf5'):

#         p = Path(directory)
#         p.mkdir(parents=True, exist_ok=True)

#         f = h5py.File(p / filename, 'w')

#         f.create_dataset('V', data=self.V,
#                          compression='gzip', compression_opts=9)

#         if self.standardization_dict != {}:
#             std_group = f.create_group('std')
#             for ifo, std in self.standardization_dict.items():
#                 std_group.create_dataset(ifo, data=std,
#                                          compression='gzip',
#                                          compression_opts=9)

#         f.close()

#     def load(self, directory='.', filename='reduced_basis.hdf5'):

#         p = Path(directory)

#         f = h5py.File(p / filename, 'r')
#         self.V = f['V'][:, :]

#         if 'std' in f.keys():
#             std_group = f['std']
#             for ifo in std_group.keys():
#                 self.standardization_dict[ifo] = std_group[ifo][:]

#         f.close()

#         self.Vh = self.V.T.conj()
#         self.n = len(self.Vh)


In [6]:
# %%timeit
# author's numpy version of prior sampling
def lfigw_sample_prior(self, n):
    # transform prior ranges corresponding to uniformly sampled parameters
    uniform_priors = np.zeros((self.nparams, 2))
    for idx, param in enumerate(self.parameters):
        uniform_priors[idx] = self.priors[param]

        # transform prior domain in order to sample form uniform distribution
        if param in ('theta_jn', 'tilt_1', 'tilt_2'):
            uniform_priors[idx] = np.cos(uniform_priors[idx])  # we apply arccos after sampling
        elif param == 'dec':
            uniform_priors[idx] = np.sin(uniform_priors[idx])  # we apply arcsin after sampling
        elif param == 'distance':
            uniform_priors[idx] = uniform_priors[idx] ** 3.0  # we apply **(1/3) after sampling

    # Draw uniform samples
    draw = np.random.random((n, self.nparams))

    samples = np.apply_along_axis(
        lambda x: x*(uniform_priors[:, 1]- uniform_priors[:, 0]) + uniform_priors[:, 0],
        axis=1, 
        arr=draw
    )

    def M_q_from_m1_m2(m1, m2):

        M = m1 + m2
        q = m2 / m1

        return M, q

    m1i = self.param_idx['mass_1']
    m2i = self.param_idx['mass_2']

    if ('M' in self.priors.keys()) and ('q' in self.priors.keys()):
        M_min, M_max = self.priors['M']
        q_min, q_max = self.priors['q']
        m1_min, m1_max = self.priors['mass_1']
        m2_min, m2_max = self.priors['mass_2']
        for i in range(n):
            m1, m2 = samples[i, [m1i, m2i]]
            while True:
                M, q = M_q_from_m1_m2(m1, m2)
                if (m1 >= m2 and M >= M_min and M <= M_max
                        and q >= q_min and q <= q_max):
                    samples[i, [m1i, m2i]] = (m1, m2)
                    break
                else:
                    m1 = m1_min + (m1_max - m1_min) * np.random.random()
                    m2 = m2_min + (m2_max - m2_min) * np.random.random()
    else:
        # ONLY VALID OF M1, M2 HAVE THE SAME RANGES
        samples[:, [m2i, m1i]] = np.sort(samples[:, [m1i, m2i]])


    # Undo uniformity transformations
    for idx, param in enumerate(self.parameters):
        if param in ('theta_jn', 'tilt_1', 'tilt_2'):
            samples[:, idx] = np.arccos(samples[:, idx])
        elif param == 'dec':
            samples[:, idx] = np.arcsin(samples[:, idx])
        elif param == 'distance':
            samples[:, idx] = samples[:, idx] ** (1.0/3.0)
            
    return samples

In [7]:
# utility functions
def save_pair_plot(df: pd.DataFrame, filename: str, theme: str='ticks'):
    #https://stackoverflow.com/questions/37612434/what-are-ways-to-speed-up-seaborns-pairplot
    sns.set_style(theme)
    sns_plot = sns.pairplot(df, corner=True, diag_kind='kde', kind='kde')
    sns_plot.savefig(filename)
    plt.clf()  # clean pairplot figure from sns

In [8]:
def generate_ordered_parameters(spins: bool, spins_aligned: bool, inclination: bool, mass_ratio: bool=False) -> Tuple[str]:
    """Function genereates an ordered tuple of parameters.
    
    The index positions of each parameter in this list must be kept static for downstream tasks.
    """
    if mass_ratio:
        parameters = ['M', 'q']
    else:
        parameters = ['mass_1', 'mass_2']
        
    parameters.extend(['phase', 'time', 'distance'])

    if spins:
        if spins_aligned:
            parameters.extend(['chi_1', 'chi_2'])
        else:
            if not inclination:
                raise Exception('Precession requires nonzero inclination.')
            parameters.extend(['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl'])

    if inclination:
        parameters.extend(['theta_jn', 'psi'])

    parameters.extend(['ra', 'dec'])
    
    return tuple(parameters)

def generate_psd(
    ifo: str,
    delta_f: float,
    f_max: float,
    f_min: float,
    event_dir: Optional[Union[Path, str]]=None,
) -> pycbc.types.FrequencySeries:
    """Generate a power spectral density (PSD) as a Frequency Series given provided interferometer detector.
    
    Arguments:
        ifo: {str} -- Interferometer name according to LALSimulation (minus SimNoisePSD prefix) - prefixed with "PSD_"
        delta_f: {float} -- Frequency spacing (resolution) of the frequency series for the PSD
        f_max: {float} -- The maximum frequency for the PSD - should be half the sampling rate (see: Nyquist frequency)
        f_min_psd: {float} -- The minimum frequency for the PSD. May be different to f_min used when processing signals.
        event_dir {pathlib.Path | str} -- The directory of the event, e.g. 'data/events/GW150914'. 
            If None, this implies the PSD is not loaded and so we get from PyCBC; else we load from file.

    Returns:
        psd -- The power spectral density generated by PyCBC given our arguments.
    """
    # "The PSD length should be the same as the length of Frequency Domain (FD) waveforms,
    # which is determined from delta_f and f_max." - Green
    psd_length = int(f_max / delta_f) + 1
    
    if event_dir is None:
        psd = pycbc.psd.from_string(
            psd_name=f'PSD_{ifo}',  # PSD name according to LALSimulation (minus SimNoisePSD prefix)
            length=psd_length,
            delta_f=delta_f,
            low_freq_cutoff=f_min,  # freq below this value are set to zero.
        )
    else:
        psd_filepath = event_dir / f'PSD_{ifo}.txt'
        assert psd_filepath.is_file(), f'{psd_filepath} does not exist.'
        psd = pycbc.psd.from_txt(
            filename=psd_filepath,
            length=psd_length,
            delta_f=delta_f,
            low_freq_cutoff=f_min,  # freq below this value are set to zero.
            is_asd_file=False,
        )
        
    # To avoid division by 0 when whitening, set the PSD boundary values to satisfy [f_min, f_max].
    lower = int(f_min / delta_f)
    psd[:lower] = psd[lower]
    psd[-1:] = psd[-2]

    return psd

# get_psd = psd_dict['H1'].get(1 // delta_f, generate_psd(ifo, delta_f, f_max, f_min, event_dir))


In [9]:
def generate_whitened_waveform(
    sample: Dict[str, float],
    inclination: bool,
    spins: bool,
    spins_aligned: bool,
    domain: str='RB',
    intrinsic_only: bool=False
):
    # Convert from source frame to Cartesian parameters; Optional parameters have default values
    mass_1 = sample['mass_1']
    mass_2 = sample['mass_2']
    phase = sample['phase']
    coalesce_time = sample['time']
    distance = sample['distance']
    ra = sample['ra']
    dec = sample['dec']
    
    if inclination:
        theta_jn = sample['theta_jn']
        psi = sample['psi']
    else:
        theta_jn = 0.0
        psi = 0.0

    if spins:
        if spins_aligned:
            spin_1x, spin_1y, spin_1z = 0., 0., sample['chi_1']
            spin_2x, spin_2y, spin_2z = 0., 0., sample['chi_2']
            iota = theta_jn
        else:
            a_1, a_2 = sample['a_1'], sample['a_2']
            tilt_1, tilt_2 = sample['tilt_1'], sample['tilt_2']
            phi_jl, phi_12 = sample['phi_jl'], sample['phi_12']

            # use bilby/LAL to simulate an inspiral given intrinsic parameters
            (iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z) = source_frame_to_radiation(
                theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, f_ref, phase
            )
    else:
        spin_1x = 0.0
        spin_1y = 0.0
        spin_1z = 0.0
        spin_2x = 0.0
        spin_2y = 0.0
        spin_2z = 0.0
        iota = theta_jn

    if domain == 'TD':
        # "Start with a TD waveform generated from pycbc. If the
        # approximant is in FD, then this suitably tapers the low
        # frequencies in order to have a finite-length TD waveform
        # without wraparound effects. If we started with an FD
        # waveform, then we would have to do these manipulations
        # ourselves" -- Stephen Green.

        # Make sure f_min is low enough
        if (
            time_duration > get_waveform_filter_length_in_time(
                mass1=mass_1, mass2=mass_2,
                spin1x=spin_1x, spin2x=spin_2x,
                spin1y=spin_1y, spin2y=spin_2y,
                spin1z=spin_1z, spin2z=spin_2z,
                inclination=iota,
                f_lower=f_min,
                f_ref=f_ref,
                approximant=approximant
            )
        ):
            print('Warning: f_min not low enough for given waveform duration')

        hp_TD, hc_TD = get_td_waveform(
            mass1=mass_1, mass2=mass_2,
            spin1x=spin_1x, spin2x=spin_2x,
            spin1y=spin_1y, spin2y=spin_2y,
            spin1z=spin_1z, spin2z=spin_2z,
            distance=distance,
            coa_phase=phase,
            inclination=iota,  # CHECK THIS!!!
            delta_t=delta_t,
            f_lower=f_min,
            f_ref=f_ref,
            approximant=approximant
        )
        hp = hp_TD.to_frequencyseries()
        hc = hc_TD.to_frequencyseries()

    elif domain in ('FD', 'RB'):
        # LAL refers to approximants by an index 
        if is_fd_waveform(approximant):  # bool(SimInspiralImplementedFDApproximants(GetApproximantFromString(approximant)))
            # Use the pycbc waveform generator; change this later (says Author)
            # returns plus and cross phases of the waveform in frequency domain
            hp, hc = get_fd_waveform(
                mass1=mass_1, mass2=mass_2,
                spin1x=spin_1x, spin2x=spin_2x,
                spin1y=spin_1y, spin2y=spin_2y,
                spin1z=spin_1z, spin2z=spin_2z,
                distance=distance,
                coa_phase=phase,
                inclination=iota,
                f_lower=f_min,
                f_final=f_max,
                delta_f=delta_f,
                f_ref=f_ref,
                approximant=approximant,
            )
        else:
            # "Use SimInspiralFD. This converts automatically
            # from the TD to FD waveform, but it requires a timeshift to be
            # applied. Approach mimics bilby treatment." - Stephen Green

            # Require SI units
            mass_1_SI = mass_1 * MSUN_SI
            mass_2_SI = mass_2 * MSUN_SI
            distance_SI = distance * PC_SI * 1e6

            lal_approximant = GetApproximantFromString(approximant)

            h_p, h_c = SimInspiralFD(
                mass_1_SI, mass_2_SI,
                spin_1x, spin_1y, spin_1z,
                spin_2x, spin_2y, spin_2z,
                distance_SI, iota, phase,
                0.0, 0.0, 0.0,  # should have keyword args here?
                delta_f, f_min, f_max,
                f_ref, None,
                lal_approximant,
            )

            # If f_max/delta_f is not a power of 2, SimInspiralFD increases f_max
            # to make this a power of 2. Take only components running up to f_max.
            hp = np.zeros_like(sample_frequencies, dtype=np.complex)
            hc = np.zeros_like(sample_frequencies, dtype=np.complex)
            hp[:] = h_p.data.data[:len(hp)]
            hc[:] = h_c.data.data[:len(hp)]

            # Zero the strain for frequencies below f_min
            hp *= frequency_mask.numpy()
            hc *= frequency_mask.numpy()

            # SimInspiralFD sets the merger time so the waveform can be
            # transformed to TD without wrapping the end of the waveform to
            # the beginning. Bring the time of coalescence to 0.
            dt = 1. / delta_f + (h_p.epoch.gpsSeconds + h_p.epoch.gpsNanoSeconds * 1e-9)
            hp *= np.exp(- 1j * 2 * np.pi * dt * sample_frequencies)
            hc *= np.exp(- 1j * 2 * np.pi * dt * sample_frequencies)

            # Convert to pycbc frequencyseries. Later, get rid of pycbc functions.
            hp = FrequencySeries(hp, delta_f=delta_f, epoch=-time_duration)
            hc = FrequencySeries(hc, delta_f=delta_f, epoch=-time_duration)

    if intrinsic_only:
        # Whiten with reference noise PSD and return hp, hc
        hp = hp / (_get_psd('H1', psd, hp.delta_f, f_max, f_min_psd, event_dir) ** 0.5)
        hc = hc / (_get_psd('H1', psd, hc.delta_f, f_max, f_min_psd, event_dir) ** 0.5)
#         hp = hp / (psd['H1'].get(1 // delta_f, generate_psd('H1', hp.delta_f, f_max, f_min_psd, event_dir)) ** 0.5)
#         hc = hc / (psd['H1'].get(1 // delta_f, generate_psd('H1', hc.delta_f, f_max, f_min_psd, event_dir)) ** 0.5)
    
        # Convert to TD if necessary, ensure correct length
        if domain == 'TD':
            hp = hp.to_timeseries().time_slice(-time_duration, 0.0)
            hc = hc.to_timeseries().time_slice(-time_duration, 0.0)

        out = (hp.data.astype(np.complex64), hc.data.astype(np.complex64))

    else:
        # Project waveform onto detectors
        h_d_dict = {}
        for ifo, detector in detectors.items():

            # Project onto antenna pattern
            fp, fc = detector.antenna_pattern(ra, dec, psi, ref_time)  # fp and fc are plus and cross polarisations
            
            # Apply time delay relative to Earth center
            dt = detector.time_delay_from_earth_center(ra, dec, ref_time)  # should ref_time be coalesce_time?
            time_d = coalesce_time + dt
            
            # transform each plus/cross phase according to antenna pattern function
            h_d = fp * hp + fc * hc  
            
            # Author's Notes: Merger is currently at time 0. Shift it.
            # time_shift = - (self.time_duration - time_d)  # NOT SURE IF THIS LINE IS RIGHT / NEEDED. COMMENTED.
            time_shift = time_d
            
            h_d = h_d.cyclic_time_shift(time_shift)
            h_d.start_time = h_d.start_time + time_shift

            # whiten
#             h_d = h_d / (_get_psd(ifo, psd, h_d.delta_f, f_max, f_min_psd, event_dir) ** 0.5)
            h_d = h_d / (psd[ifo].get(1 // delta_f, generate_psd(ifo, h_d.delta_f, f_max, f_min_psd, event_dir)) ** 0.5)

            # Convert to TD if necessary, and ensure waveform is of correct length
            if domain == 'TD':
                h_d = h_d.to_timeseries().time_slice( -time_duration, 0.0)

            h_d_dict[ifo] = h_d.data

        out = np.stack([val.astype(np.complex64) for val in h_d_dict.values()])
        
    return out

In [23]:
class WaveformGenerator:

    def __init__(
        self,
        spins: bool=True,
        inclination: bool=True,
        spins_aligned: bool=True,
        mass_ratio: bool=False,
        detectors: List[str]=['H1', 'L1', 'V1'],
        domain: str='TD',
        extrinsic_at_train: bool=False,
        num_workers: int=1,
    ):
        "Contains a database of waveforms from which to train a model."
        # multiprocessing
        assert num_workers <= multiprocessing.cpu_count()
        self.num_workers = num_workers
        
        # parameterisation
        self.spins = spins
        self.spins_aligned = spins_aligned
        self.inclination = inclination
        self.mass_ratio = mass_ratio
        self.domain = domain
        
        # whether to apply extrinsic parameters at train or data prep time
        self.extrinsic_at_train = extrinsic_at_train
        self.extrinsic_params = ['time', 'distance', 'psi', 'ra', 'dec']
        
        # Set up indices for parameters
        self.parameters = generate_ordered_parameters(spins, spins_aligned, inclination, mass_ratio)
        self.param_idx = {param: i for i, param in enumerate(self.parameters)}
        self.nparams = len(self.parameters)
        
        # Default prior ranges
        self.priors = dict(
            mass_1=[10.0, 80.0],  # solar masses
            mass_2=[10.0, 80.0],
            M=[25.0, 100.0],
            q=[0.125, 1.0],
            phase=[0.0, 2*math.pi],
            time=[-0.1, 0.1],  # seconds
            distance=[100.0, 4000.0],  # Mpc
            chi_1=[-1.0, 1.0],
            chi_2=[-1.0, 1.0],
            a_1=[0.0, 0.99],
            a_2=[0.0, 0.99],
            tilt_1=[0.0, math.pi],
            tilt_2=[0.0, math.pi],
            phi_12=[0.0, 2*math.pi],
            phi_jl=[0.0, 2*math.pi],
            theta_jn=[0.0, math.pi],
            psi=[0.0, math.pi],
            ra=[0.0, 2*math.pi],
            dec=[-math.pi/2.0, math.pi/2.0]
        )
        
        self.priors = {key: value for key, value in self.priors.items() if key in self.parameters}
        
        self.latex = dict(
            mass_1=r'$m_1$',
            mass_2=r'$m_2$',
            M=r'$M$',
            q=r'$q$',
            phase=r'$\phi_c$',
            time=r'$t_c$',
            distance=r'$d_L$',
            chi_1=r'$\chi_1$',
            chi_2=r'$\chi_2$',
            a_1=r'$a_1$',
            a_2=r'$a_2$',
            tilt_1=r'$t_1$',
            tilt_2=r'$t_2$',
            phi_12=r'$\phi_{12}$',
            phi_jl=r'$\phi_{jl}$',
            theta_jn=r'$\theta_{JN}$',
            psi=r'$\psi$',
            ra=r'$\alpha$',
            dec=r'$\delta$'
        )
        
    @property
    def parameter_labels(self):
        labels = []
        for param in self.param_idx.keys():
            labels.append(self.latex[param])
        return labels
    
#     def _sample_numpy_prior(self, n: int) -> Dict[str, List[float]]:
        #
   
    def _sample_prior(self, n):
        # create dictionary of prior ranges and instantiate them as torch tensors
        bounds = {
            param: torch.tensor(self.priors[param], dtype=torch.float64) # device=device
            for param in self.parameters
        }

        # transform prior domain in order to sample form uniform (must be sorted as [low, high])
        bounds = {
            param: (
                value.pow(3).sort().values if param == 'distance'
                else value.sin().sort().values if param == 'dec'
                else value.cos().sort().values if param in ['theta_jn', 'tilt_1', 'tilt_2']
                else value.sort().values
            ) for param, value in bounds.items()
        }

        # create a torch.distributions.uniform.Uniform object for each parameter
        uniform_priors = { param: Uniform(*value, validate_args=True) for param, value in bounds.items()}

        # sample from uniform distributions
        samples = {parameter: distribution.sample([n]) for parameter, distribution in uniform_priors.items()}

        # undo uniformity transformations
        samples = {
            param: (
                value.pow(1/3) if param == 'distance'
                else value.arcsin() if param == 'dec'
                else value.arccos() if param in ['theta_jn', 'tilt_1', 'tilt_2']
                else value
            ) for param, value in samples.items()
        }

        # handle mass and mass ratios
        if ('M' in samples.keys()) and ('q' in samples.keys()):
            # reparameterise M and Q to be component masses (unordered)
            samples['mass_1'] = samples['M'] * samples['q']
            samples['mass_2'] = samples['M'] * (1 - samples['q'])

            # recreate samples dictinoary without M and Q (and inserts mass_1 and mass_2 at the front)
            samples = {
                'mass_1': samples['M'] * samples['q'],
                'mass_2': samples['M'] * (1 - samples['q']),
                **{key: val for key, val in samples.items() if key not in ('M','q')}
            }

        # uphold constraint that mass_1 >= mass_2 by sorting along the concatenated dimension then splitting
        # warning: this approach may have some unintended consequences regarding the prior bounds of m1 and m2
        samples['mass_1'], samples['mass_2'] = torch.stack([samples['mass_1'], samples['mass_2']]).sort(dim=0).values

        return samples

In [24]:
waveforms = WaveformGenerator(
    inclination=False,
    spins_aligned=True,
    mass_ratio=False,
    extrinsic_at_train=False,
    domain='RB',
)

### Waveform Data Generation

In [11]:
class WaveformDataModule(pl.LightningDataModule):
    
    def __init__(
        self,
        batch_size: int=512,
        data_dir: Union[Path, str]=None,
        detectors: List[str]=['H1','L1','V1'],
        domain: str='RB',  # only configured for RB basis
        extrinsic_at_train: bool=False,
    ):
        """A PyTorch LightningDataModule used to manage waveform datasets."""
        super().__init__()
        self.batch_size = batch_size
        self.data_dir = Path(data_dir)
        self.detectors = {ifo: Detector(ifo) for ifo in detectors}
        
        self.basis = None
        
    def prepare_data(self, load: bool=True):
        # only called on one GPU/TPU in distributed training (the "head" node?)
        # prepared waveform dataset takes n (a long time) hours and is 8.5GB - how to handle this
        assert data_dir.exists(), f'Provided data_dir {data_dir} is not a valid directory.'
        
#         if self.domain == 'RB' and self.basis is None:
#             self.generate_reduced_basis()
    
    def setup(self):
        # make assignments here (val/train/test split) - called on every process in DDP
        pass
    
    def train_dataloader(self):
#         train_split = DataSet(...)
#         return DataLoader(train_split)
        pass

    def val_dataloader(self):
#         val_split = DataSet(...)
#         return DataLoader(val_split)
        pass

    def test_dataloader(self):
#         test_split = DataSet(...)
#         return DataLoader(test_split)
        pass

    def teardown(self):
        # clean up after fit or test - called on every process in DDP
        pass

In [12]:
device = torch.device('cpu')

In [13]:
approximant = 'IMRPhenomPv2' # LAL refers to approximants by idx from GetApproximantFromString(approximant)

In [14]:
# load (and overwrite) data
with open(Path('data/events/GW150914') / 'event_info.json', 'r') as f:
    event_info = json.load(f)
    
# waveform settings
with open(Path('waveforms') / 'GW150914' / 'settings.json', 'r') as f:
    settings = json.load(f)

In [15]:
# psd_names = dict(
#     H1='aLIGODesignSensitivityP1200087',
#     L1='aLIGODesignSensitivityP1200087',
#     V1='AdVDesignSensitivityP1200087',
#     ref='aLIGODesignSensitivityP1200087'
# )

# psd_dict = dict(
#     H1={},
#     L1={},
#     V1={},
#     ref={}
# )

In [16]:
# input data (event and data store)
event = 'GW150914'  # gravitational wave event label
data_dir = Path('data')  # data store


In [17]:
# specify directories
events_dir = data_dir / 'events'  # gravitational wave events
waveforms_dir = data_dir / 'waveforms'  # generated waveforms
event_dir = events_dir / event  # specific event data

# Load event info
with open(event_dir / 'event_info.json', 'r') as f:
    event_info = json.load(f)

if 'f_min_psd' not in event_info:
    event_info['f_min_psd'] = event_info['f_min']

# rename keys with pointers - need to unify names between saving files and attributes
event_info['duration'] = event_info['T']
event_info['ref_time'] = event_info['t_event']

detectors = {ifo: Detector(ifo) for ifo in event_info['detectors']}

psd = {}
psd_names = {}
for ifo in detectors:
    psd[ifo] = {}
    psd_names[ifo] = f'PSD_{ifo}'
psd['ref'] = {}
psd_names['ref'] = psd_names['H1']  # this is hardcoded by author -- reference ifo must be Hanford

In [18]:
# do we want these as entries in a dictionary? standalone variables? computed on the fly?
event_info['sampling_rate'] = 2*event_info['f_max']
event_info['delta_t'] = 1.0 / event_info['sampling_rate']
event_info['delta_f'] = (1.0 / event_info['duration'])
event_info['Nf'] = int(event_info['f_max'] / event_info['delta_f']) + 1

In [19]:
# psd_name=f'PSD_H1'
# delta_f=event_info['delta_f']
# f_min=event_info['f_min_psd']
# f_max=event_info['f_max']
# psd_length = int(f_max / delta_f) + 1

In [20]:
waveforms = WaveformGenerator(inclination=False, spins_aligned=True, mass_ratio=False, domain='RB', extrinsic_at_train=False)

# manually specified by author
waveforms.priors['distance'] = [100.0, 1000.0]
waveforms.priors['a_1'][1] = 0.88
waveforms.priors['a_2'][1] = 0.88

prior_df = pd.DataFrame([
    {'name': key, 'lower': value[0], 'upper': value[1]}
    for key, value in waveforms.priors.items()
])

# prior_df.insert(1, 'distribution', [
#     'cos' if key in ('theta_jn', 'tilt_1', 'tilt_2')
#     else 'sin' if key == 'dec'
#     else 'uniform'
#     for key in priors.name
# ])

prior_df.insert(1, 'latex', list(waveforms.latex.values()))

prior_df['transform'] = [
    'pow(3)' if param == 'distance'
    else 'sin' if param == 'dec'
    else 'cos' if param in ['theta_jn', 'tilt_1', 'tilt_2']
    else 'identity'
    for param in prior_df.name
]

prior_df['inverse'] = [
    'pow(1/3)' if param == 'distance'
    else 'arcsin' if param == 'dec'
    else 'arccos' if param in ['theta_jn', 'tilt_1', 'tilt_2']
    else 'identity'
    for param in prior_df.name
]

prior_df

Unnamed: 0,name,latex,lower,upper,transform,inverse
0,mass_1,$m_1$,10.0,80.0,identity,identity
1,mass_2,$m_2$,10.0,80.0,identity,identity
2,M,$M$,25.0,100.0,identity,identity
3,q,$q$,0.125,1.0,identity,identity
4,phase,$\phi_c$,0.0,6.283185,identity,identity
5,time,$t_c$,-0.1,0.1,identity,identity
6,distance,$d_L$,100.0,1000.0,pow(3),pow(1/3)
7,chi_1,$\chi_1$,-1.0,1.0,identity,identity
8,chi_2,$\chi_2$,-1.0,1.0,identity,identity
9,a_1,$a_1$,0.0,0.88,identity,identity


In [21]:
def load_event(event_dir: Union[Path, str]):
    event_dir = Path(event_dir)

    # Load event info
    with open(event_dir / 'event_info.json', 'r') as f:
        event_info = json.load(f)
        event_info['f_min_psd'] = event_info['f_min_psd'] or event_info['f_min']
        
        self.event = event_info['event']
        self.f_min = event_info['f_min']
        self.f_min_psd = event_info['f_min']  # copy f_min
        self.f_max = event_info['f_max']
        self.time_duration = event_info['T']
        self.ref_time = event_info['t_event']
        self.window_factor = event_info['window_factor']
        ifo_list = event_info['detectors']

    # Initialize detectors
    self.init_detectors(detectors)
    
    detectors = {ifo: Detector(ifo) for ifo in ifo_list}
    detectors['ref'] = detectors['H1']  # author uses this "ref" detector and is hard coded as H1

    # Set up PSD
    self.psd = {}
    self.psd_names = {}
    for ifo in detectors:
        self.psd[ifo] = {}
        self.psd_names[ifo] = 'PSD_{}'.format(ifo)
    self.psd['ref'] = {}
    self.psd_names['ref'] = self.psd_names['H1']
    
try:
    event = settings['event']
    event_dir = settings['event_dir']
    if event_dir != 'None':
        event_dir = Path(event_dir)
    load_event(event_dir)
except:
    event = None
    event_dir = None

In [22]:
# Load event info (default settings)
# time_duration = 4.0  # seconds
# f_min = 8.0 if waveforms.domain == 'TD' else 20.0  # Hz
# ifo_list = ['H1','L1','V1']
# f_ref = 20.0  # frequency at which source frame spin parameters are defined
    
# why is this data split across two JSONs like this? data integrity??
time_duration = event_info['T']
f_max = event_info['f_max']

event = settings['event']
event_dir = Path(settings['event_dir'])

ifo_list = settings['detectors']
f_min = settings['f_min']
f_min_psd = settings['f_min_psd']
f_ref = settings['f_ref']
ref_time = settings['ref_time']

sampling_rate = 2*f_max
delta_t = 1.0 / sampling_rate
delta_f = (1.0 / time_duration)
Nf = int(f_max / delta_f) + 1

# methods used with lru_cache
sample_frequencies = torch.linspace(0.0, f_max, steps=Nf)
frequency_mask = sample_frequencies >= f_min

In [23]:
# def generate_reduced_basis(self, n_train=10000, n_test=10000):
n_train=50000 # 100000 # #228885 #2000000 #50000
n_test=1000
# detectors = {ifo: Detector(ifo) for ifo in ifo_list}

In [24]:
data = {key: value.numpy() for key, value in waveforms._sample_prior(n_train).items()}

# To generate reduced basis, fix all waveforms to same fiducial distance.  need justification here imo
data['distance'] = np.ones_like(data['distance'])*settings['fiducial_params']['distance']

# convert into 'records' format as list of dictionary with each key a parameter
samples = [
    {key: value for key, value in zip(data.keys(), sample)}
    for sample in np.array(tuple(data.values())).T
]

# df = pd.DataFrame(samples)
# fig = px.density_contour(df, x='mass_1', y='mass_2', marginal_x='histogram', marginal_y='histogram')
# fig.show()

In [25]:
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed

def parallel_process(function, array, n_jobs: int=16, use_kwargs: bool=False, front_num: int=3):
    """
        A parallel version of the map function with a progress bar. 

        Args:
            array (array-like): An array to iterate over.
            function (function): A python function to apply to the elements of array
            n_jobs (int, default=16): The number of cores to use
            use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of 
                keyword arguments to function 
            front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. 
                Useful for catching bugs
        Returns:
            [function(array[0]), function(array[1]), ...]
    """
    #We run the first few iterations serially to catch bugs
    if front_num > 0:
        front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
        
    #If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
    if n_jobs==1:
        return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
    
    #Assemble the workers
    with ProcessPoolExecutor(max_workers=n_jobs) as pool:
        #Pass the elements of array into function
        if use_kwargs:
            futures = [pool.submit(function, **a) for a in array[front_num:]]
        else:
            futures = [pool.submit(function, a) for a in array[front_num:]]
            
        kwargs = {
            'total': len(futures),
            'unit': 'it',
            'unit_scale': True,
            'leave': True
        }
        
        #Print out the progress as tasks complete
        for f in tqdm(as_completed(futures), **kwargs):
            pass
        
    out = []
    
    #Get the results from the futures. 
    for i, future in tqdm(enumerate(futures)):
        try:
            out.append(future.result())
        except Exception as e:
            out.append(e)
            
    return front + out

# results = parallel_process(
#     function=generate_waveforms,
#     array=samples,
#     use_kwargs=False,
#     n_jobs=num_workers,
# )

In [26]:
# setup 
num_workers = multiprocessing.cpu_count()-1
# num_workers = 2

generate_waveforms = partial(
    generate_whitened_waveform,
    inclination=waveforms.inclination,
    spins=waveforms.spins,
    spins_aligned=waveforms.spins_aligned,
    domain=waveforms.domain,
    intrinsic_only=True,
)

print('waveform dataset configuration:')
print(f'inclination: {waveforms.inclination}')
print(f'spins: {waveforms.spins}')
print(f'spins_aligned: {waveforms.spins_aligned}')
print(f'domain: {waveforms.domain}')
print(f'num_workers: {num_workers}')

waveform dataset configuration:
inclination: False
spins: True
spins_aligned: True
domain: RB
num_workers: 23


In [27]:
detector_waveforms = {
    ifo: np.empty((len(samples), Nf), dtype=np.complex64)
    for ifo in detectors.keys()
}

In [28]:
# process data to be used for generating reduced basis via SVD
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
    training_array = np.concatenate(list(tqdm(
        executor.map(generate_waveforms, samples, chunksize=75),
        total=len(samples),
        desc=f'Generating whitened waveforms on {num_workers} processes',
    )))
    
for key in parameters:
    parameters[key] = parameters[key].astype(np.float32)

Generating whitened waveforms on 23 processes:   3%|▎         | 1725/50000 [00:45<21:21, 37.67it/s] 


KeyboardInterrupt: 

In [None]:
Nrb = 200  # hardcoded number of reduced basis elements as per author
U, s, Vh = randomized_svd(training_array, Nrb)

In [34]:
from sklearn.utils.extmath import randomized_svd
import scipy

In [32]:
%%timeit -r1 -n1
Nrb = 200  # hardcoded number of reduced basis elements as per author
U, s, Vh = randomized_svd(training_array, Nrb)

3min 34s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [35]:
%%timeit -r1 -n1
U, s, Vh = scipy.linalg.svd(training_array, full_matrices=False)
# V = Vh.T.conj()

# if (n == 0) or (n > len(V)):
#     self.V = V
#     self.Vh = Vh
# else:
#     self.V = V[:, :n]
#     self.Vh = Vh[:n, :]

# self.n = len(self.Vh)

8min 47s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [36]:
training_array.shape

(100000, 8193)

In [42]:
parameters

{'mass_1': array([10.947892, 15.541207, 63.409836, ..., 30.754145, 32.50648 ,
        33.672577], dtype=float32),
 'mass_2': array([36.354824, 79.62963 , 68.202065, ..., 38.730396, 55.61173 ,
        63.18906 ], dtype=float32),
 'phase': array([3.5613995, 6.25151  , 3.7510014, ..., 4.346566 , 1.8440422,
        5.7286015], dtype=float32),
 'time': array([-0.09021521,  0.08645985, -0.0951781 , ..., -0.00942908,
         0.01464454, -0.0817539 ], dtype=float32),
 'distance': array([1000., 1000., 1000., ..., 1000., 1000., 1000.], dtype=float32),
 'chi_1': array([-0.81193763, -0.4034527 ,  0.67301285, ...,  0.18957698,
         0.6633359 ,  0.06395798], dtype=float32),
 'chi_2': array([ 0.292476  , -0.04284127, -0.32827464, ..., -0.9876351 ,
         0.7571041 , -0.14255877], dtype=float32),
 'ra': array([1.2162216 , 0.20546114, 5.0787644 , ..., 1.6396039 , 1.1205654 ,
        0.36769477], dtype=float32),
 'dec': array([-0.5135797 ,  0.706079  ,  0.2885482 , ..., -1.1187202 ,
         0.22

In [38]:
32772000000*1e-9*10/25

13.1088

In [34]:
arrays.nbytes

26217600000

In [36]:
26217600000*1e-9

26.2176

In [42]:
import gc
gc.collect()

30

In [30]:
results[0]['H1'].shape

(8193,)

In [32]:
# chunksizes = [50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150]
# runtimes = []

# with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
#     for chunksize in chunksizes:
#         start = time.perf_counter()
#         results = list(tqdm(
#             executor.map(generate_waveforms, samples, chunksize=chunksize),
#             total=len(samples),
#             desc=f'Generating whitened waveforms utilising {num_workers} workers (chunksize={chunksize})',
#         ))
#         finish = time.perf_counter()
#         runtimes.append(round(finish - start, 6))
    

Generating whitened waveforms utilising 23 workers (chunksize=50): 100%|██████████| 50000/50000 [00:25<00:00, 1975.63it/s]
Generating whitened waveforms utilising 23 workers (chunksize=60): 100%|██████████| 50000/50000 [00:26<00:00, 1893.70it/s]
Generating whitened waveforms utilising 23 workers (chunksize=70): 100%|██████████| 50000/50000 [00:22<00:00, 2206.05it/s]
Generating whitened waveforms utilising 23 workers (chunksize=80): 100%|██████████| 50000/50000 [00:22<00:00, 2204.33it/s]
Generating whitened waveforms utilising 23 workers (chunksize=90): 100%|██████████| 50000/50000 [00:23<00:00, 2142.19it/s]
Generating whitened waveforms utilising 23 workers (chunksize=100): 100%|██████████| 50000/50000 [00:23<00:00, 2170.43it/s]
Generating whitened waveforms utilising 23 workers (chunksize=110): 100%|██████████| 50000/50000 [00:23<00:00, 2120.07it/s]
Generating whitened waveforms utilising 23 workers (chunksize=120): 100%|██████████| 50000/50000 [00:22<00:00, 2199.76it/s]
Generating wh

In [28]:
# for sample in tqdm(samples):
#     mass_1 = sample['mass_1']
#     mass_2 = sample['mass_2']
#     phase = sample['phase']
#     coalesce_time = sample['time']
#     distance = sample['distance']
#     ra = sample['ra']
#     dec = sample['dec']

#     # Convert from source frame to Cartesian parameters; Optional parameters have default values
#     if prior_data.inclination:
#         theta_jn = sample['theta_jn']
#         psi = sample['psi']
#     else:
#         theta_jn = 0.0
#         psi = 0.0

#     if prior_data.spins:
#         if prior_data.spins_aligned:
#             spin_1x, spin_1y, spin_1z = 0., 0., sample['chi_1']
#             spin_2x, spin_2y, spin_2z = 0., 0., sample['chi_2']
#             iota = theta_jn
#         else:
#             a_1, a_2 = sample['a_1'], sample['a_2']
#             tilt_1, tilt_2 = sample['tilt_1'], sample['tilt_2']
#             phi_jl, phi_12 = sample['phi_jl'], sample['phi_12']

#             # use bilby/LAL to simulate an inspiral given intrinsic parameters
#             (iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z) = source_frame_to_radiation(
#                 theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, f_ref, phase
#             )
#     else:
#         spin_1x = 0.0
#         spin_1y = 0.0
#         spin_1z = 0.0
#         spin_2x = 0.0
#         spin_2y = 0.0
#         spin_2z = 0.0
#         iota = theta_jn
        

#     if prior_data.domain == 'TD':
#         # "Start with a TD waveform generated from pycbc. If the
#         # approximant is in FD, then this suitably tapers the low
#         # frequencies in order to have a finite-length TD waveform
#         # without wraparound effects. If we started with an FD
#         # waveform, then we would have to do these manipulations
#         # ourselves" -- Stephen Green.

#         # Make sure f_min is low enough
#         if (
#             time_duration > get_waveform_filter_length_in_time(
#                 mass1=mass_1, mass2=mass_2,
#                 spin1x=spin_1x, spin2x=spin_2x,
#                 spin1y=spin_1y, spin2y=spin_2y,
#                 spin1z=spin_1z, spin2z=spin_2z,
#                 inclination=iota,
#                 f_lower=f_min,
#                 f_ref=f_ref,
#                 approximant=approximant
#             )
#         ):
#             print('Warning: f_min not low enough for given waveform duration')

#         hp_TD, hc_TD = get_td_waveform(
#             mass1=mass_1, mass2=mass_2,
#             spin1x=spin_1x, spin2x=spin_2x,
#             spin1y=spin_1y, spin2y=spin_2y,
#             spin1z=spin_1z, spin2z=spin_2z,
#             distance=distance,
#             coa_phase=phase,
#             inclination=iota,  # CHECK THIS!!!
#             delta_t=delta_t,
#             f_lower=f_min,
#             f_ref=f_ref,
#             approximant=approximant
#         )
#         hp = hp_TD.to_frequencyseries()
#         hc = hc_TD.to_frequencyseries()

#     elif prior_data.domain in ('FD', 'RB'):
#             # LAL refers to approximants by an index 
#         if is_fd_waveform(approximant):  # bool(SimInspiralImplementedFDApproximants(GetApproximantFromString(approximant)))
#             # Use the pycbc waveform generator; change this later (says Author)
#             # returns plus and cross phases of the waveform in frequency domain
#             hp, hc = get_fd_waveform(
#                 mass1=mass_1, mass2=mass_2,
#                 spin1x=spin_1x, spin2x=spin_2x,
#                 spin1y=spin_1y, spin2y=spin_2y,
#                 spin1z=spin_1z, spin2z=spin_2z,
#                 distance=distance,
#                 coa_phase=phase,
#                 inclination=iota,
#                 f_lower=f_min,
#                 f_final=f_max,
#                 delta_f=delta_f,
#                 f_ref=f_ref,
#                 approximant=approximant,
#             )
#         else:
#             # "Use SimInspiralFD. This converts automatically
#             # from the TD to FD waveform, but it requires a timeshift to be
#             # applied. Approach mimics bilby treatment." - Stephen Green

#             # Require SI units
#             mass_1_SI = mass_1 * MSUN_SI
#             mass_2_SI = mass_2 * MSUN_SI
#             distance_SI = distance * PC_SI * 1e6

#             lal_approximant = GetApproximantFromString(approximant)

#             h_p, h_c = SimInspiralFD(
#                 mass_1_SI, mass_2_SI,
#                 spin_1x, spin_1y, spin_1z,
#                 spin_2x, spin_2y, spin_2z,
#                 distance_SI, iota, phase,
#                 0.0, 0.0, 0.0,  # should have keyword args here?
#                 delta_f, f_min, f_max,
#                 f_ref, None,
#                 lal_approximant,
#             )

#             # If f_max/delta_f is not a power of 2, SimInspiralFD increases f_max
#             # to make this a power of 2. Take only components running up to f_max.
#             hp = np.zeros_like(sample_frequencies, dtype=np.complex)
#             hc = np.zeros_like(sample_frequencies, dtype=np.complex)
#             hp[:] = h_p.data.data[:len(hp)]
#             hc[:] = h_c.data.data[:len(hp)]

#             # Zero the strain for frequencies below f_min
#             hp *= frequency_mask.numpy()
#             hc *= frequency_mask.numpy()

#             # SimInspiralFD sets the merger time so the waveform can be
#             # transformed to TD without wrapping the end of the waveform to
#             # the beginning. Bring the time of coalescence to 0.
#             dt = 1. / delta_f + (h_p.epoch.gpsSeconds + h_p.epoch.gpsNanoSeconds * 1e-9)
#             hp *= np.exp(- 1j * 2 * np.pi * dt * sample_frequencies)
#             hc *= np.exp(- 1j * 2 * np.pi * dt * sample_frequencies)

#             # Convert to pycbc frequencyseries. Later, get rid of pycbc functions.
#             hp = FrequencySeries(hp, delta_f=delta_f, epoch=-time_duration)
#             hc = FrequencySeries(hc, delta_f=delta_f, epoch=-time_duration)

#     intrinsic_only=False
#     if intrinsic_only:  # don't have specific detector data at this point, it would seem
#         # Whiten with reference noise PSD and return hp, hc
#         hp = hp / (_get_psd(detectors['ref'], psd, hp.delta_f, f_max, f_min_psd, event_dir) ** 0.5)
#         hc = hc / (_get_psd(detectors['ref'], psd, hc.delta_f, f_max, f_min_psd, event_dir) ** 0.5)

#         # Convert to TD if necessary, ensure correct length
#         if self.domain == 'TD':
#             hp = hp.to_timeseries().time_slice(-time_duration, 0.0)
#             hc = hc.to_timeseries().time_slice(-time_duration, 0.0)

#         out = (hp.data, hc.data)

#     else:
#         # Project waveform onto detectors
#         h_d_dict = {}
#         for ifo, detector in detectors.items():

#             # Project onto antenna pattern
#             fp, fc = detector.antenna_pattern(ra, dec, psi, ref_time)  # fp and fc are plus and cross polarisations
#             h_d = fp * hp + fc * hc  # transform each plus/cross phase according to antenna pattern function

#             # Apply time delay relative to Earth center
#             dt = detector.time_delay_from_earth_center(ra, dec, ref_time)  # should ref_time be coalesce_time?
#             time_d = coalesce_time + dt

#             # Author's Notes: Merger is currently at time 0. Shift it.
#             # NOT SURE NEXT LINE IS RIGHT / NEEDED. COMMENTED.
#             # time_shift = - (self.time_duration - time_d)
#             time_shift = time_d
#             h_d = h_d.cyclic_time_shift(time_shift)
#             h_d.start_time = h_d.start_time + time_shift

#             # whiten
#             h_d = h_d / (_get_psd(ifo, psd, h_d.delta_f, f_max, f_min_psd, event_dir) ** 0.5)

#             # Convert to TD if necessary, and ensure waveform is of correct length
#             if prior_data.domain == 'TD':
#                 h_d = h_d.to_timeseries().time_slice( -time_duration, 0.0)

#             h_d_dict[ifo] = h_d.data

#         out = h_d_dict

In [None]:
import inspect
inspect.getmembers(d)

In [None]:
d.time_delay_from_earth_center

In [None]:
d.antenna_pattern

In [None]:
from gwpy.timeseries import TimeSeries
from gwpy.frequencyseries import FrequencySeries

In [None]:
gwpy_h_d = FrequencySeries.from_pycbc(h_d)#.abs()
gwpy_h_d.plot()

In [None]:
gwpy_h_d_td = TimeSeries.from_pycbc(h_d.to_timeseries())
gwpy_h_d_td.plot()

In [None]:
10*10

In [None]:
2.23*(10e-3)*(10e6)/60/60

In [None]:
key not in psd_dict[ifo]

In [None]:
detectors['ref']

In [None]:
# instantiate arrays to store data - but we should never really need to do this, right?
# h_detector = {ifo: torch.zeros((n_train, Nf), dtype=torch.complex64) for ifo in detectors}  # ???

In [None]:
# class FlowModel(pl.LightningModule):
    
#     def __init__(
#         self,
#         model_dir: Union[Path, str]=None,
#         data_dir: Union[Path, str]=None,
#         device: Union[torch.device, str]='cuda',
#     ):
#         """A FlowModel is a PyTorch LightningModule that we use for gravitational wave event modelling."""
#         super().__init__()
#         self.device = torch.device(device)
#         self.data_dir = data_dir
#         self.model_dir = model_dir
        

In [None]:
# x1 = np.random.randn(200)
# x2 = np.random.randn(200) + 2

# group_labels = ['Group 1', 'Group 2']

# colors = ['slategray', 'magenta']

# # Create distplot with curve_type set to 'normal'
# fig = ff.create_distplot(
#     [x1, x2],
#     group_labels,
#     bin_size=.5,
#     show_hist=False,
#     show_rug=False,
#     curve_type='kde', # override default 'kde'
#     colors=colors
# )

# # Add title
# fig.update_layout(title_text='Distplot with Normal Distribution')
# fig.show()

In [None]:
# 2d plotly density chart
# fig = px.density_contour(samples_df, x='mass_1', y='mass_2', marginal_x='histogram', marginal_y='histogram')
# fig.show()