## Few ideas inspired from https://www.kaggle.com/code/ttahara/hms-hbac-resnet34d-baseline-training

### Upvote if it helps

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import random

from sklearn import model_selection

import tensorflow

import torch
from torch.utils.data import DataLoader,Dataset
import torch.nn as nn
from torch import optim
import torch.functional as F
from pathlib import Path
from tqdm import tqdm

import matplotlib.pyplot as plt

from torchvision import transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

import timm

# from tqdm import tqdm
from tqdm.auto import tqdm
import torchaudio


import warnings
warnings.filterwarnings('ignore')

In [None]:
ROOT = Path.cwd().parent
INPUT = ROOT / "input"
OUTPUT = ROOT / "output"
SRC = ROOT / "src"

DATA = INPUT / "hms-harmful-brain-activity-classification"
TRAIN_SPEC = DATA / "train_spectrograms"
TEST_SPEC = DATA / "test_spectrograms"

TMP = ROOT / "tmp"
TRAIN_SPEC_SPLIT = TMP / "train_spectrograms_split"
TEST_SPEC_SPLIT = TMP / "test_spectrograms_split"

TMP.mkdir(exist_ok=True)
TRAIN_SPEC_SPLIT.mkdir(exist_ok=True)
TEST_SPEC_SPLIT.mkdir(exist_ok=True)

In [None]:
class CFG:
    CLASSES = ["seizure_vote", "lpd_vote", "gpd_vote", "lrda_vote", "grda_vote", "other_vote"]
    N_CLASSES = len(CLASSES)
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_transform=transforms.Resize((256, 256))
    t_transform = torchaudio.transforms.Spectrogram()
    SEED=1086
    N_FOLDS=5
    NUM_EPOCHS=8    
    BATCH_SIZE=32
    NUM_WORKERS=4
    PATIENCE=3
    EPS=1e-5
cfg = CFG()

In [None]:
def seed_everything(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
seed_everything(cfg.SEED)

In [None]:
train = pd.read_csv(DATA / "train.csv")
# convert vote to probability
train[cfg.CLASSES] /= train[cfg.CLASSES].sum(axis=1).values[:, None]
train = train.groupby("spectrogram_id").head(1).reset_index(drop=True)
print(train.shape)

In [None]:
# train = train.sample(100).reset_index(drop=True)

In [None]:
for spec_id, df in tqdm(train.groupby("spectrogram_id")):
    spec = pd.read_parquet(TRAIN_SPEC / f"{spec_id}.parquet")
    spec_arr = spec.fillna(0).values[:, 1:].T.astype("float32")
    for spec_offset, label_id in df[["spectrogram_label_offset_seconds", "label_id"]].astype(int).values:
        spec_offset = spec_offset // 2
        split_spec_arr = spec_arr[:, spec_offset: spec_offset + 300]
        np.save(TRAIN_SPEC_SPLIT / f"{label_id}.npy" , split_spec_arr)

In [None]:
sgkf = model_selection.StratifiedGroupKFold(n_splits=cfg.N_FOLDS, shuffle=True, random_state=cfg.SEED)

train["kfold"] = -1

for fold_id, (_, val_idx) in enumerate(
    sgkf.split(train, y=train["expert_consensus"], groups=train["patient_id"])
):
    train.loc[val_idx, "kfold"] = fold_id

In [None]:
class HMSDataset(Dataset):
    def __init__(self,df):
        self.df = df
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx:int):
        img_path = TRAIN_SPEC_SPLIT / f"{self.df.loc[:,'label_id'][idx]}.npy"
        labels = self.df.loc[:,cfg.CLASSES].values[idx]
        img = np.load(img_path)
#         img = img - img.min()
#         img = img / img.max()
#         img= np.log(img)
        data_mean=img.mean(axis=(0,1))
        data_std=img.std(axis=(0,1))
        img=(img-data_mean)/(data_std+cfg.EPS)
        img = np.nan_to_num(img, nan=0.0)
        data_tensor = torch.unsqueeze(torch.Tensor(img), dim=0)
#         img = cfg.t_transform(img)
        img = cfg.image_transform(data_tensor)
        lab = [lab.astype("float32") for lab in labels]
        return torch.tensor(img),torch.tensor(lab)

In [None]:
class HMSCnn(nn.Module):
    def __init__(self):
        super(HMSCnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(128 * 32 * 32, 256)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 6)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.pool3(self.relu3(self.conv3(x)))
        x = x.view(-1, 128 * 32 * 32)
        x = self.relu4(self.fc1(x))
        x = self.fc2(x)
        x = self.fc3(x)
        return x

In [None]:
model = HMSCnn()
model = model.to(cfg.DEVICE)

In [None]:
class KLDivLossWithLogits(nn.KLDivLoss):

    def __init__(self):
        super().__init__(reduction="batchmean")

    def forward(self, y, t):
        y = nn.functional.log_softmax(y,  dim=1)
        loss = super().forward(y, t)

        return loss

In [None]:
# criterion = nn.MSELoss()
criterion = KLDivLossWithLogits()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
def train_fold(model,train_loader,valid_loader,patience,n_epochs,criterion,optimizer):
    
    train_losses = []
    valid_losses = []
    avg_train_losses = []
    avg_valid_losses = []
    
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    for epoch in tqdm(range(1,cfg.NUM_EPOCHS+1)):
        
        model.train()
        for batch,(data,target) in enumerate(train_loader,1):
            data,target = data.to(cfg.DEVICE),target.to(cfg.DEVICE)
            optimizer.zero_grad()
            
            output = model(data)
            loss = criterion(output,target)
            
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())
            
        model.eval() 
        for data, target in valid_loader:
            data,target = data.to(cfg.DEVICE),target.to(cfg.DEVICE)
            output = model(data)
            loss = criterion(output, target)
            valid_losses.append(loss.item())
            
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        epoch_len = len(str(cfg.NUM_EPOCHS))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        
        print(print_msg)
        
        
        train_losses = []
        valid_losses = []
        
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    model.load_state_dict(torch.load('checkpoint.pt'))
    
    return  model, avg_train_losses, avg_valid_losses

In [None]:
def plot_loss(fold,train_loss,valid_loss):
    fig = plt.figure(figsize=(5,5))
    
    plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
    plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss')
    
    minposs = valid_loss.index(min(valid_loss))+1 
    plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint')
    
    plt.title(f"Plot for fold={fold}")
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.ylim(0, 0.5) # consistent scale
    plt.xlim(0, len(train_loss)+1) # consistent scale
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()
    fig.savefig('loss_plot.png', bbox_inches='tight')

In [None]:
def train_on_folds(fold):
    
    train_df = train.query("kfold!=@fold").reset_index(drop=True)
    val_df = train.query("kfold==@fold").reset_index(drop=True)
    
    trainset = HMSDataset(train_df)
    train_loader = DataLoader(trainset,batch_size=cfg.BATCH_SIZE,num_workers=cfg.NUM_WORKERS,shuffle=True)

    valset = HMSDataset(val_df)
    val_loader = DataLoader(valset,batch_size=cfg.BATCH_SIZE,num_workers=cfg.NUM_WORKERS,shuffle=False)
    
    m,atl,avl = train_fold(model,
                      train_loader,
                      val_loader,
                      cfg.PATIENCE,
                      cfg.NUM_EPOCHS,
                      criterion,
                      optimizer)
    
    plot_loss(fold,atl,avl)
    
    torch.save(m.state_dict(),f"hms_model_fold_{fold}.bin")   

In [None]:
train_on_folds(0)
train_on_folds(1)
train_on_folds(2)
train_on_folds(3)
train_on_folds(4)