![](https://storage.googleapis.com/kaggle-competitions/kaggle/23249/logos/header.png)

# Motivation

Motivation: I wanted to implement a notebook where it would use a Loss Function which optimized AUC directly

Any feedback is appreciated!
Be sure to also checkout:
* Inference: tbd

# Config

In [None]:
# ====================================================
# Configurations
# ====================================================
import os
class CFG:
    DEBUG = True
    
    #Model Params
    device = 'GPU' #['CPU','GPU','TPU']
    N_FOLDS = 5
    MODEL_NAME = 'tf_efficientnet_b1_ns' # Recommended : ['deit_base_patch16_384','vit_large_patch16_384','tf_efficientnet_b4_ns','resnext50_32x4d']
    pretrained = True   
    EPOCHS = 5 if not DEBUG else 3 # more is definitely plausible
    TRAIN_FOLDS = [0] if DEBUG else [i for i in range(N_FOLDS)] #Folds to be Trained
    N_CLASSES = 1
    in_channels = 1
    
    scheduler_name = 'CosineAnnealingWarmRestarts'
    # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'GradualWarmupSchedulerV2']
    scheduler_update = 'epoch' #['batch','epoch']
    criterion_name = 'ROC-Star'
    # ['BCEWithLogitsLoss','ROC-Star']
    optimizer_name = 'AdamP' #['Adam','AdamW','AdamP','Ranger'] -> AdamP doesn't work on TPUs
    LR_RAMPUP_EPOCHS = 1
    LR_SUSTAIN_EPOCHS = 0
    
    FREEZE = False #If you fine tune after START_FREEZE epochs
    START_FREEZE = 8
    
    #Training Params
    BATCH_SIZE = 512
    
    LR = 2e-4
    LR_START =1e-4
    LR_MIN = 8e-7
    weight_decay = 0
    eps = 1e-8
    PATIENCE = 2
      
    #CosineAnnealingWarmRestarts
    T_0 = EPOCHS
    
    #CosineAnnealingLR
    T_max = EPOCHS
    
    NUM_WORKERS = 4
    
    model_print = False #If the model architecture is printed
    tqdm = True #If training bar is shown
    
    #n_procs = number of replicas -> TPU
    n_procs = 8 #You can set it to 1 and run a TPU as a GPU if you want
    SEED = 42
    saved_models = {}

In [None]:
!pip install timm

if CFG.optimizer_name == 'Ranger':
    !pip install --quiet '../input/pytorch-ranger'
elif CFG.optimizer_name == 'AdamP':
    !pip install adamp
    
!pip install -q nnAudio

# Library

In [None]:
# ====================================================
# Library
# ====================================================

import random
import math
import time

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import imread
import numpy as np
import cv2
from sklearn.model_selection import GroupKFold, StratifiedKFold,KFold
from sklearn.metrics import accuracy_score,roc_auc_score

import torch
import torch.nn as nn
import torchvision
from torchvision import models as tvmodels
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F
from tqdm import tqdm
import timm



import albumentations as A
from albumentations import Compose
from albumentations.pytorch import ToTensorV2
import seaborn as sns


from nnAudio.Spectrogram import CQT
from nnAudio.Spectrogram import CQT1992v2

from PIL import Image, ImageOps, ImageEnhance, ImageChops

    
if CFG.scheduler_name == 'GradualWarmupSchedulerV2':
    from warmup_scheduler import GradualWarmupScheduler

if CFG.optimizer_name == 'AdamP':
    from adamp import AdamP

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
import librosa

# Utils

In [None]:
# ====================================================
# Utils
# ====================================================
def get_score(y_true, y_pred):
    score = roc_auc_score(y_true, y_pred)
    return score

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
if CFG.scheduler_name == 'GradualWarmupSchedulerV2':
    class GradualWarmupSchedulerV2(GradualWarmupScheduler):
        def __init__(self, optimizer = None, multiplier = CFG.LR/CFG.LR_START, total_epoch = CFG.LR_RAMPUP_EPOCHS, after_scheduler=None):
            super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
            self.after_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = CFG.T_0 - CFG.LR_RAMPUP_EPOCHS, T_mult=1, eta_min=CFG.LR_MIN, last_epoch=-1)
        def get_lr(self):
            if self.last_epoch > self.total_epoch:
                if self.after_scheduler:
                    if not self.finished:
                        self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                        self.finished = True
                    return self.after_scheduler.get_lr()
                return [base_lr * self.multiplier for base_lr in self.base_lrs]
            if self.multiplier == 1.0:
                return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
            else:
                return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
  
    
