In [None]:
!pip install torch --upgrade --quiet

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
from tqdm.auto import tqdm
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

**There some difference in the `fmin` value**. Originally it was 20Hz and I slightly raise it to **21.83Hz** for this version. Lowering the value slightly will make bright part of the image dimmer (speaking in terms of image rather than frequency since easy to visualize) while raising the frequency slightly will brighten the strongest part, and some of the background noise on the RHS of the picture will also brighten into existence. 

If you'd like to make your own dataset consider tuning this value to which you see fit. It might or might not fit better with brighter or dimmer value. 

Second thing is `n_bins`. Tuning this too high will cause it to exceed the nyquist limit, while too low might have some bright image darkens. Consider tuning this as well. One changes it **from 55 to 63** to try out the difference. 

Of course, this is not a confirmation. Some of the bright image will dim out when increasing `fmin` and/or `n_bins`, hence this requires some experimentation. 

In [None]:
import fastai
import torch
fastai.__version__

In [None]:
import glob
import pathlib

head = pathlib.Path("../input/g2net-gravitational-wave-detection")

train_files = sorted(glob.glob("../input/g2net-gravitational-wave-detection/train/*/*/*/*.npy"))

In [None]:
wave = np.load(train_files[0])

In [None]:
import librosa
import librosa.display
import matplotlib.pyplot as plt

In [None]:
from numba import njit, jit, cuda, guvectorize

@njit(nogil=True)
def min_max_scaler(wave):
    for i in range(len(wave)):
        wave[i] = (wave[i] - min(wave[i])) / (max(wave[i]) - min(wave[i]))
        wave[i] = 2 * wave[i] - 1
        
    return wave

In [None]:
wave1 = min_max_scaler(wave)

In [None]:
plt.figure(dpi=120)
for i in range(len(wave)):
    plt.plot(range(len(wave[i])), wave[i], label=f"label_{i}")
plt.legend()

# Bandpass filter

In [None]:
from scipy.signal import butter, filtfilt, sosfiltfilt
# from torchaudio.functional import bandpass_biquad

T = 2 # sample period, s
fs = 2048.0  # sample rate, Hz
cutoff = 2.5  # desired cutoff frequency, slightly higher than actual 3 sine wave / 2 s = 1.5

nyq = 0.5 * fs  # Nyquist frequency

order = 3  # sine wave approx as quadratic
n = int(T * fs)
normal_cutoff = cutoff / nyq

In [None]:
def butter_bandpass_filter_torch(data, lowcut, highcut, fs):
    return bandpass_biquad(data, fs, (highcut + lowcut) / 2, (highcut - lowcut) / (highcut + lowcut))

In [None]:
# normal_cutoff = (21.83/fs, 500/fs)
# def butter_bandpass_filter(data, normal_cutoff, fs, order=2):
#     b, a = butter(order, normal_cutoff, btype="bandpass", analog=False)
#     y = filtfilt(b, a, data)
#     return y

In [None]:
def butter_bandpass_filter(data, low, high, fs, order):
    sos = butter(order, [low, high], btype="bandpass", output="sos", fs=fs)
    normalization = np.sqrt((high - low) / (fs / 2))
    return sosfiltfilt(sos, data) / normalization

In [None]:
def butter_lowpass_filter(data, normal_cutoff, fs, order):
    
    # Get filter coeff
    b, a = butter(order, normal_cutoff, btype="lowpass", analog=False)
    y = filtfilt(b, a, data)
    
    return y

In [None]:
# y = min_max_scaler(butter_bandpass_filter(wave, normal_cutoff, fs, 3))
data = torch.from_numpy(wave)
y = butter_bandpass_filter(data, 21.83, 500, fs, 4)

In [None]:
plt.figure(dpi=120)
plt.plot(range(len(wave[0])), y[0])

In [None]:
plt.figure(dpi=120)
plt.plot(range(len(wave[0])), y[0])

# Continuation

In [None]:
from scipy.signal import spectrogram

plt.figure(dpi=120)
for i in range(len(wave)):
    f, t, Sxx = spectrogram(wave1[i], fs=10)
    plt.pcolormesh(t, f, Sxx, shading="gouraud")

In [None]:
# plt.figure(dpi=120)
# f, t, Sxx = spectrogram(wave1[0], fs=4096)
# plt.pcolormesh(t, fftshift(f), fftshift(Sxx), shading="gouraud")

In [None]:
def wrapper_plot(m):
    plt.figure(dpi=120)
    m()
    plt.show()

