## References:

In [None]:
https://www.kaggle.com/haqishen/baseline-modified-from-previous-competition

https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-train-amp-aug
    
https://www.kaggle.com/piantic/train-cassava-starter-using-various-loss-funcs

AdamP Optimizer: Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights
    https://arxiv.org/abs/2006.08217
    https://www.kaggle.com/seriousran/adam-vs-adamp-iclr-2021
    https://github.com/clovaai/AdamP

SAM Optimizer: Sharpness-Aware Minimization for Efficiently Improving Generalization
    https://arxiv.org/pdf/2010.01412v2.pdf
    https://github.com/davda54/sam

In [None]:
!pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git

In [None]:
"""
    import external libraries
"""
from warmup_scheduler import GradualWarmupScheduler    
    
package_paths = [
    '../input/timm-models/pytorch-image-models-master',
]

import sys; 

for pth in package_paths:
    sys.path.append(pth)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from glob import glob
import os, time, gc, random, warnings, joblib
from sklearn.model_selection import StratifiedKFold
import sklearn
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics

import cv2
from skimage import io
from scipy.ndimage.interpolation import zoom
import timm
import albumentations
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.cuda.amp import autocast, GradScaler
from torch.nn.modules.loss import _WeightedLoss

import torchvision
from torchvision import transforms
from tqdm import tqdm
import argparse

In [None]:
Config = {
    'seed': 42,
    'data_image_root': '../input/cassava-leaf-disease-classification/train_images/',
    #'data_image_root': '../input/cassava-leaf-disease-merged/train/', # if you want to consider 2019 data for early stage of model training 
    'train_set': '../input/train-folds/train_folds.csv',
    #'train_set2': '../input/train-folds/train_2019_folds.csv', # if you want to consider 2019 data for early stage of model training 
    'image_size': 512, 
    'model_arch': 'tf_efficientnet_b4_ns', # 'gluon_seresnext101_32x4d', 'seresnext50_32x4d'
    'epochs': 10,
    'device': 'cuda:0',
    'kernel_type': 'training_stage',
    'cutmix_alpha': 1,
    'output_label': True,
    'one_hot_label': False,
    'do_cutmix': False,
    'do_mixup': False,
    'num_workers': 4,
    'batch_size': 12,
    'pin_memory': True,
    'init_lr': 1e-5,
    'warmup_factor': 10,
    'warmup_epoch': 1,
    'weight_decay': 1e-6,
    'rho': 0.05,
    'num_classes': 5,
    'beta': 0.0,
    'filter_bias_and_bn': True,
    'resume': False
}

## Utils

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

def rand_bbox(size, lam):
    W = size[0]
    H = size[1]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)
    
    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

## Dataset

In [None]:
class CassavaDataset(Dataset):
    def __init__(self, df, data_image_root, transforms=None, output_label=True, one_hot_label=False, 
                 do_cutmix=False, cutmix_params={'alpha': 1.}, 
                 do_mixup=False):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_image_root = data_image_root
        self.do_cutmix = do_cutmix
        self.cutmix_params = cutmix_params
        self.do_mixup = do_mixup
        
        self.output_label = output_label
        self.one_hot_label = one_hot_label
        
        if output_label == True: # set to true for train set
            self.labels = self.df['label'].values
            
            if one_hot_label is True:
                self.labels = np.eye(self.df['label'].max()+1)[self.labels]
                
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        # get labels
        if self.output_label: # set to true for train set
            target = self.labels[index]
        
        # retrieve image from assigned directory
        img = get_img("{}/{}".format(self.data_image_root, self.df.loc[index]['image_id']))
        
        # do augmentation/s
        if self.transforms:
            img = self.transforms(image=img)['image'] 
                
        if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            with torch.no_grad():
                cutmix_ix = np.random.choice(self.df.index, size=1)[0]
                cutmix_img = get_img("{}/{}".format(self.data_image_root, self.df.iloc[cutmix_ix]['image_id']))
                
                if self.transforms:
                    cutmix_img = self.transforms(image=cutmix_img)['image']
                    
                lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']), 0.3, 0.7)
                bbx1, bby1, bbx2, bby2 = rand_bbox((Config['image_size'], Config['image_size']), lam)
                img[:, bbx1:bbx2, bby1:bby2] = cutmix_img[:, bbx1:bbx2, bby1:bby2]
                rate = 1 - ((bbx2-bbx1) * (bby2-bby1) / (Config['image_size'] * Config['image_size']))
                target = rate*target + (1.-rate)*self.labels[cutmix_ix]
                
        if self.do_mixup and np.random.uniform(0., 1., size=1)[0] > 0.5:
            with torch.no_grad():
                mixup_idx = np.random.choice(self.df.index, size=1)[0]
                mixup_img = get_img("{}/{}".format(self.data_image_root, self.df.iloc[mixup_idx, :]['image_id']))
                
                if self.transforms:
                    mixup_img = self.transforms(image=mixup_img)['image']
                    
                lam = np.random.beta(1.0, 1.0)
                lam = max(lam, 1-lam)
                
                img = lam * img + (1 - lam) * mixup_img
                target = lam*target + (1.-lam)*self.labels[mixup_idx]
                
        if self.output_label==True:
            return img, target
        else:
            return img