def GetScheduler(scheduler_name,optimizer,batches):
    #['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'GradualWarmupSchedulerV2']
    if scheduler_name == 'OneCycleLR':
        return torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr = 1e-2,epochs = CFG.EPOCHS,steps_per_epoch = batches+1,pct_start = 0.1)
    elif scheduler_name == 'CosineAnnealingWarmRestarts':
        return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = CFG.T_0, T_mult=1, eta_min=CFG.LR_MIN, last_epoch=-1)
    elif scheduler_name == 'CosineAnnealingLR':
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = CFG.T_max, eta_min=0, last_epoch=-1)
    elif scheduler_name == 'ReduceLROnPlateau':
        return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.1, patience=1, threshold=0.0001, cooldown=0, min_lr=CFG.LR_MIN, eps=CFG.eps)
    elif scheduler_name == 'GradualWarmupSchedulerV2':
        return GradualWarmupSchedulerV2(optimizer=optimizer)
    
def GetOptimizer(optimizer_name,parameters):
    #['Adam','Ranger']
    if optimizer_name == 'Adam':
        if CFG.scheduler_name == 'GradualWarmupSchedulerV2':
            return torch.optim.Adam(parameters, lr=CFG.LR_START, weight_decay=CFG.weight_decay, amsgrad=False)
        else:
            return torch.optim.Adam(parameters, lr=CFG.LR, weight_decay=CFG.weight_decay, amsgrad=False)
    elif optimizer_name == 'AdamW':
        if CFG.scheduler_name == 'GradualWarmupSchedulerV2':
            return torch.optim.AdamW(parameters, lr=CFG.LR_START, weight_decay=CFG.weight_decay, amsgrad=False)
        else:
            return torch.optim.Adam(parameters, lr=CFG.LR, weight_decay=CFG.weight_decay, amsgrad=False)
    elif optimizer_name == 'AdamP':
        if CFG.scheduler_name == 'GradualWarmupSchedulerV2':
            return AdamP(parameters, lr=CFG.LR_START, weight_decay=CFG.weight_decay)
        else:
            return AdamP(parameters, lr=CFG.LR, weight_decay=CFG.weight_decay)
    elif optimizer_name == 'Ranger':
        return Ranger(parameters,lr = CFG.LR,alpha = 0.5, k = 6,N_sma_threshhold = 5,betas = (0.95,0.999),eps=CFG.eps,weight_decay=CFG.weight_decay)

def print_scheduler(scheduler = None,scheduler_update = CFG.scheduler_update,optimizer = None, batches = -1, epochs = -1, model = None):
    lrs = []
    if scheduler_update == 'epoch':
        for epoch in range(epochs):
            scheduler.step(epoch)
            lrs.append(optimizer.param_groups[0]["lr"])
        plt.figure(figsize=(15,4))
        plt.plot(lrs)
    elif scheduler_update == 'batch':
        for epoch in range(epochs):
            for batch in range(batches):
                scheduler.step()
                lrs.append(optimizer.param_groups[0]["lr"])
        plt.figure(figsize=(15,4))
        plt.plot(lrs)
    
SEED = CFG.SEED
seed_everything(SEED)  
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Transforms

