# G2Net Multi-Model PyTorch Training Script with W&B Experiment Tracking ðŸš€

Training Script with Multiple model architecture support along with W&B experiment tracking.

I have tried to write as generalised training script as possible and I will be pushing major updates to it in the future so keep checking!

<center>
<img src="https://i.imgur.com/gb6B4ig.png" width=300px height=100px>
</center>
<br>
Wandb is a developer tool for companies turn deep learning research projects into deployed software by helping teams track their models, visualize model performance and easily automate training and improving models. We will use their tools to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.

You can check more about them here: [wandb.ai](https://www.wandb.ai)

*Note: Huge Credits to Y.Nakama's Datasets ([train](https://www.kaggle.com/yasufuminakama/g2net-n-mels-128-train-images) and [test](https://www.kaggle.com/yasufuminakama/g2net-n-mels-128-test-images)) that I am using in this notebook along with the loading functions!*

**If you found this notebook useful you can leave an upvote! If there's scope for improvement or any mistakes you found, please comment down below!**

In [None]:
%%sh
pip install -q timm
pip install -q wandb --upgrade

## 1. Imports and Utility functions

Some imports, utility functions and W&B login.

In [None]:
import os
import sys
import gc
import platform
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold

import timm
import cv2

import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import wandb

import warnings
warnings.simplefilter('ignore')

In [None]:
# W&B Login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wb_key = user_secrets.get_secret("WANDB_API_KEY")

wandb.login(key=wb_key)

CONFIG = dict(
    lr=1e-5,
    autocast = True,
    resize=(224, 224),
    model_name = 'tf_efficientnet_b4',
    pretrained = True,
    epochs = 3,
    scheduler = 'CosineAnnealingLR',
    n_splits = 5,
    split = 0.97,
    folds = [1, 2, 3, 4, 5],
    workers = 4,
    train_bs = 64,
    valid_bs = 64,
    seed = 42,
    num_labels = 1,
    grad_acc_steps = 1,
    max_gnorm = 1000,
    architecture = "CNN",
    infra = "Kaggle",
    competition = 'g2net',
    _wandb_kernel = 'tanaym'
)

run = wandb.init(project='g2net', 
                 config=CONFIG,
                 group='effnet',
                 job_type='train'
                )

In [None]:
class Config:
    lr=1e-5
    autocast = False
    resize=(224, 224)
    model_name = 'tf_efficientnet_b4'
    pretrained = True
    epochs = 3
    scheduler = 'CosineAnnealingLR'
    n_splits = 5
    split = 0.97
    folds = [1, 2, 3, 4, 5]
    workers = 4
    train_bs = 64
    valid_bs = 64
    seed = 42
    scaler = GradScaler()
    num_labels = 1
    target_name = 'target'
    grad_acc_steps = 1
    wandb = True
    max_gnorm = 1000
    TRAIN_PATH = "../input/g2net-gravitational-wave-detection/train"
    TEST_PATH = "../input/g2net-gravitational-wave-detection/test"
    train_file = "../input/g2net-gravitational-wave-detection-file-paths/training_labels_with_paths.csv"
    sample_sub = "../input/g2net-gravitational-wave-detection/sample_submission.csv"

In [None]:
def wandb_log(**kwargs):
    """
    Logs a key-value pair to W&B
    """
    for k, v in kwargs.items():
        wandb.log({k: v})

def get_train_file_path(image_id):
    """
    Taken from Y.Nakama's notebook
    """
    return "../input/g2net-n-mels-128-train-images/{}.npy".format(image_id)

def get_test_file_path(image_id):
    """
    Taken from Y.Nakama's notebook
    """
    return "../input/g2net-n-mels-128-test-images/{}.npy".format(image_id)

## 2. Model Architectures

List of all different model that you can use.

In [None]:
# Models
class VITModel(nn.Module):
    def __init__(self):
        super(VITModel, self).__init__()
        self.backbone = timm.create_model(Config.model_name, pretrained=Config.pretrained, in_chans=1)
        self.n_f = self.backbone.head.in_features
        self.backbone.head = nn.Linear(self.n_f, Config.num_labels)
    def forward(self, x):
        return self.backbone(x)

class EffNetModel(nn.Module):
    def __init__(self):
        super(EffNetModel, self).__init__()
        self.backbone = timm.create_model(Config.model_name, pretrained=Config.pretrained, in_chans=1)
        self.n_f = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Linear(self.n_f, Config.num_labels)
    def forward(self, x):
        return self.backbone(x)
    
class ResNextModel(nn.Module):
    def __init__(self):
        super(ResNextModel, self).__init__()
        self.backbone = timm.create_model(Config.model_name, pretrained=Config.pretrained, in_chans=1)
        self.n_f = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(self.n_f, Config.num_labels)
    def forward(self, x):
        return self.backbone(x)

## 3. Custom Dataset Class

Now we define a custom dataset class that will load our data when training the model.

The `_get_fname()` function essentially takes an id and gets the corresponding `.npy` file from the folder structure.

In [None]:
class G2NetData(Dataset):
    def __init__(self, data, is_test=False):
        self.data = data
        self.is_test = is_test        
        self.file_names = self.data['file_path'].values
        self.labels = self.data['target'].values
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path = self.file_names[idx]
        image = np.load(file_path)
        image = image[np.newaxis, :, :]
        image = torch.from_numpy(image).float()
        label = torch.tensor(self.labels[idx]).float()
        
        return image, label

## 4. Augments

Some basic Augments using Albumentations library.

I am currently not using them in the code but planning to do so in the future.

In [None]:
class Augments:
    """
    Contains Train, Validation Augments
    """
    train_augments = A.Compose([
        ToTensorV2(p=1.0),
    ],p=1.)
    
    valid_augments = A.Compose([
        ToTensorV2(p=1.0),
    ], p=1.)

## 5. Trainer Class

This is the main trainer class. It has training and validation function for one-one epoch each.

In [None]:
class Trainer:
    def __init__(self, model, optimizer, scheduler, train_dataloader, valid_dataloader, device):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.train_data = train_dataloader
        self.valid_data = valid_dataloader
        self.loss_fn = self.yield_loss
        self.val_loss_fn = self.yield_loss
        self.device = device
        
    def yield_loss(self, outputs, targets):
        """
        Returns the loss function
        """
        return nn.BCEWithLogitsLoss()(outputs, targets)
    
    def train_one_epoch(self):
        """
        Trains the model for one epoch
        """
        prog_bar = tqdm(enumerate(self.train_data), total=len(self.train_data))
        self.model.train()
        avg_loss = 0
        for idx, inputs in prog_bar:
            image = inputs[0].to(self.device, dtype=torch.float)
            targets = inputs[1].to(self.device, dtype=torch.float)

            if Config.autocast:
                with autocast():
                    outputs = self.model(image).view(-1)
                    loss = self.loss_fn(outputs, targets)
                Config.scaler.scale(loss).backward()
                Config.scaler.step(self.optimizer)
                Config.scaler.update()
                
            else:
                outputs = self.model(image).view(-1)
                loss = self.loss_fn(outputs, targets)
                loss.backward()
                self.optimizer.step()
            
            self.optimizer.zero_grad()
            prog_bar.set_description('loss: {:.2f}'.format(loss.item()))

            avg_loss += loss.item()

        return avg_loss / len(self.train_data)
    
    def valid_one_epoch(self):
        """
        Validates the model over all batches (1 epoch)
        """
        prog_bar = tqdm(enumerate(self.valid_data), total=len(self.valid_data))
        self.model.eval()
        all_targets = []
        all_predictions = []
        avg_loss = 0
        with torch.no_grad():
            for idx, inputs in prog_bar:
                image = inputs[0].to(self.device, dtype=torch.float)
                targets = inputs[1].to(self.device, dtype=torch.float)
                
                outputs = self.model(image).view(-1)
                
                val_loss = self.val_loss_fn(outputs, targets)
                prog_bar.set_description('val_loss: {:.2f}'.format(val_loss.item()))
                
                all_targets.extend(targets.cpu().detach().numpy().tolist())
                all_predictions.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
                
                avg_loss += val_loss.item()
        val_roc_auc = roc_auc_score(all_targets, all_predictions)
        return val_roc_auc, avg_loss / len(self.valid_data)
    
    def get_model(self):
        return self.model

## 6. Main Training Code

Below code combines everything into one big training and validation procedure.

In [None]:
# Training Code
if __name__ == '__main__':
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
        DEVICE = torch.device('cuda')
    else:
        print("\n[INFO] GPU not found. Using CPU: {}\n".format(platform.processor()))
        DEVICE = torch.device('cpu')
    
    # Data loading and stuff
    data = pd.read_csv(Config.train_file)
    data['file_path'] = data['id'].apply(get_train_file_path)
    data = data.sample(frac=1).reset_index(drop=True)
    
    kf = StratifiedKFold(n_splits=Config.n_splits, shuffle=True)
    
    print(f"Training Model: {Config.model_name}")
    for fold_, (train_idx, valid_idx) in enumerate(kf.split(data, data[Config.target_name])):
        print(f"{'='*20} Fold: {fold_} {'='*20}")
        
        train_data = data.loc[train_idx].reset_index(drop=True)
        valid_data = data.loc[valid_idx].reset_index(drop=True)
    
        print(f"Training on {train_data.shape[0]} samples and validating on {valid_data.shape[0]} samples")

        # Make Training and Validation Datasets
        training_set = G2NetData(
            data=train_data
        )

        validation_set = G2NetData(
            data=valid_data
        )

        train = DataLoader(
            training_set,
            batch_size=Config.train_bs,
            shuffle=True,
            num_workers=Config.workers,
            pin_memory=True
        )

        valid = DataLoader(
            validation_set,
            batch_size=Config.valid_bs,
            shuffle=False,
            num_workers=Config.workers
        )

        # Declare model and initialize other things
        model = EffNetModel().to(DEVICE)
        nb_train_steps = int(len(train_data) / Config.train_bs * Config.epochs)
        optimizer = torch.optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=1e-6)

        trainer = Trainer(model, optimizer, None, train, valid, DEVICE)

        # Do the training and validation
        for epoch in range(1, Config.epochs+1):
            print(f"\n{'--'*5} EPOCH: {epoch} {'--'*5}\n")

            # Train for 1 epoch
            tr_lss = trainer.train_one_epoch()

            # Validate for 1 epoch
            current_roc, vl_lss = trainer.valid_one_epoch()
            if Config.wandb:
                wandb_log(
                    epoch_train_loss=tr_lss,
                    epoch_valid_loss=vl_lss,
                    roc_auc_score=current_roc,
                )
            print(f"Validation ROC-AUC: {current_roc:.4f}")

            torch.save(trainer.get_model().state_dict(), f"fold_{fold_}_{Config.model_name}.pt")
        
        del train_data, valid_data, training_set, validation_set, train, valid, model, optimizer, trainer
        gc.collect()
        torch.cuda.empty_cache()