In [None]:
!pip install -q nnAudio

In [None]:
# installing the torch-xla nightly version
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

In [None]:
!pip install pretrainedmodels

In [None]:
%cd ../input/cpythonlibrary/cpython-master
from Lib import copy
%cd /kaggle/working

In [None]:
!pip install torch-summary

In [None]:
import torch_xla
import torch_xla.distributed.parallel_loader as pl
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
import torch
from torch.utils.data import Dataset
from nnAudio.Spectrogram import CQT # CQT is an alias of CQT1992v2

import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn import model_selection
from PIL import Image
import albumentations
from torch.utils.data import DataLoader
import torch.nn.functional as F
import gc
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from torch.utils.data.sampler import SequentialSampler

import pretrainedmodels
TRAIN_BATCH_SIZE = 1

import warnings
warnings.filterwarnings("ignore")

In [None]:
cqt = CQT(sr=2048,        # sample rate
            fmin=20,        # min freq
            fmax=1024,      # max freq
            hop_length=64,  # hop length
            verbose=False)

sample = np.load("../input/g2net-gravitational-wave-detection/test/0/1/0/01002036c9.npy")
sample = np.concatenate(sample, axis=0)
sample = sample / np.max(sample)
sample = torch.tensor(sample, dtype=torch.float)

const_q_transform = cqt(sample).squeeze()
const_q_transform = const_q_transform.repeat(3, 1, 1)

plt.figure(figsize=(17, 3))
plt.title("01002036c9.npy")
plt.pcolormesh(const_q_transform[0])

const_q_transform.size()

In [None]:
df = pd.read_csv("../input/g2net-gravitational-wave-detection/training_labels.csv")
df.head()

In [None]:
df["kfold"] = -1    
df = df.sample(frac=1).reset_index(drop=True)
y = df.target.values
kf = model_selection.StratifiedKFold(n_splits=5)

for f, (t_, v_) in enumerate(kf.split(X=df, y=y)):
    df.loc[v_, 'kfold'] = f
df.head(10)

In [None]:
class waveformClassification(Dataset):
    def __init__(self, ids,tabular_data):
        self.ids = ids
        self.tabular_data = tabular_data
        
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, index):
        file_path = "../input/g2net-gravitational-wave-detection/train/"+ \
                    self.tabular_data["id"][index][0]+"/"+ \
                    self.tabular_data["id"][index][1] + "/" + \
                    self.tabular_data["id"][index][2] + "/" + \
                    self.tabular_data["id"][index] + ".npy"
        
        sample = np.load(file_path)
        sample = np.concatenate(sample, axis=0)
        sample = sample / np.max(sample)
        sample = torch.tensor(sample, dtype=torch.float)

        const_q_transform = cqt(sample).squeeze()
        const_q_transform = const_q_transform.repeat(3, 1, 1)
        
        return {
            'file_name' :self.tabular_data["id"][index] + ".npy",
            'tabular_data' : const_q_transform,
            'label' : torch.tensor(self.tabular_data["target"][index], dtype = torch.float)
        }

In [None]:
train_data = waveformClassification(ids = [i for i in range(len(df))], 
                                  tabular_data = df)

val_data = waveformClassification(ids = [i for i in range(len(df))], 
                                tabular_data = df)

#dry run 
idx = 472226
plt.figure(figsize=(17, 3))
plt.title(val_data[idx]["file_name"])
plt.pcolormesh(val_data[idx]["tabular_data"][0])

print(val_data[idx]["label"])

In [None]:
class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.model = pretrainedmodels.__dict__['resnet18'](pretrained=None)
        self.dropout = nn.Dropout(0.1)
        self.final_layer = nn.Linear(512 , 1)
        
    def forward(self, inputs):
        batch_size, _, _, _ = inputs.shape
        
        x = self.model.features(inputs)
        
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1)
        outputs = self.final_layer(self.dropout(x))

        return outputs
    
model = ResNet18()
model = model.to(xm.xla_device())

In [None]:
summary(model, (3, 69, 193))

In [None]:
EPOCHS = 5

optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3 * 0.95 * xm.xrt_world_size())

loss_fn = torch.nn.BCEWithLogitsLoss()

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 2, gamma=0.1)