In [None]:
# credit https://www.kaggle.com/yasufuminakama/g2net-efficientnet-b7-baseline-training
# ====================================================
# Transforms
# ====================================================
def get_transforms(*, data):
    
    if data == 'train':
        return A.Compose([
            ToTensorV2(),
        ])

    elif data == 'valid':
        return A.Compose([
            ToTensorV2(),
        ])

# Datasets

In [None]:
# ====================================================
# Datasets
# ====================================================
def retrieve_df(df,name,idx):
    series = df[name].iloc[idx]
    series.reset_index(drop=True,inplace=True)
    
    return series

class G2NetDataset(torch.utils.data.Dataset):
    def __init__(self, features, target, is_test=False,file_names = [],transform=True):
        self.features,self.target,self.is_test,self.file_names,self.transform = features,target,is_test,file_names,transform
        
        self.wave_transform = CQT1992v2(sr=2048, fmin=20, fmax=1024, hop_length=64)
        
        self.image_transform = A.Compose([
            ToTensorV2(),
        ])
        
    def __getitem__(self, i):
        #adapted from https://www.kaggle.com/yasufuminakama/g2net-efficientnet-b0-baseline-training
        file_path = self.file_names[i]
        if not self.transform:
            image = np.load(file_path)
            image = image[np.newaxis,:,:]
            image = torch.from_numpy(image).float()
        else:
            waves = np.load(file_path)
            image = self.apply_qtransform(waves, self.wave_transform)
            image = image.squeeze().numpy()
            image = self.image_transform(image=image)['image']
        tgt = self.target.loc[i]
        return (image, torch.tensor(tgt, dtype=torch.float))
    
    def apply_qtransform(self, waves, transform):
        waves = np.hstack(waves)
        waves = waves / np.max(waves)
        waves = torch.from_numpy(waves).float()
        image = transform(waves)
        return image
    
    def __len__(self): return len(self.target)


# CV Split

In [None]:
# ====================================================
# CV Split
# ====================================================
#adapted from https://www.kaggle.com/yasufuminakama/g2net-efficientnet-b0-baseline-training
train_df = pd.read_csv("../input/g2net-gravitational-wave-detection/training_labels.csv")
test_df = pd.read_csv('../input/g2net-gravitational-wave-detection/sample_submission.csv')
#
if True:
    def get_train_file_path(image_id):
        return "../input/g2net-gravitational-wave-detection/train/{}/{}/{}/{}.npy".format(
            image_id[0], image_id[1], image_id[2], image_id)
    def get_test_file_path(image_id):
        return "../input/g2net-gravitational-wave-detection/test/{}/{}/{}/{}.npy".format(
            image_id[0], image_id[1], image_id[2], image_id)
else:
    def get_train_file_path(image_id):
        return "../input/g2net-n-mels-128-train-images/{}.npy".format(image_id)

    def get_test_file_path(image_id):
        return "../input/g2net-n-mels-128-test-images/{}.npy".format(image_id)

train_df['file_path'] = train_df['id'].apply(get_train_file_path)
test_df['file_path'] = test_df['id'].apply(get_test_file_path)

#modified from https://www.kaggle.com/thedrcat/g2net-fastai-resnet34-starter-mel
if CFG.DEBUG:
    train_df = train_df.sample(frac=0.01).reset_index(drop=True)
skf = StratifiedKFold(n_splits=CFG.N_FOLDS, shuffle=True, random_state=CFG.SEED)
skf.get_n_splits(np.arange(train_df.shape[0]), train_df['target'])
folds = [(idxT,idxV) for i,(idxT,idxV) in enumerate(skf.split(np.arange(train_df.shape[0]), train_df['target']))]
    
sns.countplot(data=train_df, x="target")

# Model

