<div style="text-align: center"><h2><font color="sky_blue">G2Net Gravitational Wave Detection</font></h2></div>

![](https://www.nasa.gov/sites/default/files/thumbnails/image/smbhb_rotate_banner.gif)
Simulation of Spiraling Supermassive Blackholes. Source - [NASA](https://www.nasa.gov/feature/goddard/2018/new-simulation-sheds-light-on-spiraling-supermassive-black-holes)

## Table of Content
1. [Resources](#Resources)
1. [Metadata](#Metadata)
1. [Visualize Waves](#Visualize-Waves)



## Resources
<a id='Resources'></a>
- Huge collection of ML + G-Waves [link](https://iphysresearch.github.io/Survey4GWML/)
- CNN for simulated G-Wave detection - [link](https://arxiv.org/pdf/2011.04418.pdf)

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import urllib
import torchvision
from torchvision import transforms as T
import torch
import pytorch_lightning as pl
import torchmetrics
import pandas as pd
import librosa
import librosa.display
import numpy as np
import typing
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
import warnings

sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

from kaggle_secrets import UserSecretsClient
from pytorch_lightning.loggers import WandbLogger

# user_secrets = UserSecretsClient()
# import wandb
# os.environ["WANDB_API_KEY"] = user_secrets.get_secret("wandb_api_key")
# print(wandb.__version__)

### Configuration

In [None]:
class opt:
    batch_size = 128
    num_workers = 8
    lr = 3e-4
    epochs = 10
    model_name = "efficientnet_b0"
    img_dim = 224
    seed = 400
    n_fold = 5
    pin_memory = True
    
pl.seed_everything(opt.seed)

## Metadata
<a id='Metadata'></a>

In [None]:
root = "/kaggle/input/g2net-gravitational-wave-detection/"

train_df = pd.read_csv("/kaggle/input/g2net-mapping-id-to-file-path/training_labels_with_paths.csv")
test_df = pd.read_csv("/kaggle/input/g2net-gravitational-wave-detection/sample_submission.csv")

print(train_df.head(2))
print(test_df.head(2))

KFold = StratifiedKFold(n_splits=opt.n_fold, shuffle=True, random_state=opt.seed)
for n, (train_index, val_index) in enumerate(KFold.split(train_df, train_df["target"])):
    train_df.loc[val_index, 'fold'] = int(n)
train_df['fold'] = train_df['fold'].astype(int)
    
print(train_df.groupby(['fold', "target"]).size())

## Visualize Waves
<a id='Visualize Waves'></a>

In [None]:
def plot_waves(sample: np.ndarray, title: str) -> None:
    plt.style.use("seaborn-muted")
    fig, axs = plt.subplots(1, 3, figsize=(20, 3), sharey=True)
    fig.suptitle(title)
    _ = axs[0].plot(sample[0], c='pink')
    _ = axs[1].plot(sample[1], c='skyblue')
    _ = axs[2].plot(sample[2], c='lightgreen')

def plot_specs(sample: np.ndarray, title: str) -> None:
    plt.style.use("seaborn-muted")
    fig, axs = plt.subplots(1, 3, figsize=(20, 3), sharey=True)
    fig.suptitle(title)
    _ = librosa.display.specshow(sample[0], sr=2048, ax=axs[0], vmin=-200, vmax=50)
    _ = librosa.display.specshow(sample[1], sr=2048, ax=axs[1], vmin=-200, vmax=50)
    _ = librosa.display.specshow(sample[2], sr=2048, ax=axs[2], vmin=-200, vmax=50)

#     _ = axs[0] = plt.plot(sample[0], c='pink')
#     _ = axs[1]plot(sample[1], c='skyblue')
#     _ = axs[2].plot(sample[2], c='lightgreen')


pos_df = train_df[train_df.target==1].iloc[0]
neg_df = train_df[train_df.target==0].iloc[0]

neg_wav = np.load(neg_df.filepath)
plot_waves(neg_wav, f"target {neg_df.target}")
plt.show()
pos_wav = np.load(pos_df.filepath)
plot_waves(pos_wav, f"target {pos_df.target}")

In [None]:
def get_spec(x):
    x = librosa.feature.melspectrogram(y=x/x.max(), sr=2048)
    x = librosa.power_to_db(x)
    return x
    
pos_spec = [get_spec(x) for x in pos_wav]
plot_specs(pos_spec, "target 1")

neg_spec = [get_spec(x) for x in neg_wav]
plot_specs(neg_spec, "target 0")

## Dataset and Model

In [None]:
class G2NetData(torch.utils.data.Dataset):
    def __init__(self, df:pd.DataFrame, train:bool=True):
        self.paths = df["path"].values
        self.targets = df["target"].values
        self.npy_loader = lambda x: np.load(x)[::2].astype(np.float32)
        self.flatten_channels = lambda x: np.vstack(x).T
        transforms_train = T.Compose([T.Resize((opt.img_dim, opt.img_dim)),
                                    T.RandomHorizontalFlip(),
                                    T.RandomVerticalFlip()])
        transforms_test = T.Compose([T.Resize((opt.img_dim, opt.img_dim))])
        self.transforms = transforms_train if train else transforms_test

    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, idx):
        inp = self.npy_loader(self.paths[idx])
        inp = self.flatten_channels(inp)
        inp = self.transforms(torch.from_numpy(inp).unsqueeze(0))
        target = self.targets[idx].astype(np.float32)
        
        return inp, target

class G2Net(pl.LightningModule):
    def __init__(self, pretrained:bool=True):
        super().__init__()
        self.model = timm.create_model(model_name = opt.model_name, pretrained = pretrained, in_chans = 1, num_classes = 1)
        self.val_score = torchmetrics.AUROC(pos_label=1, compute_on_step=False)
        self.test_score = torchmetrics.AUROC(pos_label=1, compute_on_step=False)
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat.view(-1), y)

        self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x).view(-1)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y)
        
        self.val_score(y_hat.sigmoid(), y.to(torch.int))
        self.log('val_score', self.val_score, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        return {"val_loss": loss}
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x).view(-1)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y)
        self.test_score(y_hat.sigmoid(), y.to(torch.int))
        self.log('test_score', self.test_score, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        return {"test_loss": loss}
    
    def predict(self, batch, batch_idx: int , dataloader_idx: int = None):
        x, y = batch
        y_hat = self(x).view(-1)
        return y_hat.sigmoid()
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=opt.lr)


<div class="alert alert-block alert-info">
<b>Note:</b>  Work in Progress
</div>