In [None]:
stacked = []
for j in range(len(wave1)):
    melspec = librosa.feature.melspectrogram(wave1[j], sr=4096, n_mels=128, fmin=21.83, fmax=2048)
    melspec = librosa.power_to_db(melspec)
    melspec = melspec.transpose((1, 0))
    stacked.append(melspec)
image = np.vstack(stacked)
wrapper_plot(lambda: plt.imshow(image))

In [None]:
t.min()

## Finish playing
Now is time to use dataset created by Y. Nakama and continue. 

In [None]:
# X = np.load("../input/g2net-n-mels-128-train-images-aggregated/X.npy")

In [None]:
# y = np.load("../input/g2net-n-mels-128-train-images-aggregated/y.npy")

In [None]:
# X.shape

In [None]:
!pip install -q nnAudio

In [None]:
from nnAudio.Spectrogram import *
import torch

In [None]:
# @njit(nogil=True)
# def min_max_scaler_hstack(wave):
#     for i in range(len(wave)):
#         wave[i] = (wave[i] - min(wave[i])) / (max(wave[i]) - min(wave[i]))
        
#     wave = np.hstack(wave)
#     return wave

In [None]:
import gc
gc.collect()
# import torch
# torch.cuda.empty_cache()

In [None]:
normal_cutoff = (20/nyq, 500/nyq)

In [None]:
from scipy.signal import cwt, ricker

In [None]:
wave.shape

In [None]:
import time

In [None]:
# Taken from https://www.kaggle.com/anjum48/continuous-wavelet-transform-cwt-in-pytorch

class CWT(nn.Module):
    def __init__(
        self,
        widths,
        wavelet="ricker",
        channels=1,
        filter_len=2000,
        bs=1,
    ):
        """PyTorch implementation of a continuous wavelet transform.

        Args:
            widths (iterable): The wavelet scales to use, e.g. np.arange(1, 33)
            wavelet (str, optional): Name of wavelet. Either "ricker" or "morlet".
            Defaults to "ricker".
            channels (int, optional): Number of audio channels in the input. Defaults to 3.
            filter_len (int, optional): Size of the wavelet filter bank. Set to
            the number of samples but can be smaller to save memory. Defaults to 2000.
        """
        super().__init__()
        self.widths = torch.from_numpy(widths)
        self.wavelet = getattr(self, wavelet)
        self.filter_len = filter_len
        self.bs = bs
        self.channels = channels
        self.wavelet_bank = self._build_wavelet_bank()

    def ricker(self, points, a):
        # https://github.com/scipy/scipy/blob/v1.7.1/scipy/signal/wavelets.py#L262-L306
        a = torch.Tensor([a])
        A = 2 / (torch.sqrt(3 * a) * (np.pi ** 0.25))
        wsq = a ** 2
        vec = torch.arange(0, points) - (points - 1.0) / 2
        xsq = vec ** 2
        mod = 1 - xsq / wsq
        gauss = torch.exp(-xsq / (2 * wsq))
        total = A * mod * gauss
        return total

    def morlet(self, points, s):
        s = torch.Tensor([s])
        x = torch.arange(0, points) - (points - 1.0) / 2
        x = x / s
        # https://pywavelets.readthedocs.io/en/latest/ref/cwt.html#morlet-wavelet
        wavelet = torch.exp(-(x ** 2.0) / 2.0) * torch.cos(5.0 * x)
        output = torch.sqrt(1 / s) * wavelet
        return output

    def cmorlet(self, points, s, wavelet_width=1, center_freq=1):
        # https://pywavelets.readthedocs.io/en/latest/ref/cwt.html#complex-morlet-wavelets
        s = torch.Tensor([s])
        x = torch.arange(0, points) - (points - 1.0) / 2
        x = x / s
        norm_constant = torch.sqrt(torch.Tensor([np.pi * wavelet_width]))
        exp_term = torch.exp(-(x ** 2) / wavelet_width)
        kernel_base = exp_term / norm_constant
#         kernel = kernel_base * torch.exp(1j * 2 * np.pi * center_freq * x)
        kernel_real = kernel_base * torch.cos(2 * np.pi * center_freq * x)
        kernel_imag = kernel_base * torch.sin(2 * np.pi * center_freq * x)
        return kernel_real, kernel_imag

    def _build_wavelet_bank(self):
        wavelet_bank_real = []
        wavelet_bank_imag = []
        for w in self.widths:
            wavelet_bank = self.wavelet(self.filter_len, w)
            wavelet_bank_real.append(wavelet_bank[0])
            wavelet_bank_imag.append(wavelet_bank[1])
