In [None]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm import tqdm
pd.options.display.max_columns = 100

#from skimage.filters import difference_of_gaussians
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.metrics import f1_score
import random
import time

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
# from torch.cuda.amp import autocast, GradScaler

import librosa
import librosa.display

from scipy.special import logit, expit

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
%matplotlib inline

import cv2
from scipy.interpolate import interp1d
import pywt

def sigmoid(var):
    return 1/(1+np.exp(-var))

In [None]:
!pip install pycbc
import pycbc
from pycbc.waveform import td_approximants, fd_approximants, get_td_waveform
from pycbc.detector import Detector

In [None]:
# List of td approximants that are available
print(td_approximants())

In [None]:
# List of fd approximants that are currently available
print(fd_approximants())

In [None]:
gwlist = ['SEOBNRv2', 'SEOBNRv2_opt', 'SEOBNRv4', 'SEOBNRv4_opt', 'SEOBNRv2T', 'SEOBNRv4T', ]
gwlist = ['SEOBNRv2', 'SEOBNRv4' ]

In [None]:
np.random.seed(1)

#Define the Detectors
det_h1 = Detector('H1')
det_l1 = Detector('L1')
det_v1 = Detector('V1')

for n in range(100):
    #Define the GW params
    gwapprox = np.random.choice( gwlist )
    print(gwapprox)
    hp, hc = get_td_waveform(approximant=gwapprox,
                             mass1=16 + np.random.randint(0,10),
                             mass2=16 + np.random.randint(0,10),
                             delta_t=1.0/4096,
                             spin1z=0.5 + np.random.rand()*0.5,
                             spin2z=0.25 + np.random.rand()*0.5,
                             inclination= 2 * np.pi * np.random.rand(),
                             coa_phase= 2 * np.pi * np.random.rand(),
                             phase_order = np.random.randint(2,8),
                             f_lower=np.random.randint(24,64),
                             distance=int(np.random.randint(1,1000)),
                            )

    
    # Choose a GPS end time, sky location, and polarization phase for the merger
    # NOTE: Right ascension and polarization phase runs from 0 to 2pi
    #       Declination runs from pi/2. to -pi/2 with the poles at pi/2. and -pi/2.
    end_time = 1192529720 + np.random.randint(1192529720//100000)
    declination = np.pi * np.random.rand() - np.pi/2
    right_ascension = 2 * np.pi * np.random.rand()
    polarization = 2 * np.pi * np.random.rand()
    hp.start_time += end_time
    hc.start_time += end_time

    signal_h1 = det_h1.project_wave(hp, hc,  right_ascension, declination, polarization)
    signal_l1 = det_l1.project_wave(hp, hc,  right_ascension, declination, polarization)
    signal_v1 = det_v1.project_wave(hp, hc,  right_ascension, declination, polarization)    
    minlen = np.min( [len(signal_h1), len(signal_l1), len(signal_v1)] )
    data = np.stack( (signal_h1[:minlen], signal_l1[:minlen], signal_v1[:minlen]),  ) * 1e19
    print(data.shape)
    
    if data.shape[1]>4096:
        data = data[:,data.shape[1]-4096:]
        for N in range(80):
            data[:,N] *= 1./(N+1)
    
    if len(hp)<4096:
        for N in range(80):
            data[:,N] *= 1./(N+1)
        data = np.pad(data, ((0,0),(4096-data.shape[1],0)) )
    
    
    plt.plot(data[0])
    plt.plot(data[1])
    plt.plot(data[2])
    plt.show()
    
    cwt, freqs = pywt.cwt(data, scales=np.arange(1, 95, 0.62), wavelet='cmor1.5-0.95', sampling_period=1/2048, method='fft')
    cwt = cwt.transpose(0,2,1)
    print(cwt.shape)
    cwt = np.log1p( np.abs(cwt) )
    print( cwt.min(), cwt.max())
    cwt -= cwt.min()
    cwt /= cwt.max()
    plt.imshow(cv2.resize(cwt,(256, 256)) )
    plt.title(gwapprox)
    plt.show()