In [None]:
# ====================================================
# Model
# ====================================================
class G2Net(nn.Module):
    def __init__(self, model_name=CFG.MODEL_NAME, pretrained=CFG.pretrained,in_chans = CFG.in_channels):
        super().__init__()
        self.model_name = model_name
        if model_name == 'deit_base_patch16_224' or model_name == 'deit_base_patch16_384':
            self.model = torch.hub.load('facebookresearch/deit:main', model_name, pretrained=pretrained, in_chans=in_chans)
        else:
            self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=in_chans)
        if 'efficientnet' in model_name:
            self.n_features = self.model.classifier.in_features
            self.model.classifier = nn.Linear(self.n_features, CFG.N_CLASSES)
        elif model_name == 'vit_large_patch16_384' or model_name == 'deit_base_patch16_224' or model_name == 'deit_base_patch16_384':
            self.n_features = self.model.head.in_features
            self.model.head = nn.Linear(self.n_features, CFG.N_CLASSES)
        elif 'resnext' in model_name:
            self.n_features = self.model.fc.in_features
            self.model.fc = nn.Linear(self.n_features, CFG.N_CLASSES)
        
    def forward(self, x):
        return self.model(x)
    
    def freeze(self):
        # To freeze the residual layers
        for param in self.model.parameters():
            param.requires_grad = False
            
        if 'efficientnet' in self.model_name:
            for param in self.model.classifier.parameters():
                param.requires_grad = True
        elif self.model_name == 'vit_large_patch16_384' or 'deit_base_patch16_224':
            for param in self.model.head.parameters():
                param.requires_grad = True
        elif 'resnext' in self.model_name:
            for param in self.model.fc.parameters():
                param.requires_grad = True
            
    def unfreeze(self):
        # Unfreeze all layers
        for param in self.model.parameters():
            param.requires_grad = True
model = G2Net()

# ROC-Star loss

In [None]:
# Adapted ROC-Star
def epoch_update_gamma(y_true,y_pred, epoch=-1,delta=1):
        """
        Calculate gamma from last epoch's targets and predictions.
        Gamma is updated at the end of each epoch.
        y_true: `Tensor`. Targets (labels).  Float either 0.0 or 1.0 .
        y_pred: `Tensor` . Predictions.
        """
        DELTA = delta+1
        SUB_SAMPLE_SIZE = 2000.0
        pos = y_pred[y_true==1]
        neg = y_pred[y_true==0] # yo pytorch, no boolean tensors or operators?  Wassap?
        # subsample the training set for performance
        cap_pos = pos.shape[0]
        cap_neg = neg.shape[0]
        pos = pos[torch.rand_like(pos) < SUB_SAMPLE_SIZE/cap_pos]
        neg = neg[torch.rand_like(neg) < SUB_SAMPLE_SIZE/cap_neg]
        ln_pos = pos.shape[0]
        ln_neg = neg.shape[0]
        pos_expand = pos.view(-1,1).expand(-1,ln_neg).reshape(-1)
        neg_expand = neg.repeat(ln_pos)
        diff = neg_expand - pos_expand
        ln_All = diff.shape[0]
        Lp = diff[diff>0] # because we're taking positive diffs, we got pos and neg flipped.
        ln_Lp = Lp.shape[0]-1
        diff_neg = -1.0 * diff[diff<0]
        diff_neg = diff_neg.sort()[0]
        ln_neg = diff_neg.shape[0]-1
        ln_neg = max([ln_neg, 0])
        left_wing = int(ln_Lp*DELTA)
        left_wing = max([0,left_wing])
        left_wing = min([ln_neg,left_wing])
        default_gamma=torch.tensor(0.2, dtype=torch.float).cuda()
        if diff_neg.shape[0] > 0 :
           gamma = diff_neg[left_wing]
        else:
           gamma = default_gamma # default=torch.tensor(0.2, dtype=torch.float).cuda() #zoink
        L1 = diff[diff>-1.0*gamma]
        ln_L1 = L1.shape[0]
        if epoch > -1 :
            return gamma
        else :
            return default_gamma