#         wavelet_bank = [self.wavelet(self.filter_len, w) for w in self.widths]
        wavelet_bank_real = torch.stack(wavelet_bank_real)
        wavelet_bank_imag = torch.stack(wavelet_bank_imag)
        wavelet_bank_real = wavelet_bank_real.view(
            wavelet_bank_real.shape[0], 1, 1, wavelet_bank_real.shape[1]
        )
        wavelet_bank_imag = wavelet_bank_imag.view(
            wavelet_bank_imag.shape[0], 1, 1, wavelet_bank_imag.shape[1]
        )
        wavelet_bank_real = torch.cat([wavelet_bank_real] * self.channels, 2)
        wavelet_bank_imag = torch.cat([wavelet_bank_imag] * self.channels, 2)
#         wavelet_bank_real = torch.cat([wavelet_bank_real] * self.bs, 1)
#         wavelet_bank_imag = torch.cat([wavelet_bank_imag] * self.bs, 1)
        return wavelet_bank_real, wavelet_bank_imag
        

#     def _build_wavelet_bank(self):
#         """This function builds a 2D wavelet filter using wavelets at different scales

#         Returns:
#             tensor: Tensor of shape (num_widths, 1, channels, filter_len)
#         """
#         wavelet_bank = [
#             torch.conj(torch.flip(self.wavelet(self.filter_len, w), [-1]))
#             for w in self.widths
#         ]
#         wavelet_bank = torch.stack(wavelet_bank)
#         wavelet_bank = wavelet_bank.view(
#             wavelet_bank.shape[0], 1, 1, wavelet_bank.shape[1]
#         )
#         wavelet_bank = torch.cat([wavelet_bank] * self.channels, 2)
#         return wavelet_bank

    def forward(self, x):
        """Compute CWT arrays from a batch of multi-channel inputs

        Args:
            x (torch.tensor): Tensor of shape (batch_size, channels, time)

        Returns:
            torch.tensor: Tensor of shape (batch_size, channels, widths, time)
        """
        x = x.unsqueeze(1)
#         if self.wavelet_bank.is_complex():
        if type(self.wavelet_bank) == tuple:
#             wavelet_real = self.wavelet_bank.real.to(device=x.device, dtype=x.dtype)
#             wavelet_imag = self.wavelet_bank.imag.to(device=x.device, dtype=x.dtype)
            wavelet_real = self.wavelet_bank[0].to(device=x.device, dtype=x.dtype)
            wavelet_imag = self.wavelet_bank[1].to(device=x.device, dtype=x.dtype)

            output_real = nn.functional.conv2d(x, wavelet_real, padding="same")
            output_imag = nn.functional.conv2d(x, wavelet_imag, padding="same")
            output_real = torch.transpose(output_real, 1, 2)
            output_imag = torch.transpose(output_imag, 1, 2)
#             return torch.complex(output_real, output_imag)
            return torch.sqrt(output_real**2 + output_imag**2)
        else:
            self.wavelet_bank = self.wavelet_bank.to(device=x.device, dtype=x.dtype)
            output = nn.functional.conv2d(x, self.wavelet_bank, padding="same")
            return torch.transpose(output, 1, 2)

In [None]:
widths = np.arange(25, 89)
pycwt = CWT(widths, "cmorlet", 3, 4096)

In [None]:
wavelet_bank_real = pycwt.wavelet_bank[0]
wavelet_bank_real.shape

In [None]:
bs = 16
torch.cat([wavelet_bank_real] * bs, 1).shape

In [None]:
imgs = []
for i in range(1, 5): imgs.append(np.load(train_files[i]))
imgs = torch.from_numpy(np.array(imgs))
imgs.shape

In [None]:
%timeit our_imgs = pycwt(imgs)

In [None]:
imgs = []
for i in range(1, 9): imgs.append(np.load(train_files[i]))
imgs = torch.from_numpy(np.array(imgs))
imgs.shape

In [None]:
%timeit _ = pycwt(imgs)

In [None]:
our_imgs = pycwt(imgs)

In [None]:
@njit(nogil=True)
def min_max_scaler_int8(wave):
    return (wave - wave.min()) / (wave.max() - wave.min()) * 255

In [None]:
def image_to_int8(image): 
#     g = (image - image.min()) / (image.max() - image.min())
    return np.round_(min_max_scaler_int8(image)).astype(np.uint8)

In [None]:
m = our_imgs[0, 2].numpy().copy()

In [None]:
our_imgs[1, 1].numpy()

In [None]:
from PIL import Image
Image.fromarray(image_to_int8(our_imgs[1, 1].numpy())).convert("RGB").resize((400, 300)).save("data.png")

