# Pytorch Lightning Efficient for G2Net

based on:

- [G2Net / efficientnet_b7 / baseline [training]](https://www.kaggle.com/yasufuminakama/g2net-efficientnet-b7-baseline-training) 
- [PL 1Fold CQT + DeepSpeed Op Baseline + W&B [.84]](https://www.kaggle.com/ligtfeather/pl-1fold-cqt-deepspeed-op-baseline-w-b-84)

TODOs:

- add deep speed optimizers - imp.
- use 16 precision - imp. TESTED: Extremely slow comparing to 32 bit precision. Unknown reason (may need deep speed).
- add callback for timing - good.
- model optimization - imp.

RECORDINGs:
    
- highest validation auc in v7 (32 bit precision): 0.853

FATAL ERROR: something is causing the notebook increase system memory for every batch and I can't find why ...

## Preparations

In [None]:
!pip install -q nnAudio -qq

In [None]:
import os
import sys

# Path settings.
OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)
    
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')

import math
import time
import random
import shutil
from pathlib import Path
from contextlib import contextmanager
from collections import defaultdict, Counter

import scipy as sp
import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

from tqdm.auto import tqdm
from functools import partial

import cv2
from PIL import Image

import plotly.express as px

from nnAudio.Spectrogram import CQT1992v2 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler

import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import timm

import warnings
warnings.filterwarnings('ignore')

In [None]:
# General parameter settings.
class CFG:
    apex=False
    debug=False
    print_freq=100
    num_workers=4
    model_name='tf_efficientnet_b7_ns'
    scheduler='CosineAnnealingLR' # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts']
    epochs=3 
    precision=32 # 16
    #factor=0.2 # ReduceLROnPlateau
    #patience=4 # ReduceLROnPlateau
    #eps=1e-6 # ReduceLROnPlateau
    T_max=3 # CosineAnnealingLR
    #T_0=3 # CosineAnnealingWarmRestarts
    lr=1e-4
    min_lr=1e-6
    batch_size=64 # Tried: 48 for 32 bit precision (memory error), 64 for 16 bit precision.
    weight_decay=1e-6
    gradient_accumulation_steps=1
    max_grad_norm=1000
    qtransform_params={"sr": 2048, "fmin": 20, "fmax": 1024, "hop_length": 32, "bins_per_octave": 8}
    seed=42
    target_size=1
    target_col='target'
    n_fold=5
    trn_fold=[0] # [0, 1, 2, 3, 4]
    train=True
    grad_cam=True
    
if CFG.debug:
    CFG.epochs = 1
    train = train.sample(n=10000, random_state=CFG.seed).reset_index(drop=True)
    
def class2dict(f):
    return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

In [None]:
# Wandb settigs - login using kaggle secrets.
from kaggle_secrets import UserSecretsClient
import wandb

user_secrets = UserSecretsClient()
wandb_api = user_secrets.get_secret("wandb_api")
wandb.login(key=wandb_api)

def class2dict(f):
    return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

run = wandb.init(project='G2Net-exp', 
                 group=CFG.model_name, # group changes with model for split cv
                 config=class2dict(CFG),
                 job_type='train')

In [None]:
seed_everything(42)

## Data Loading & Data Engineering

In [None]:
train_df = pd.read_csv('../input/g2net-gravitational-wave-detection/training_labels.csv')
test_df = pd.read_csv('../input/g2net-gravitational-wave-detection/sample_submission.csv')

train_dir = '../input/g2net-gravitational-wave-detection/train/'
test_dir = '../input/g2net-gravitational-wave-detection/test/'

def id2path(name, folder=train_dir):
    path = os.path.join(folder, f'{name[0]}/{name[1]}/{name[2]}/{name}.npy')
    return path

train_df['path'] = train_df['id'].apply(lambda x: id2path(x, train_dir))
test_df['path'] = train_df['id'].apply(lambda x: id2path(x, test_dir))

In [None]:
train_df.head()

In [None]:
test_df.head()

## DataModule

In [None]:
# Transform functions. Add Channel dimension (default 1) for (H, W) input.
def get_train_transforms():
    return A.Compose([
        ToTensorV2(p=1.0)
    ])

def get_valid_transforms():
    return A.Compose([
        ToTensorV2(p=1.0)
    ])

def get_test_transforms():
    return A.Compose([
        ToTensorV2(p=1.0)
    ])

In [None]:
# Dataset.
class BaseDataset(Dataset):
    """Basic dataset for training, validation and testing.
    
    Must thoroughly inspect the data dimensions after each transformation. Because you
    wouldn't know what they will become unless you test with samples.
    For example, both toTensorV2 and nnAudio CQT transform will add one dimension (out the last)
    to the input data. We need to squeeze & unsqueeze them a lot.
    """
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['path'].values
        self.labels = df[CFG.target_col].values
        self.wave_transform = CQT1992v2(**CFG.qtransform_params)
        self.transform = transform # image transform
        
    def apply_qtransform(self, waves):
        waves = np.hstack(waves) # 1D data array
        waves = waves / np.max(waves) # normalization
        waves = torch.from_numpy(waves).float()
       
        image = self.wave_transform(waves) # [1, 46, 385]
        
        return image
    
    def __getitem__(self, idx): # always idx
        # Load data.
        file_path = self.file_names[idx]
        waves = np.load(file_path)
        
        # Wave2Image transformation applied.
        image = self.apply_qtransform(waves)
        image = image.squeeze().numpy() # [46, 385]
        
        # Image transformation.
        if self.transform:
            image = self.transform(image=image)['image'] # [1, 46, 385]
            
        label = torch.tensor(self.labels[idx]).float()
        
        return image, label
        
    def __len__(self):
        return len(self.file_names)

In [None]:
# View dataset samples: image - 2d sigle channel.
train_dataset = BaseDataset(train_df, transform=get_train_transforms())

for i in random.sample(list(range(len(train_dataset))), 3):
    image, label = train_dataset[i]
    fig = px.imshow(image.squeeze(), title=f'label: {int(label)}')
    fig.update_layout(height=200,
                      title={
                          'x': 0.5,
                          'y': 0.9
                      },
                      margin=dict(t=10, b=10, l=10, r=10))
    fig.show()

In [None]:
# Data split before feeding into datloader.
fold = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)

