In [None]:
import numpy as np 
import pandas as pd
from glob import glob
import matplotlib.pyplot as plt
import torch
import torchaudio
import librosa
import scipy

### Utils

In [None]:
SR = 2048

In [None]:
def get_full_path(signal_file_name):
    folder_id = signal_file_name[:3]
    path = f"/kaggle/input/g2net-gravitational-wave-detection/train/{folder_id[0]}/{folder_id[1]}/{folder_id[2]}/{signal_file_name}.npy"
    return path

In [None]:
def get_signal(file_path):
    with open(file_path, "rb") as file:
        return np.load(file)    

In [None]:
def plot_transformation(
    negative_signal,
    positive_signal,
    transformation,
    title="CQT"
):
    fig, ax = plt.subplots(3, 2, figsize=(18, 10))
    plt.suptitle(f"{title}\n\n", fontsize=18)
    fig.tight_layout()
    
    for i  in range(3):
        ax[i, 0].imshow(transformation(negative_signal[i, :]), aspect='auto')
        ax[i, 1].imshow(transformation(positive_signal[i, :]), aspect='auto')

    for i in range(ax.shape[0]):
        for j in range(ax.shape[1]):    
            ax[i, j].set(xlabel='time', ylabel='pseudo freq')

    ax[0, 0].set_title('label=0', fontsize=16)
    ax[0, 1].set_title('label=1', fontsize=16)
    
    plt.show()

In [None]:
def plot_raw_signal(
    negative_signal,
    positive_signal
):
    fig, ax = plt.subplots(1, 2, figsize=(14, 5))
    plt.suptitle(f"Raw signal\n\n", fontsize=18)
    
    for i  in range(3):
        ax[0].plot(negative_signal[i, :])
        ax[1].plot(positive_signal[i, :])        

    for i in range(2):
        ax[i].set(xlabel='time', ylabel='amplitude')

    ax[0].set_title('label=0', fontsize=16)
    ax[1].set_title('label=1', fontsize=16)
    fig.tight_layout()
    
    plt.show()

### Load labels/data path

In [None]:
train = pd.read_csv("/kaggle/input/g2net-gravitational-wave-detection/training_labels.csv")

all_negatives = train[train.target == 0].id.apply(get_full_path)
all_positives = train[train.target == 1].id.apply(get_full_path)

In [None]:
id0 = get_signal(all_negatives.iloc[0])
id1 = get_signal(all_positives.iloc[0])

### Raw signal

In [None]:
plot_raw_signal(id0, id1)

### STFT

[The Short-time Fourier transform (STFT), is a Fourier-related transform used to determine the sinusoidal frequency and phase content of local sections of a signal as it changes over time](https://en.wikipedia.org/wiki/Short-time_Fourier_transform)

In [None]:
def get_stft(signal):
    return torch.stft(torch.Tensor(signal), n_fft=128, return_complex=True, normalized=True).abs().numpy()

In [None]:
plot_transformation(id0, id1, get_stft, "STFT")

### CQT

[In mathematics and signal processing, the constant-Q transform, simply known as CQT transforms a data series to the frequency domain. It is related to the Fourier transform and very closely related to the complex Morlet wavelet transform](https://en.wikipedia.org/wiki/Constant-Q_transform)

In [None]:
def get_cqt(signal):    
    return np.abs(librosa.cqt(signal, sr=SR, hop_length=5*8, n_bins=40, bins_per_octave=12))    

In [None]:
plot_transformation(id0, id1, get_cqt, "CQT")

### CWT Morlet

[In mathematics, the continuous wavelet transform (CWT) is a formal (i.e., non-numerical) tool that provides an overcomplete representation of a signal by letting the translation and scale parameter of the wavelets vary continuously](https://en.wikipedia.org/wiki/Continuous_wavelet_transform)

In [None]:
def get_cwt_morlet(signal):
    return np.abs(scipy.signal.cwt(signal, scipy.signal.morlet, np.arange(1, 10)))

In [None]:
plot_transformation(id0, id1, get_cwt_morlet, "CWT Morlet")