# CBC Waveform in wavelet domain

In [None]:
import numpy as np
from scipy.signal.windows import tukey
import matplotlib.pyplot as plt
from scipy.signal import chirp, spectrogram
from pycbc.waveform import get_td_waveform
from pycbc.conversions import mass1_from_mchirp_q, mass2_from_mchirp_q
from pywavelet.waveform_generator.generators import FunctionalWaveformGenerator  


def cbc_waveform(mc, q=1, delta_t=1.0 / 4096, f_lower=20):
    m1 = mass1_from_mchirp_q(mc, q)
    m2 = mass2_from_mchirp_q(mc, q)
    hp, hc = get_td_waveform(
        approximant="IMRPhenomD",
        mass1=m1,
        mass2=m2,
        delta_t=delta_t,
        f_lower=f_lower,

    )
    return hp.sample_times, hp.data


def plot_wavelet_domain_signal(wavelet_data, time_grid, freq_grid, freq_range):
    fig = plt.figure()
    plt.imshow(
        np.abs(np.rot90(wavelet_data)),
        aspect="auto",
        extent=[time_grid[0], time_grid[-1], freq_grid[0], freq_grid[-1]],
    )
    cbar = plt.colorbar()
    cbar.set_label("Wavelet Amplitude")
    plt.xlabel("Time (s)")
    plt.ylabel("Frequency (Hz)")
    plt.ylim(*freq_range)
    plt.tight_layout()
    return fig

def create_cbc_wavelet_waveform_generator():
    Nf, Nt = 64, 64
    mult = 16
    ND = Nf * Nt
    dt = 1 / 256
    fmin = 20
    h_func = lambda mc: cbc_waveform(mc, q=1, delta_t=dt, f_lower=20)[1]
    t_vals = cbc_waveform(15, q=1, delta_t=dt, f_lower=fmin)[0]
    assert len(t_vals) == ND
    dt = t_vals[1] - t_vals[0]
    Tobs = max(t_vals)
    f_vals = np.arange(0, ND // 2 + 1) * 1 / Tobs
    waveform_generator = FunctionalWaveformGenerator(h_func, Nf=Nf, Nt=Nt, mult=mult)
    return waveform_generator, t_vals, f_vals

wavelet_generator, ts, fs = create_cbc_wavelet_waveform_generator()
for i, mc in enumerate(range(15, 50, 5)):
    wavelet_matrix = wavelet_generator(mc=mc)
    fig = plot_wavelet_domain_signal(wavelet_matrix, ts, fs, (0, 64))
    fig.suptitle(f"mc={mc}")
    fig.show()