for n, (train_idx, val_idx) in enumerate(fold.split(train_df, train_df[CFG.target_col])):
    train_df.loc[val_idx, 'fold'] = int(n) # Give fold label for validation folds.
    
train_df['fold'] = train_df['fold'].astype(int)
print(train_df.groupby(['fold', 'target']).size())

In [None]:
# DataModule ready for split. When transfer data from CPU to GPU, use pin_memory to speed up.
class G2NetDataModule(pl.LightningDataModule):
    def __init__(self, df: pd.DataFrame, fold: int = 0):
        super().__init__()
        self.df = df
        self.fold = fold
       
    def setup(self):
        train_idxs = self.df[self.df['fold'] != self.fold].index
        valid_idxs = self.df[self.df['fold'] == self.fold].index
        
        train_fold = self.df.loc[train_idxs].reset_index(drop=True)
        valid_fold = self.df.loc[valid_idxs].reset_index(drop=True)
        
        self.train_dataset = BaseDataset(train_fold, transform=get_train_transforms())
        self.valid_dataset = BaseDataset(valid_fold, transform=get_valid_transforms())
        self.test_dataset = BaseDataset(test_df, transform=get_test_transforms())
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=CFG.batch_size,
                          shuffle=True,
                          num_workers=CFG.num_workers,
                          pin_memory=True)
    
    def val_dataloader(self):
        return DataLoader(self.valid_dataset,
                          batch_size=CFG.batch_size,
                          shuffle=False,
                          num_workers=CFG.num_workers,
                          pin_memory=True)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=CFG.batch_size,
                          shuffle=False,
                          num_workers=CFG.num_workers,
                          pin_memory=True)

## Lightning Module