With numpy: 2.19s.  
With pytorch (no GPU) also around 2.18s.  
Without using complex numbers: 740ms. (but with slightly different output).

with all: $2.16 s\pm 16.4 ms$  
without norm-const: $2.2 s \pm 108 ms$  
without exponential: $2.18 s \pm 61.9 ms$  
without kernel calc: $1.07 s \pm 4.39 ms$

In [None]:
%timeit pycwt(torch.from_numpy(wave).view(1, 3, 4096))

In [None]:
min_max_scaler(butter_bandpass_filter(wave, 20, 500, fs, 4))

Currently we are taking mean of all 3 waves. Perhaps there are other methods. 

In [None]:
def apply_qtransform(waves, transform=None, cuda=False):
#     waves *= scipy.signal.tukey(4096, 0.2)
    waves = min_max_scaler(butter_bandpass_filter(waves, 27.5, 466.16, fs, 4))
    waves = np.ascontiguousarray(waves)
#     waves = np.hstack(waves)
    waves = torch.from_numpy(waves).float().view(1, 3, 4096)
    if cuda: waves = waves.cuda()
    image = torch.abs(pycwt(waves))
    image = torch.mean(image, dim=1).squeeze()  # Get mean of all 3 different waves.
#     image = transform(waves)
    return image


imgs = []
for i in tqdm(range(10)):
    wave = np.load(train_files[i])
#     img = apply_qtransform(wave, transform=CQT1992v2(sr=2048, fmin=21.83, fmax=1024, hop_length=64))
    img = apply_qtransform(wave, cuda=True if torch.cuda.is_available() else False)
    imgs.append(img)
print(img.shape)

In [None]:
for i in range(10):
    plt.figure(dpi=150)
    plt.imshow(imgs[i].cpu().numpy().squeeze(), aspect="auto")

In [None]:
del imgs
gc.collect()

In [None]:
os.mkdir("train/")
OUT_DIR = "train/"

labels = pd.read_csv("../input/g2net-gravitational-wave-detection/training_labels.csv")
labels["file_path"] = train_files

pd.set_option("display.max_colwidth", None)
labels.head()

In [None]:
ones_train = labels[labels["target"] == 1]["file_path"].to_numpy()

In [None]:
def save_images(file_path, out_dir):
    file_name = file_path.split('/')[-1].split('.npy')[0]
    waves = np.load(file_path).astype(np.float32) # (3, 4096)
    image = apply_qtransform(wave, cuda=True).cpu()
    plt.imsave(out_dir + file_name + ".png", image.cpu().numpy().squeeze())

In [None]:
# Saving all the 1's in the 1's folder. 
import joblib
from tqdm.auto import tqdm

folder_name = "train/ones/"

os.makedirs(folder_name, exist_ok=True)

_ = joblib.Parallel(n_jobs=8, prefer="threads")(
    joblib.delayed(save_images)(file_path, out_dir=folder_name) for file_path in tqdm(ones_train)
)

In [None]:
folder_name = "train/zero/"
zeroes_train = labels[labels["target"] == 0]["file_path"].to_numpy()

os.makedirs(folder_name, exist_ok=True)

_ = joblib.Parallel(n_jobs=8, prefer="threads")(
    joblib.delayed(save_images)(file_path, out_dir=folder_name) for file_path in tqdm(zeroes_train)
)

In [None]:
import os
import shutil

def move_to_destination(origin, destination, percentage_split):
    num_images = int(len(os.listdir(origin))*percentage_split)
    for image_name, image_number in zip(sorted(os.listdir(origin)), range(num_images)):
        shutil.move(os.path.join(origin, image_name), destination)

In [None]:
os.makedirs("./valid/ones")
os.makedirs("./valid/zero")
move_to_destination("./train/ones", "./valid/ones", 0.2)
move_to_destination("./train/zero", "./valid/zero", 0.2)

In [None]:
%%time
import shutil

shutil.make_archive("train/", 'zip', "train/")
shutil.rmtree("train/")

shutil.make_archive("valid/", "zip", "valid/")
shutil.rmtree("valid/")

In [None]:
OUT_DIR = "test/"
os.mkdir("test/")
test_files = sorted(glob.glob("../input/g2net-gravitational-wave-detection/test/*/*/*/*.npy"))

_ = joblib.Parallel(n_jobs=8, prefer="threads")(
    joblib.delayed(save_images)(file_path, out_dir=OUT_DIR) for file_path in tqdm(test_files)
)

In [None]:
%%time
shutil.make_archive("test/", 'zip', "test/")
shutil.rmtree("test/")