def roc_star_loss( _y_true, y_pred, gamma, _epoch_true, epoch_pred):
        """
        Nearly direct loss function for AUC.
        See article,
        C. Reiss, "Roc-star : An objective function for ROC-AUC that actually works."
        https://github.com/iridiumblue/articles/blob/master/roc_star.md
            _y_true: `Tensor`. Targets (labels).  Float either 0.0 or 1.0 .
            y_pred: `Tensor` . Predictions.
            gamma  : `Float` Gamma, as derived from last epoch.
            _epoch_true: `Tensor`.  Targets (labels) from last epoch.
            epoch_pred : `Tensor`.  Predicions from last epoch.
        """
        #convert labels to boolean
        y_true = (_y_true>=0.50)
        epoch_true = (_epoch_true>=0.50)

        # if batch is either all true or false return small random stub value.
        if torch.sum(y_true)==0 or torch.sum(y_true) == y_true.shape[0]: return torch.sum(y_pred)*1e-8

        pos = y_pred[y_true]
        neg = y_pred[~y_true]

        epoch_pos = epoch_pred[epoch_true]
        epoch_neg = epoch_pred[~epoch_true]

        # Take random subsamples of the training set, both positive and negative.
        max_pos = 1000 # Max number of positive training samples
        max_neg = 1000 # Max number of positive training samples
        cap_pos = epoch_pos.shape[0]
        cap_neg = epoch_neg.shape[0]
        epoch_pos = epoch_pos[torch.rand_like(epoch_pos) < max_pos/cap_pos]
        epoch_neg = epoch_neg[torch.rand_like(epoch_neg) < max_neg/cap_pos]

        ln_pos = pos.shape[0]
        ln_neg = neg.shape[0]

        # sum positive batch elements agaionst (subsampled) negative elements
        if ln_pos>0 :
            pos_expand = pos.view(-1,1).expand(-1,epoch_neg.shape[0]).reshape(-1)
            neg_expand = epoch_neg.repeat(ln_pos)

            diff2 = neg_expand - pos_expand + gamma
            l2 = diff2[diff2>0]
            m2 = l2 * l2
            len2 = l2.shape[0]
        else:
            m2 = torch.tensor([0], dtype=torch.float).cuda()
            len2 = 0

        # Similarly, compare negative batch elements against (subsampled) positive elements
        if ln_neg>0 :
            pos_expand = epoch_pos.view(-1,1).expand(-1, ln_neg).reshape(-1)
            neg_expand = neg.repeat(epoch_pos.shape[0])

            diff3 = neg_expand - pos_expand + gamma
            l3 = diff3[diff3>0]
            m3 = l3*l3
            len3 = l3.shape[0]
        else:
            m3 = torch.tensor([0], dtype=torch.float).cuda()
            len3=0

        if (torch.sum(m2)+torch.sum(m3))!=0 :
           res2 = torch.sum(m2)/max_pos+torch.sum(m3)/max_neg
           #code.interact(local=dict(globals(), **locals()))
        else:
           res2 = torch.sum(m2)+torch.sum(m3)

        res2 = torch.where(torch.isnan(res2), torch.zeros_like(res2), res2)

        return res2
    
class ROC_Star(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, y_pred,_y_true,i):    #_epoch_true, epoch_pred):
        return roc_star_loss( _y_true, y_pred, CFG.gamma, torch.from_numpy(CFG.last_epoch_true[i]).cuda(), torch.from_numpy(CFG.last_epoch_pred[i]).cuda())

# Training Loop

