This notebook shows how Constant Wavelet Transform from waves can be calculated with usage of PyTorchWavelets package (https://github.com/tomrunia/PyTorchWavelets).

Note: due to unfixed bug of the official version, I use fixed version from https://github.com/ar4/PyTorchWavelets/blob/master/wavelets_pytorch/transform.py

PyTorchWavelets is a SciPy/PyTorch implementation for the wavelet analysis outlined in Torrence and Compo (BAMS, 1998). 

Have any questions or suggestions? Please comment below.

**<font color='red'>And if you liked this notebook, please upvote it!</font>**

**Changelog**
* v2 - number of processed samples can be now easily changed via num_samples variable
* v1 - initial version

## Import packages

In [None]:
!git clone https://github.com/ar4/PyTorchWavelets.git > /dev/null
%cd PyTorchWavelets
!pip install -r requirements.txt > /dev/null
!python setup.py install > /dev/null

In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from scipy import signal
from scipy.cluster.vq import whiten
import torch
from torch.utils.data import Dataset
from wavelets_pytorch.transform import WaveletTransform # Use WaveletTransformTorch to use with PyTorch

import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

In [None]:
num_samples = 4 # first N samples to process

## Define dataset

Let's define a dataset to work with.

In [None]:
class G2NetDataset(Dataset):
    def __init__(self, paths, targets, use_filter=True): 
        self.paths = paths
        self.targets = targets
        self.use_filter = use_filter
        if self.use_filter:
            self.bHP, self.aHP = signal.butter(8, (20, 500), btype='bandpass', fs=2048)

    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):      
        waves = np.load(self.paths[index])
        waves = np.concatenate(waves, axis=0)
        if self.use_filter:
            waves *= signal.tukey(4096*3, 0.2)
            waves = signal.filtfilt(self.bHP, self.aHP, waves)
        waves = waves / np.max(waves)
        targets = self.targets[index]
                
        return {
            "waves": torch.tensor(waves, dtype=torch.float),
            "target": torch.tensor(targets, dtype=torch.long),
        }

## Read training labels

Now we read training labels data, and get npy paths.

In [None]:
ROOT_DIR = '/kaggle/input/g2net-gravitational-wave-detection'
df = pd.read_csv(os.path.join(ROOT_DIR, 'training_labels.csv'))
df['path'] = df['id'].apply(lambda x: f'{ROOT_DIR}/train/{x[0]}/{x[1]}/{x[2]}/{x}.npy')

## Demonstrate CWT usage

Let's calculate CWT for 4 first signals with and without usage of a bandpass filter (20-500Hz), and plot results!

In [None]:
transform = WaveletTransform(dt=0.1)  

ds = G2NetDataset(df['path'], df['target'], use_filter=False)
ds_f = G2NetDataset(df['path'], df['target'], use_filter=True)

waves = []
waves_f = []
cwts = []
cwts_f = []
for i in range(num_samples):
    waves.append(ds.__getitem__(i)['waves'])
    waves_f.append(ds_f.__getitem__(i)['waves'])
    cwts.append(transform.power(waves[i]).squeeze())
    cwts_f.append(transform.power(waves_f[i]).squeeze())

### Without a filter

In [None]:
fig, axs = plt.subplots(num_samples)
fig.set_figheight(15)
fig.set_figwidth(15)
for i in range(num_samples):
    nid = df['id'][i]
    ntarget = df['target'][i]
    axs[i].title.set_text(f'{nid}.npy, target: {ntarget}')
    axs[i].plot(waves[i])

In [None]:
fig, axs = plt.subplots(num_samples)
fig.set_figheight(15)
fig.set_figwidth(15)
for i in range(num_samples):
    nid = df['id'][i]
    ntarget = df['target'][i]
    axs[i].title.set_text(f'{nid}.npy, target: {ntarget}')
    axs[i].pcolormesh(cwts[i])

### With a filter with Tukey window

In [None]:
fig, axs = plt.subplots(num_samples)
fig.set_figheight(15)
fig.set_figwidth(15)
for i in range(num_samples):
    nid = df['id'][i]
    ntarget = df['target'][i]
    axs[i].title.set_text(f'{nid}.npy, target: {ntarget}')
    axs[i].plot(waves_f[i])

In [None]:
fig, axs = plt.subplots(num_samples)
fig.set_figheight(15)
fig.set_figwidth(15)
for i in range(num_samples):
    nid = df['id'][i]
    ntarget = df['target'][i]
    axs[i].title.set_text(f'{nid}.npy, target: {ntarget}')
    axs[i].pcolormesh(cwts_f[i])

You can use WaveletTransformTorch() as your model block to convert waves to CWT on-the-fly in PyTorch models.