## Data Augmentations

In [None]:
"""
Note: use these augmentations with combination of cutmix by setting it to True during early stage of model training
"""
#def get_train_transforms():
#    return albumentations.Compose([
#        albumentations.RandomResizedCrop(Config['image_size'], Config['image_size']),
#        albumentations.Transpose(p=0.5),
#        albumentations.HorizontalFlip(p=0.5),
#        albumentations.VerticalFlip(p=0.5),
#        albumentations.ShiftScaleRotate(p=0.5),
#        albumentations.Cutout(max_h_size=int(Config['image_size'] * 0.4), max_w_size=int(Config['image_size'] * 0.4), num_holes=1, p=0.5),
#        albumentations.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225], max_pixel_value=255,p=1.0),
#        ToTensorV2(p=1.0),
#    ], p=1.0)

"""
Note: for fine-tuning, use these augmentations, disable cutmix
"""
def get_train_transforms():
    return albumentations.Compose([
        albumentations.RandomCrop(Config['image_size'], Config['image_size']),
        albumentations.Resize(Config['image_size'], Config['image_size']),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(0.5),
        albumentations.Transpose(p=0.5),
        albumentations.Rotate(limit=(-90, 90), p=0.5),
        albumentations.OneOf([
            albumentations.ShiftScaleRotate(),
            albumentations.ElasticTransform(alpha=3)
        ], p=0.5),
        albumentations.OneOf([
            albumentations.OpticalDistortion(distort_limit=1.0),
            albumentations.GridDistortion(num_steps=5, distort_limit=1.0)
        ], p=0.5),
        albumentations.OneOf([
            albumentations.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2),
            albumentations.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1)),
            albumentations.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            albumentations.FancyPCA(),
            albumentations.CLAHE(clip_limit=4.0)
        ], p=0.5),
        albumentations.OneOf([
            albumentations.IAAAffine(),
            albumentations.IAAPerspective(),
            albumentations.IAAPiecewiseAffine(),
            albumentations.IAASuperpixels()
        ], p=0.5),
        albumentations.Cutout(max_h_size=int(Config['image_size'] * 0.375), max_w_size=int(Config['image_size'] * 0.375), num_holes=1, p=0.5),
        albumentations.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225], max_pixel_value=255, p=1.0),
        ToTensorV2(p=1.0)
    ])

def get_valid_transforms():
    return albumentations.Compose([
        albumentations.CenterCrop(Config['image_size'], Config['image_size'], p=1.0),
        albumentations.Resize(Config['image_size'], Config['image_size']),
        albumentations.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225], max_pixel_value=255, p=1.0),
        ToTensorV2(p=1.0),
    ], p=1.0)