In [None]:
class G2NetLightningModule(pl.LightningModule):
    def __init__(self, pretrained=True):
        super().__init__()
        # Set model. Output: [48, 1], 48 for batches, 1 for output.
        self.model = timm.create_model(CFG.model_name, pretrained=pretrained, in_chans=1)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, CFG.target_size)
        
        # Set loss function.
        self.criterion = nn.BCEWithLogitsLoss()
    
    def forward(self, x):
        output = self.model(x)
        return output
    
    def training_step(self, batch, batch_idx):        
        x, y = batch # [C, H, W], [labels] from dataloader.
        y_preds = self.model(x).view(-1) # Reduce one dimension.
        loss = self.criterion(y_preds, y)
        score = roc_auc_score(y.cpu().numpy(), y_preds.sigmoid().detach().cpu().numpy())
        
        # Log score and loss for each step (batch).
        self.log('train_auc_step', score, on_step=True, prog_bar=True, logger=True)
        self.log('train_loss_step', loss, on_step=True, prog_bar=True, logger=True)
        
        return dict(loss=loss, predictions=y_preds, labels=y)
    
    def training_epoch_end(self, outputs):
        """Compute and log epoch roc-auc score.
        """
        preds = []
        labels = []
        
        for output in outputs:
            preds += output['predictions']
            labels += output['labels']
            
        preds = torch.stack(preds)
        labels = torch.stack(labels)
        
        train_epoch_auc = roc_auc_score(labels.cpu().numpy(), preds.sigmoid().detach().cpu().numpy())
        self.log('train_auc_epoch', train_epoch_auc, prog_bar=True, logger=True)
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_preds = self.model(x).view(-1)
        loss = self.criterion(y_preds, y)
        
        self.log('val_loss', loss, prog_bar=True, logger=True) # Unable for step log.
        
        return dict(predictions=y_preds, labels=y)
    
    def validation_epoch_end(self, outputs):
        """Compute and log epoch roc-auc score.
        """
        preds = []
        labels = []
        
        for output in outputs:
            preds += output['predictions']
            labels += output['labels']
            
        preds = torch.stack(preds)
        labels = torch.stack(labels)
        
        val_epoch_auc = roc_auc_score(labels.cpu().numpy(), preds.sigmoid().detach().cpu().numpy())
        self.log('val_auc_epoch', val_epoch_auc, prog_bar=True, logger=True)
    
    def configure_optimizers(self):
        # Optimizer.
        optimizer = Adam(self.model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
        
        # Scheduler.
        if CFG.scheduler == 'ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer, 
                                          mode='min', 
                                          factor=CFG.factor, 
                                          patience=CFG.patience, 
                                          verbose=True, 
                                          eps=CFG.eps)
        elif CFG.scheduler == 'CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer, 
                                          T_max=CFG.T_max, 
                                          eta_min=CFG.min_lr, 
                                          last_epoch=-1)
        elif CFG.scheduler == 'CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer, 
                                                    T_0=CFG.T_0, 
                                                    T_mult=1,
                                                    eta_min=CFG.min_lr ,
                                                    last_epoch=-1)
        
        return dict(optimizer=optimizer,
                    lr_scheduler=scheduler)

In [None]:
# Set up callbacks, including checkpoints, early stop, etc.
early_stop_callback = EarlyStopping(monitor='val_auc_epoch',
                                    mode='max',
                                    patience=5) # Early stop in each epoch.
checkpoint_callback = ModelCheckpoint(dirpath='./checkpoints/',
                                      filename='best-checkpoint-fold{fold}-val_auc{val_auc:.3f}',
                                      monitor='val_auc_epoch',
                                      mode='max',
                                      save_top_k=2,
                                      verbose=True)

In [None]:
# Logger settings, including runtime logger and wandb setting.
def get_logger(fold: int):
    return WandbLogger(project='G2Net-exp', # same as init
                       config=class2dict(CFG), # same as init
                       group=CFG.model_name, # group changes with model
                       name=f'fold-{fold}', # fold changes
                       job_type='train')

## Train Loop

In [None]:
# Initialize model.
model = G2NetLightningModule()

for fold in range(CFG.n_fold):
    if fold > 0: break # Debug and Parameter optimization before model selection.
    
    # DataModule.
    data_module = G2NetDataModule(train_df, fold)
    data_module.setup() # Just to check dataset information by running setup.
    
    # Logger.
    wandb_logger = get_logger(fold)
    
    # Trainer.
    trainer = pl.Trainer(gpus=1,
                         precision=CFG.precision,
                         callbacks=[early_stop_callback,
                                    checkpoint_callback],
                         max_epochs=CFG.epochs,
                         logger=wandb_logger,
                         deterministic=True,
                         stochastic_weight_avg=True, # https://arxiv.org/abs/1803.05407
                         progress_bar_refresh_rate=1)
    
    # Let's go
    trainer.fit(model, data_module)
    
    print(f'===Fold-{fold} running successful===')