In [None]:
# defining the training loop
def train_loop_fn(data_loader, model, optimizer, device, scheduler=None):
    running_loss = 0.0
    all_targets = 0
    all_predictions = 0
    
    model.train()
    
    for batch_index,dataset in enumerate(data_loader):
        input_data = dataset["tabular_data"]
        targets = dataset["label"]
        
        input_data = input_data.to(device, dtype=torch.float32)
        targets = targets.to(device, dtype=torch.float)
        targets = targets.unsqueeze(1)
        
        optimizer.zero_grad()

        outputs = model(input_data)
        
        y_true = targets.detach().cpu().numpy()
        y_pred = outputs.detach().cpu().numpy()

        
        loss = loss_fn(outputs, targets)

        loss.backward()
        xm.optimizer_step(optimizer)
        scheduler.step(loss)

        running_loss += loss.item()
        
        if batch_index > 0:
            all_targets = np.concatenate((all_targets, y_true), axis=0)
            all_predictions = np.concatenate((all_predictions, y_pred), axis=0)
        else:
            all_targets = y_true
            all_predictions = y_pred
            
    train_loss = running_loss / float(len(train_data))
    train_roc_score = roc_auc_score(all_targets, all_predictions)
    
    return train_loss, train_roc_score

In [None]:
def eval_loop_fn(data_loader, model, device):
    running_loss = 0.0
    all_targets = 0
    all_predictions = 0
    
    model.eval()
    
    for batch_index,dataset in enumerate(data_loader):
        input_data = dataset['tabular_data']
        targets = dataset['label']
        
        input_data = input_data.to(device, dtype=torch.float32)
        targets = targets.to(device, dtype=torch.float)
        targets = targets.unsqueeze(1)

        outputs = model(input_data)
        
        y_true = targets.detach().cpu().numpy()
        y_pred = outputs.detach().cpu().numpy()
        
        loss = loss_fn(outputs, targets)

        running_loss += loss.item()
                    
        if batch_index > 0:
            all_targets = np.concatenate((all_targets, y_true), axis=0)
            all_predictions = np.concatenate((all_predictions, y_pred), axis=0)
        else:
            all_targets = y_true
            all_predictions = y_pred
            
    val_loss = running_loss / float(len(val_data))
    val_roc_score = roc_auc_score(all_targets, all_predictions)
    
    return val_loss, val_roc_score

In [None]:
def _run():
    no_of_folds = 1
    for i in range(no_of_folds):
        a_string = "*" * 20

        print(a_string, " FOLD NUMBER ", i, a_string)
        
        i= 0 
        df_train = df[df.kfold != i].reset_index(drop=True)
        df_valid = df[df.kfold == i].reset_index(drop=True)

        train_data = waveformClassification(ids = [k for k in range(len(df_train))], 
                                  tabular_data = df_train)
        val_data = waveformClassification(ids = [m for m in range(len(df_valid))], 
                                        tabular_data = df_valid)

        train_sampler = torch.utils.data.distributed.DistributedSampler(
                  train_data,
                  num_replicas=xm.xrt_world_size(),
                  rank=xm.get_ordinal(),
                  shuffle=True)

        valid_sampler = torch.utils.data.distributed.DistributedSampler(
                  val_data,
                  num_replicas=xm.xrt_world_size(),
                  rank=xm.get_ordinal(),
                  shuffle=False)

        training_dataloader = DataLoader(train_data,
                                num_workers= 4,
                                batch_size= TRAIN_BATCH_SIZE,
                                sampler=train_sampler,
                                drop_last=True
                               )

        val_dataloader = DataLoader(val_data,
                                num_workers= 4,
                                batch_size= TRAIN_BATCH_SIZE,
                                sampler=valid_sampler,
                                drop_last=False
                               )
        all_accuracies = []
        
        for epoch in range(EPOCHS):
            xm.master_print(f"Epoch --> {epoch+1} / {EPOCHS}")
            xm.master_print(f"-------------------------------")
            
            train_para_loader = pl.ParallelLoader(training_dataloader, [xm.xla_device()])
            train_loss, train_roc = train_loop_fn(train_para_loader.per_device_loader(xm.xla_device()), model, optimizer, xm.xla_device(), scheduler)
            xm.master_print(f'training Loss: {train_loss} & training ROC Score: {train_roc}.')
            
            val_para_loader = pl.ParallelLoader(val_dataloader, [xm.xla_device()])
            valid_loss, val_roc = eval_loop_fn(val_para_loader.per_device_loader(xm.xla_device()), model, xm.xla_device())
            xm.master_print(f'validation Loss: {valid_loss} & validation ROC Score: {val_roc} \n')
            
            all_accuracies.append(val_roc)
        xm.master_print('\n')
        
        if i < 1:
            best_accuracy = max(all_accuracies)
            best_model = copy.deepcopy(model)
        else:
            if best_accuracy > max(all_accuracies):
                continue
            else:
                best_accuracy = max(all_accuracies)
                best_model = copy.deepcopy(model)
        
    torch.save(best_model.state_dict(),'./first_basic_model.bin')
    xm.master_print()
    xm.master_print("The highest ROC core that we got across all the folds is {:.2f}".format(best_accuracy))
    
    return best_model

In [None]:
# initializing the training of model
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = _run()
    
# applying multiprocessing so that images get paralley trained in different cores of kaggle-tpu
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1, start_method='fork')