In [None]:
def prepare_dataloader(df, df2, train_idx, valid_idx, data_image_root):
    
    train_ = df.loc[train_idx, :].reset_index(drop=True)
    
    #train_ = pd.concat([
    #    df.loc[train_idx, :].reset_index(drop=True),
    #    df2
    #], ignore_index=True) # enable this if you want to consider 2019 data
    
    valid_ = df.loc[valid_idx, :].reset_index(drop=True)
    
    train_ds = CassavaDataset(train_, 
                              data_image_root, 
                              transforms=get_train_transforms(), 
                              output_label=Config['output_label'],
                              one_hot_label=Config['one_hot_label'], 
                              do_cutmix=Config['do_cutmix'], 
                              do_mixup=Config['do_mixup'])
    
    valid_ds = CassavaDataset(valid_, 
                              data_image_root, 
                              transforms=get_valid_transforms(), 
                              output_label=Config['output_label'])
    
    train_dataloader = DataLoader(train_ds, 
                                  batch_size=Config['batch_size'], 
                                  pin_memory=Config['pin_memory'], 
                                  shuffle=True, 
                                  drop_last=True,
                                  num_workers=Config['num_workers'])
    
    valid_dataloader = DataLoader(valid_ds, 
                                  batch_size=Config['batch_size'], 
                                  pin_memory=Config['pin_memory'], 
                                  shuffle=False, drop_last=False,
                                  num_workers=Config['num_workers'])
    
    return train_dataloader, valid_dataloader

## Model

In [None]:
class CassavaImageClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        
        ## SE-ResNext
        #n_features = self.model.fc.in_features
        #self.model.fc = nn.Linear(n_features, n_class)          
        
        ## EfficientNet
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
        
    def forward(self, x):
        x = self.model(x)
        return x

## Train & Validation Step Function

In [None]:
def train_one_epoch(model, epoch, criterion, optimizer, train_loader, device):
    
    model.train()
    
    t = time.time()
    train_loss = []
    image_preds_all = []
    image_targets_all = []
    
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    num_iters = len(train_loader)
    
    for step, (imgs, image_labels) in pbar:
        
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()  
                
        # # first forward-backward step
        image_preds = model(imgs)
        loss = criterion(image_preds, image_labels)
        
        loss.mean().backward()
        optimizer.first_step(zero_grad=True)
        
        # second forward-backward step
        criterion(model(imgs), image_labels).mean().backward()
        optimizer.second_step(zero_grad=True)
        
        image_pred = image_preds.softmax(1).argmax(1).detach()
        image_preds_all.append(image_pred)
        image_targets_all.append(image_labels)
        
        loss_np = loss.detach().cpu().numpy()
        train_loss.append(loss_np)
        smooth_loss = sum(train_loss[-100:]) / min(len(train_loss), 100)
        pbar.set_description('loss: %.5f, smooth: %.5f' % (loss_np, smooth_loss))
        
    image_preds_all = torch.cat(image_preds_all).cpu().numpy()
    image_targets_all = torch.cat(image_targets_all).cpu().numpy()
    acc = (image_preds_all == image_targets_all).mean()
        
    return train_loss, acc

def valid_one_epoch(model, criterion, valid_loader, device):
        
    model.eval()
    
    t = time.time()
    valid_loss = []
    image_logits_all = []
    image_preds_all = []
    image_targets_all = []
    
    with torch.no_grad():
        pbar = tqdm(valid_loader)
        for (imgs, image_labels) in pbar:
            imgs = imgs.to(device).float()
            image_labels = image_labels.to(device).long()
        
            image_logits = model(imgs)
            loss = criterion(image_logits, image_labels)

            image_pred = image_logits.softmax(1).argmax(1).detach()
            image_logits_all.append(image_logits)
            image_preds_all.append(image_pred)
            image_targets_all.append(image_labels)
        
            valid_loss.append(loss.detach().cpu().numpy())
        valid_loss = np.mean(valid_loss)
    
    image_logits_all = torch.cat(image_logits_all).cpu().numpy()
    image_preds_all = torch.cat(image_preds_all).cpu().numpy()
    image_targets_all = torch.cat(image_targets_all).cpu().numpy()
    acc = (image_preds_all == image_targets_all).mean()
    
    return valid_loss, acc 

## Learning Rate Scheduler

In [None]:
class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

## Loss Functions / Criterions