In [None]:
# ====================================================
# Training Loop
# ====================================================
def train_one_epoch(model,optimizer,scheduler,scaler,train_loader,criterion,batches,epoch,DEVICE):   
    tr_loss = 0.0
    scores = 0.0
    trn_epoch_result = dict()
    model.train()
    if CFG.tqdm:
        progress = tqdm(enumerate(train_loader), desc="Loss: ", total=len(train_loader))
    else:
        progress = enumerate(train_loader)
    for i, (images,labels) in progress:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        
        optimizer.zero_grad()
        with autocast():
            logits = model(images).reshape((-1))
            if not epoch == 0:
                loss = criterion(logits, labels,i)
            else:
                temp_criterion = nn.BCEWithLogitsLoss()
                loss = temp_criterion(logits,labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        
        if CFG.scheduler_update == 'batch':
            if not CFG.scheduler_name == 'OneCycleLR':
                scheduler.step(epoch + i/len(train_loader))
            else:
                scheduler.step()

        tr_loss += loss.detach().item()
        
        if epoch == 0:
            CFG.last_epoch_true.append(labels.detach().to('cpu').numpy())
            CFG.last_epoch_pred.append(logits.sigmoid().detach().to('cpu').numpy())
        else:
            CFG.last_epoch_true[i] = labels.detach().to('cpu').numpy()
            CFG.last_epoch_pred[i] = logits.sigmoid().detach().to('cpu').numpy()
        
        if CFG.tqdm:
            trn_epoch_result['Epoch'] = epoch
            trn_epoch_result['train_loss'] = round(tr_loss/(i+1), 4)
            trn_epoch_result['LR'] = round(optimizer.param_groups[0]["lr"],7)

            progress.set_description(str(trn_epoch_result))
        else:
            print(tr_loss/(i+1))
    if CFG.scheduler_update == 'epoch':
            scheduler.step(epoch+1)
        
def val_one_epoch(model,DEVICE,loader,val_criterion,epoch,get_output = False):
    val_loss = 0.0
    scores = 0.0
    model.eval()
    val_progress = tqdm(enumerate(loader), desc="Loss: ", total=len(loader))
    with torch.no_grad():
        for i, (images,labels) in val_progress:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            logits = model(images).reshape((-1))
            
            sig_logits = logits.sigmoid()
            
            val_loss_value = val_criterion(logits,labels)
            val_loss += val_loss_value.detach().item()

            scores += get_score(labels.to('cpu').numpy(),sig_logits.to('cpu').numpy())

            val_epoch_result = dict()
            val_epoch_result['Epoch'] = epoch
            val_epoch_result['val_loss'] = round(val_loss/(i+1), 4)

            val_epoch_result['val_acc'] = round(scores/(i+1), 4)
            val_progress.set_description(str(val_epoch_result))
    if get_output:
        return val_loss/len(loader),scores/len(loader)
        
def model_train():
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
    for fold,(idxT, idxV) in enumerate(folds):
        if fold not in CFG.TRAIN_FOLDS:
            continue
        #if xm.is_master_ordinal():
        #    xm.master_print(fold)
        #______INSTANTIATE TRAINING DATASETS_____
        x_train = retrieve_df(train_df,'id',idxT)
        y_train = retrieve_df(train_df,'target',idxT)
        x_val = retrieve_df(train_df,'id',idxV)
        y_val = retrieve_df(train_df,'target',idxV)
        train_set = G2NetDataset(x_train,y_train, is_test=False,file_names = retrieve_df(train_df,'file_path',idxT))
        val_set = G2NetDataset(x_val,y_val, is_test=False,file_names = retrieve_df(train_df,'file_path',idxV))
        print(f"Start of Fold {fold}")
        train_loader = DataLoader(train_set, batch_size=CFG.BATCH_SIZE, shuffle=True,drop_last=True, num_workers=CFG.NUM_WORKERS,pin_memory = True)
        val_loader = DataLoader(val_set, batch_size=CFG.BATCH_SIZE, shuffle=False,drop_last=True, num_workers=CFG.NUM_WORKERS,pin_memory = True)
        scaler = GradScaler()
            
        batches = len(train_loader)
        val_batches = len(val_loader)

        #INSTANTIATE FOLD MODEL
        if CFG.model is None:
            CFG.model = G2Net(model_name=CFG.MODEL_NAME, pretrained=True)
        model = CFG.model.to(DEVICE)

        criterion = CFG.criterion.to(DEVICE)
        val_criterion = CFG.val_criterion.to(DEVICE)

        optimizer = GetOptimizer(CFG.optimizer_name, model.parameters())
        scheduler = GetScheduler(CFG.scheduler_name, optimizer,batches)
        
        saved_model = None
        best_val_acc = 0.0
        best_val_loss = 1e3
        fold_patience = 0.0
        for epoch in range(CFG.EPOCHS):
            if epoch >= CFG.START_FREEZE and CFG.FREEZE:
                print('Model Frozen -> Train Classifier Only')
                info = torch.load(saved_model,map_location = torch.device(DEVICE))
                model.load_state_dict(info)
                model.freeze()
                
                CFG.FREEZE = False
            #______TRAINING______
            train_one_epoch(model,optimizer,scheduler,scaler,train_loader,criterion,batches,epoch,DEVICE)
            
            #______VALIDATION_______
            val_loss, val_acc = val_one_epoch(model,DEVICE,val_loader,val_criterion,epoch,get_output = True)
            
            #Update Gamma
            CFG.gamma = epoch_update_gamma(torch.from_numpy(np.concatenate(CFG.last_epoch_true)), torch.from_numpy(np.concatenate(CFG.last_epoch_pred)), epoch=epoch, delta=2)
            
            if val_acc > best_val_acc:
                fold_patience = 0
                best_val_loss = val_loss/val_batches
                best_val_acc = val_acc
                torch.save(model.state_dict(),
                        f'{CFG.MODEL_NAME}_f{fold}_b{round(best_val_acc, 4)}.pth')
                if saved_model is not None:
                    try:
                        os.remove("./"+saved_model)
                    except:
                        a = 1
                saved_model = f'{CFG.MODEL_NAME}_f{fold}_b{round(best_val_acc, 4)}.pth'
                CFG.saved_models[fold] = round(best_val_acc, 4)
                print(f'Model Saved at {round(best_val_acc, 5)} accuracy')
            else:
                fold_patience += 1
                if fold_patience >= CFG.PATIENCE:
                    print(f'Early stopping due to model not improving for {CFG.PATIENCE} epochs')
                    CFG.model = None
                    break
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        CFG.model = None
                
def _map_fn(index,flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = model_train()

# Run

In [None]:
CFG.model = model
CFG.criterion = ROC_Star()
CFG.gamma = 0.2
CFG.last_epoch_true, CFG.last_epoch_pred = [],[]
CFG.val_criterion = nn.BCEWithLogitsLoss()
torch.set_default_tensor_type('torch.FloatTensor')
model_train()

# Inference

In [None]:
# ====================================================
# Inference
# ====================================================
test_df['file_path'] = test_df['id'].apply(get_test_file_path)
x_test = test_df['id']
y_test = test_df['target']
folds_preds = np.zeros((len(CFG.TRAIN_FOLDS),len(test_df)))
for fold in CFG.TRAIN_FOLDS:
    info = torch.load(f'{CFG.MODEL_NAME}_f{fold}_b{CFG.saved_models[fold]}.pth',map_location = torch.device(DEVICE))
    model.load_state_dict(info)
    test_set = G2NetDataset(x_test,y_test, is_test=True,file_names = test_df['file_path'])
    test_loader = DataLoader(test_set, batch_size=CFG.BATCH_SIZE, shuffle=False,drop_last=False, num_workers=CFG.NUM_WORKERS,pin_memory = True)
    test_progress = tqdm(enumerate(test_loader), desc="Loss: ", total=len(test_loader))
    model.eval()
    preds_arr = []
    with torch.no_grad():
        for i, (images,labels) in test_progress:
            images = images.to(DEVICE)

            logits = model(images).reshape((-1))
            preds = logits.sigmoid().to('cpu').numpy()
            preds_arr.append(preds)

    folds_preds[fold,:] = np.concatenate(preds_arr)
folds_preds = np.mean(folds_preds,axis = 0)
test_df['target'] = folds_preds
test_df.drop('file_path',axis = 1,inplace=True)
test_df.to_csv('output.csv', index=False)