In [None]:
class LabelSmoothingLoss(nn.Module): 
    def __init__(self, classes=5, smoothing=0.06, dim=-1): 
        super(LabelSmoothingLoss, self).__init__() 
        self.confidence = 1.0 - smoothing 
        self.smoothing = smoothing 
        self.cls = classes 
        self.dim = dim 
    def forward(self, pred, target): 
        pred = pred.log_softmax(dim=self.dim) 
        with torch.no_grad():
            true_dist = torch.zeros_like(pred) 
            true_dist.fill_(self.smoothing / (self.cls - 1)) 
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
    
class FocalCosineLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, xent=.1): #change from gamma value from 2 to 1.934; alpha from 1 to 0.3868
        super(FocalCosineLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

        self.xent = xent

        self.y = torch.Tensor([1]).cuda()

    def forward(self, input, target, reduction="mean"):
        cosine_loss = F.cosine_embedding_loss(input, F.one_hot(target, num_classes=input.size(-1)), self.y, reduction=reduction)

        cent_loss = F.cross_entropy(F.normalize(input), target, reduce=False)
        pt = torch.exp(-cent_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * cent_loss

        if reduction == "mean":
            focal_loss = torch.mean(focal_loss)

        return cosine_loss + self.xent * focal_loss
    
def log_t(u, t):
    """Compute log_t for `u'."""
    if t==1.0:
        return u.log()
    else:
        return (u.pow(1.0 - t) - 1.0) / (1.0 - t)

def exp_t(u, t):
    """Compute exp_t for `u'."""
    if t==1:
        return u.exp()
    else:
        return (1.0 + (1.0-t)*u).relu().pow(1.0 / (1.0 - t))

def compute_normalization_fixed_point(activations, t, num_iters):

    """Returns the normalization value for each example (t > 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same shape as activation with the last dimension being 1.
    """
    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations_step_0 = activations - mu

    normalized_activations = normalized_activations_step_0

    for _ in range(num_iters):
        logt_partition = torch.sum(
                exp_t(normalized_activations, t), -1, keepdim=True)
        normalized_activations = normalized_activations_step_0 * \
                logt_partition.pow(1.0-t)

    logt_partition = torch.sum(
            exp_t(normalized_activations, t), -1, keepdim=True)
    normalization_constants = - log_t(1.0 / logt_partition, t) + mu

    return normalization_constants

def compute_normalization_binary_search(activations, t, num_iters):

    """Returns the normalization value for each example (t < 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (< 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """

    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations = activations - mu

    effective_dim = \
        torch.sum(
                (normalized_activations > -1.0 / (1.0-t)).to(torch.int32),
            dim=-1, keepdim=True).to(activations.dtype)

    shape_partition = activations.shape[:-1] + (1,)
    lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device)
    upper = -log_t(1.0/effective_dim, t) * torch.ones_like(lower)

    for _ in range(num_iters):
        logt_partition = (upper + lower)/2.0
        sum_probs = torch.sum(
                exp_t(normalized_activations - logt_partition, t),
                dim=-1, keepdim=True)
        update = (sum_probs < 1.0).to(activations.dtype)
        lower = torch.reshape(
                lower * update + (1.0-update) * logt_partition,
                shape_partition)
        upper = torch.reshape(
                upper * (1.0 - update) + update * logt_partition,
                shape_partition)

    logt_partition = (upper + lower)/2.0
    return logt_partition + mu

class ComputeNormalization(torch.autograd.Function):
    """
    Class implementing custom backward pass for compute_normalization. See compute_normalization.
    """
    @staticmethod
    def forward(ctx, activations, t, num_iters):
        if t < 1.0:
            normalization_constants = compute_normalization_binary_search(activations, t, num_iters)
        else:
            normalization_constants = compute_normalization_fixed_point(activations, t, num_iters)

        ctx.save_for_backward(activations, normalization_constants)
        ctx.t=t
        return normalization_constants

    @staticmethod
    def backward(ctx, grad_output):
        activations, normalization_constants = ctx.saved_tensors
        t = ctx.t
        normalized_activations = activations - normalization_constants 
        probabilities = exp_t(normalized_activations, t)
        escorts = probabilities.pow(t)
        escorts = escorts / escorts.sum(dim=-1, keepdim=True)
        grad_input = escorts * grad_output
        
        return grad_input, None, None

def compute_normalization(activations, t, num_iters=5):
    """Returns the normalization value for each example. 
    Backward pass is implemented.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    return ComputeNormalization.apply(activations, t, num_iters)

def tempered_sigmoid(activations, t, num_iters = 5):
    """Tempered sigmoid function.
    Args:
      activations: Activations for the positive class for binary classification.
      t: Temperature tensor > 0.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_probabilities = tempered_softmax(internal_activations, t, num_iters)
    return internal_probabilities[..., 0]


def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature > 1.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    if t == 1.0:
        return activations.softmax(dim=-1)

    normalization_constants = compute_normalization(activations, t, num_iters)
    return exp_t(activations - normalization_constants, t)

def bi_tempered_binary_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing = 0.0,
        num_iters=5,
        reduction='mean'):

    """Bi-Tempered binary logistic loss.
    Args:
      activations: A tensor containing activations for class 1.
      labels: A tensor with shape as activations, containing probabilities for class 1
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing
      num_iters: Number of iterations to run the method.
    Returns:
      A loss tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_labels = torch.stack([labels.to(activations.dtype),
        1.0 - labels.to(activations.dtype)],
        dim=-1)
    return bi_tempered_logistic_loss(internal_activations, 
            internal_labels,
            t1,
            t2,
            label_smoothing = label_smoothing,
            num_iters = num_iters,
            reduction = reduction)

def bi_tempered_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing=0.0,
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
            - labels_onehot * log_t(probabilities, t1) \
            - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
            + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim = -1) #sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()

class BiTemperedLogisticLoss(nn.Module): 
    def __init__(self, t1, t2, smoothing=0.0): 
        super(BiTemperedLogisticLoss, self).__init__() 
        self.t1 = t1
        self.t2 = t2
        self.smoothing = smoothing
    def forward(self, logit_label, truth_label):
        loss_label = bi_tempered_logistic_loss(
            logit_label, truth_label,
            t1=self.t1, t2=self.t2,
            label_smoothing=self.smoothing,
            reduction='none'
        )
        
        loss_label = loss_label.mean()
        return loss_label
    
class TaylorSoftmax(nn.Module):
    '''
    This is the autograd version
    '''
    def __init__(self, dim=1, n=2):
        super(TaylorSoftmax, self).__init__()
        assert n % 2 == 0
        self.dim = dim
        self.n = n

    def forward(self, x):
        '''
        usage similar to nn.Softmax:
            >>> mod = TaylorSoftmax(dim=1, n=4)
            >>> inten = torch.randn(1, 32, 64, 64)
            >>> out = mod(inten)
        '''
        fn = torch.ones_like(x)
        denor = 1.
        for i in range(1, self.n+1):
            denor *= i
            fn = fn + x.pow(i) / denor
        out = fn / fn.sum(dim=self.dim, keepdims=True)
        return out


##
# version 1: use torch.autograd
class BaseTaylorCrossEntropyLoss(nn.Module):
    '''
    This is the autograd version
    '''
    def __init__(self, n=2, ignore_index=-1, reduction='mean'):
        super(BaseTaylorCrossEntropyLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, logits, labels):
        '''
        usage similar to nn.CrossEntropyLoss:
            >>> crit = TaylorCrossEntropyLoss(n=4)
            >>> inten = torch.randn(1, 10, 64, 64)
            >>> label = torch.randint(0, 10, (1, 64, 64))
            >>> out = crit(inten, label)
        '''
        log_probs = self.taylor_softmax(logits).log()
        loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
                ignore_index=self.ignore_index)
        return loss

class TaylorCrossEntropyLoss(nn.Module):

    def __init__(self, n=2, ignore_index=-1, reduction='mean', smoothing=0.2):
        super(TaylorCrossEntropyLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.lab_smooth = LabelSmoothingLoss(Config['num_classes'], smoothing=smoothing)

    def forward(self, logits, labels):

        log_probs = self.taylor_softmax(logits).log()
        loss = self.lab_smooth(log_probs, labels)
        return loss
    
    
## reference: https://github.com/HanxunH/Active-Passive-Losses    
class NormalizedCrossEntropy(nn.Module):
    def __init__(self, num_classes, scale=1.0):
        super(NormalizedCrossEntropy, self).__init__()
        #self.device = device
        self.num_classes = num_classes
        self.scale = scale
        self.taylor_softmax = TaylorSoftmax(dim=1, n=2)

    def forward(self, pred, labels):
        pred = F.log_softmax(pred, dim=1)
        #pred = self.taylor_softmax(pred).log()
        label_one_hot = torch.nn.functional.one_hot(labels, self.num_classes).float()
        nce = -1 * torch.sum(label_one_hot * pred, dim=1) / (- pred.sum(dim=1))
        return self.scale * nce.mean()
    
class ReverseCrossEntropy(nn.Module):
    def __init__(self, num_classes, scale=1.0):
        super(ReverseCrossEntropy, self).__init__()
        #self.device = device
        self.num_classes = num_classes
        self.scale = scale
        self.taylor_softmax = TaylorSoftmax(dim=1, n=2)

    def forward(self, pred, labels):
        pred = F.softmax(pred, dim=1)
        #pred = self.taylor_softmax(pred).log()
        pred = torch.clamp(pred, min=1e-7, max=1.0)
        label_one_hot = torch.nn.functional.one_hot(labels, self.num_classes).float()
        label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
        rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))
        return self.scale * rce.mean()
    

class NCEandRCE(nn.Module):
    def __init__(self, alpha, beta, num_classes):
        super(NCEandRCE, self).__init__()
        self.num_classes = num_classes
        self.nce = NormalizedCrossEntropy(scale=alpha, num_classes=num_classes)
        self.rce = ReverseCrossEntropy(scale=beta, num_classes=num_classes)

    def forward(self, pred, labels):
        return self.nce(pred, labels) + self.rce(pred, labels)
    
class ComboLoss(nn.Module):
    def __init__(self):
        super(ComboLoss, self).__init__()
        
        self.loss_1 = BiTemperedLogisticLoss(t1=0.32, t2=1.0, smoothing=0.0)
        self.loss_2 = TaylorCrossEntropyLoss(smoothing=0.06)
        self.loss_3 = NCEandRCE(alpha=0.17, beta=0.34, num_classes=5)
        #self.loss_4 = FocalCosineLoss(alpha=0.3868, gamma=1.934, xent=0.1)
        
    def forward(self, preds, labels):
        return self.loss_1(preds, labels) + self.loss_2(preds, labels) + self.loss_3(preds, labels) #+ self.loss_4(preds, labels)

## Optimizers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer, required
import math

class AdamP(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
                        delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
        super(AdamP, self).__init__(params, defaults)

    def _channel_view(self, x):
        return x.view(x.size(0), -1)

    def _layer_view(self, x):
        return x.view(1, -1)

    def _cosine_similarity(self, x, y, eps, view_func):
        x = view_func(x)
        y = view_func(y)

        return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()

    def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
        wd = 1
        expand_size = [-1] + [1] * (len(p.shape) - 1)
        for view_func in [self._channel_view, self._layer_view]:

            cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)

            if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
                p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
                perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
                wd = wd_ratio

                return perturb, wd

        return perturb, wd

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                beta1, beta2 = group['betas']
                nesterov = group['nesterov']

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                # Adam
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                state['step'] += 1
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                step_size = group['lr'] / bias_correction1

                if nesterov:
                    perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
                else:
                    perturb = exp_avg / denom

                # Projection
                wd_ratio = 1
                if len(p.shape) > 1:
                    perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])

                # Weight decay
                if group['weight_decay'] > 0:
                    p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio)

                # Step
                p.data.add_(perturb, alpha=-step_size)

        return loss

class SGDP(Optimizer):
    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
                        nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
        super(SGDP, self).__init__(params, defaults)

    def _channel_view(self, x):
        return x.view(x.size(0), -1)

    def _layer_view(self, x):
        return x.view(1, -1)

    def _cosine_similarity(self, x, y, eps, view_func):
        x = view_func(x)
        y = view_func(y)

        return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()

    def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
        wd = 1
        expand_size = [-1] + [1] * (len(p.shape) - 1)
        for view_func in [self._channel_view, self._layer_view]:

            cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)

            if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
                p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
                perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
                wd = wd_ratio

                return perturb, wd

        return perturb, wd

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['momentum'] = torch.zeros_like(p.data)

                # SGD
                buf = state['momentum']
                buf.mul_(momentum).add_(grad, alpha=1 - dampening)
                if nesterov:
                    d_p = grad + momentum * buf
                else:
                    d_p = buf

                # Projection
                wd_ratio = 1
                if len(p.shape) > 1:
                    d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])

                # Weight decay
                if group['weight_decay'] > 0:
                    p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))

                # Step
                p.data.add_(d_p, alpha=-group['lr'])

        return loss

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    def step(self, closure=None):
        raise NotImplementedError("SAM doesn't work like the other optimizers, you should first call `first_step` and the `second_step`; see the documentation for more info.")

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

In [None]:
## reference: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/optim_factory.py
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

In [None]:
def main(fold):
   
    seed_everything(Config['seed'])
    device = torch.device(Config['device'])
    
    train = pd.read_csv(Config['train_set'])
    train2 = pd.read_csv(Config['train_set2'])
    
    train_idx = np.where((train['fold'] != fold))[0]
    valid_idx = np.where((train['fold'] == fold))[0]
        
    num_classes = train.label.nunique()

    print('Training with {} started'.format(fold))
    
    print("Training set size :", len(train_idx), "Validation set size :", len(valid_idx))
    train_loader, valid_loader = prepare_dataloader(train, train2, train_idx, valid_idx, data_image_root=Config['data_image_root'])
        
    model = CassavaImageClassifier(Config['model_arch'], num_classes, pretrained=True).to(device)
    
    # load trained model
    if Config['resume']:
        #model.load_state_dict(torch.load('model_path'))
    
    # reference: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/optim_factory.py
    weight_decay = Config['weight_decay']
    if weight_decay and Config['filter_bias_and_bn']:
        skip = {}
        if hasattr(model, 'no_weight_decay'):
            skip = model.no_weight_decay()
        model_parameters = add_weight_decay(model, weight_decay, skip)
        weight_decay = 0.
    else:
        model_parameters = model.parameters()
    
    # set optimizer
    base_optimizer = AdamP
    optimizer = SAM(model_parameters,
                    base_optimizer,
                    rho=Config['rho'],
                    lr=Config['init_lr'],
                    weight_decay=weight_decay,
                    nesterov=True)

    # set learning rate scheduler
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, Config['epochs']-Config['warmup_epoch'])
    scheduler = GradualWarmupScheduler(optimizer, multiplier=Config['warmup_factor'], total_epoch=Config['warmup_epoch'], after_scheduler=scheduler_cosine)
    
    # set loss function
    criterion = ComboLoss().to(device)

    best_accuracy = 0
        
    for epoch in range(1, Config['epochs']+1):
        print(time.ctime(), 'Epoch :', epoch)
        scheduler.step(epoch-1)

        train_loss, train_accuracy = train_one_epoch(model, epoch, criterion, optimizer, train_loader, device)
        valid_loss, valid_accuracy = valid_one_epoch(model, criterion, valid_loader, device)
        
        content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {np.mean(train_loss):.5f}, train accuracy: {(train_accuracy):.5f}, valid loss: {np.mean(valid_loss):.5f}, validation accuracy: {(valid_accuracy):.5f}'
        print(content)
        with open('log_{}_{}_fold_{}.txt'.format(Config['kernel_type'], Config['model_arch'], fold), 'a') as appender:
            appender.write(content + '\n')
                
        ## check maximum accuracy per validation and save best model
        if valid_accuracy > best_accuracy:
            print('valid accuracy ({:.6f} --> {:.6f}).  Saving model ...'.format(best_accuracy, valid_accuracy))
            torch.save(model.state_dict(), '{}_{}_best_fold_{}.pth'.format(Config['kernel_type'], Config['model_arch'], fold))
            best_accuracy = valid_accuracy
       
    del model, optimizer, train_loader, valid_loader, scheduler_cosine, scheduler
    torch.cuda.empty_cache()
    gc.collect()   
    
    return best_accuracy

In [None]:
#main(fold=1)