# CONFIG

In [None]:
CFG = {
    #general setting
        'fold_num': 5,
        'only_one_fold':True,
        'epochs': 7,
        'seed': 719,
        'train_bs': 16,
        'valid_bs': 32,
        'num_workers': 4,
        'accum_iter': 1, # suppoprt to do batch accumulation for backprop with effectively larger batch size
        'verbose_step': 1,
        'device': 'tpu',#device can be either cuda:0 or tpu
    #dataset setting(change between 2020 dataset and 2019+2020 dataset)
        'train_imgs_path':'../input/cassava-leaf-disease-classification/train_images',
    #data cleaning setting
        'train_csv_path':'../input/cassava-leaf-disease-classification/train.csv',
    #model archetecture setting
        'model_arch': 'tf_efficientnet_b4_ns',
    #augmentation
        'img_size': 512,
    #loss function setting
        'loss_function':'CrossEntropyLoss',
        #(CrossEntropyLoss,LabelSmoothingLoss,FocalLoss,FocalCosineLoss,SymmetricCrossEntropy,BiTemperedLogisticLoss,TaylorCrossEntropyLoss)
        'LabelSmoothingLoss_smoothing':0.1,
        'FocalLoss_alpha':1, 
        'FocalLoss_gamma':2,
        'FocalCosineLoss_alpha':1,
        'FocalCosineLoss_gamma':2, 
        'FocalCosineLoss_xent':0.1,
        'SymmetricCrossEntropy_alpha':0.1, 
        'SymmetricCrossEntropy_beta':1.0,
        'BiTemperedLogisticLoss_t1':0.3, 
        'BiTemperedLogisticLoss_t2':1.0, 
        'BiTemperedLogisticLoss_smoothing':0.0,
        'TaylorCrossEntropyLoss_smoothing':0.05,
        'TaylorCrossEntropyLoss_n':2,
    #optimizer setting
        'optimizer':'Adam',#(Adam,SGD)
        'Adam_lr': 1e-4,
        'Adam_weight_decay':1e-6,#regularization,add l2 loss
        'SGD_lr':1e-4,
        'SGD_momentum':0.9,
    #schedular setting
        'lr_schedular':'CosineAnnealingLR',
        #(StepLR,ExponentialLR,CosineAnnealingLR,ReduceLROnPlateau,CosineAnnealingWarmRestarts)
        'StepLR_step_size':2,
        'StepLR_gamma':0.5,
        'ExponentialLR_gamma':0.9,
        'CosineAnnealingLR_T_max':10,
        'CosineAnnealingLR_eta_min':0,
        'ReduceLROnPlateau_factor':0.5,
        'ReduceLROnPlateau_patience':1,
        'ReduceLROnPlateau_threshold':0.0001,
        'ReduceLROnPlateau_min_lr':0,
        'CosineAnnealingWarmRestarts_T_0':10,
        'CosineAnnealingWarmRestarts_min_lr':1e-6,
    
}

In [None]:
package_paths = ['../input/pytorch-image-models/pytorch-image-models-master']
import sys; 

for pth in package_paths:
    sys.path.append(pth)

# TPU package installation

In [None]:
if CFG['device']!='cuda:0':
    !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py  > /dev/null
    !python pytorch-xla-env-setup.py --version nightly  > /dev/null
    
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
    import torch_xla.utils.serialization as xser
    import gc
    import os

    os.environ['XLA_USE_BF16']="1"
    os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

    import warnings
    warnings.filterwarnings("ignore")

In [None]:
from glob import glob
from sklearn.model_selection import GroupKFold, StratifiedKFold
import cv2
from skimage import io
import torch
from torch import nn
import os
from datetime import datetime
import time
import random
import cv2,gc
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F

import timm

import sklearn
import warnings
import joblib
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
import warnings
import cv2
import pydicom
#from efficientnet_pytorch import EfficientNet
from scipy.ndimage.interpolation import zoom

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    #print(im_rgb)
    return im_rgb

In [None]:
def rand_bbox(size, lam):
    W = size[0]
    H = size[1]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2


class CassavaDataset(Dataset):
    def __init__(self, df, data_root, 
                 transforms=None, 
                 output_label=True, 
                 one_hot_label=False,
                 do_fmix=False, 
                 fmix_params={
                     'alpha': 1., 
                     'decay_power': 3., 
                     'shape': (CFG['img_size'], CFG['img_size']),
                     'max_soft': True, 
                     'reformulate': False
                 },
                 do_cutmix=False,
                 cutmix_params={
                     'alpha': 1,
                 }
                ):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        self.do_fmix = do_fmix
        self.fmix_params = fmix_params
        self.do_cutmix = do_cutmix
        self.cutmix_params = cutmix_params
        
        self.output_label = output_label
        self.one_hot_label = one_hot_label
        
        if output_label == True:
            self.labels = self.df['label'].values
            #print(self.labels)
            
            if one_hot_label is True:
                self.labels = np.eye(self.df['label'].max()+1)[self.labels]
                #print(self.labels)
            
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.labels[index]
          
        img  = get_img("{}/{}".format(self.data_root, self.df.loc[index]['image_id']))

        if self.transforms:
            img = self.transforms(image=img)['image']
        
        if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            with torch.no_grad():
                #lam, mask = sample_mask(**self.fmix_params)
                
                lam = np.clip(np.random.beta(self.fmix_params['alpha'], self.fmix_params['alpha']),0.6,0.7)
                
                # Make mask, get mean / std
                mask = make_low_freq_image(self.fmix_params['decay_power'], self.fmix_params['shape'])
                mask = binarise_mask(mask, lam, self.fmix_params['shape'], self.fmix_params['max_soft'])
    
                fmix_ix = np.random.choice(self.df.index, size=1)[0]
                fmix_img  = get_img("{}/{}".format(self.data_root, self.df.iloc[fmix_ix]['image_id']))

                if self.transforms:
                    fmix_img = self.transforms(image=fmix_img)['image']

                mask_torch = torch.from_numpy(mask)
                
                # mix image
                img = mask_torch*img+(1.-mask_torch)*fmix_img

                #print(mask.shape)

                #assert self.output_label==True and self.one_hot_label==True

                # mix target
                rate = mask.sum()/CFG['img_size']/CFG['img_size']
                target = rate*target + (1.-rate)*self.labels[fmix_ix]
                #print(target, mask, img)
                #assert False
        
        if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            #print(img.sum(), img.shape)
            with torch.no_grad():
                cmix_ix = np.random.choice(self.df.index, size=1)[0]
                cmix_img  = get_img("{}/{}".format(self.data_root, self.df.iloc[cmix_ix]['image_id']))
                if self.transforms:
                    cmix_img = self.transforms(image=cmix_img)['image']
                    
                lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']),0.3,0.4)
                bbx1, bby1, bbx2, bby2 = rand_bbox((CFG['img_size'], CFG['img_size']), lam)

                img[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2]

                rate = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (CFG['img_size'] * CFG['img_size']))
                target = rate*target + (1.-rate)*self.labels[cmix_ix]
                
            #print('-', img.sum())
            #print(target)
            #assert False
                            
        # do label smoothing
        #print(type(img), type(target))
        if self.output_label == True:
            return img, target
        else:
            return img

# Augmentations

In [None]:
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

def get_train_transforms():
    return Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.5,0.5), contrast_limit=(-0.5, 0.5), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            Cutout(p=0.5),
            ToTensorV2(p=1.0),
        ], p=1.)
  
        
def get_valid_transforms():
    return Compose([
            Resize(CFG['img_size'], CFG['img_size']),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

In [None]:
class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained,num_classes=n_class)

    def forward(self, x):
        x = self.model(x)
        return x

# LOSS FUNCTION

In [None]:
# ====================================================
# Label Smoothing
# ====================================================
class LabelSmoothingLoss(nn.Module): 
    def __init__(self, classes=5, smoothing=0.1, dim=-1): 
        super(LabelSmoothingLoss, self).__init__() 
        self.confidence = 1.0 - smoothing 
        self.smoothing = smoothing 
        self.cls = classes 
        self.dim = dim 
    def forward(self, pred, target): 
        pred = pred.log_softmax(dim=self.dim) 
        with torch.no_grad():
            true_dist = torch.zeros_like(pred) 
            true_dist.fill_(self.smoothing / (self.cls - 1)) 
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self, inputs, targets):
        BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

In [None]:
class FocalCosineLoss(nn.Module):#few shot and inbalanced dataset
    def __init__(self, alpha=1, gamma=2, xent=0.1):
        super(FocalCosineLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.xent = xent

        self.y = torch.Tensor([1]).cuda()

    def forward(self, input, target, reduction="mean"):
        cosine_loss = F.cosine_embedding_loss(input, F.one_hot(target, num_classes=input.size(-1)), self.y, reduction=reduction)

        cent_loss = F.cross_entropy(F.normalize(input), target, reduce=False)
        pt = torch.exp(-cent_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * cent_loss

        if reduction == "mean":
            focal_loss = torch.mean(focal_loss)

        return cosine_loss + self.xent * focal_loss

In [None]:
class SymmetricCrossEntropy(nn.Module):#noisy data

    def __init__(self, alpha=0.1, beta=1.0, num_classes=5):
        super(SymmetricCrossEntropy, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.num_classes = num_classes

    def forward(self, logits, targets, reduction='mean'):
        onehot_targets = torch.eye(self.num_classes)[targets].cuda()
        ce_loss = F.cross_entropy(logits, targets, reduction=reduction)
        rce_loss = (-onehot_targets*logits.softmax(1).clamp(1e-7, 1.0).log()).sum(1)
        if reduction == 'mean':
            rce_loss = rce_loss.mean()
        elif reduction == 'sum':
            rce_loss = rce_loss.sum()
        return self.alpha * ce_loss + self.beta * rce_loss

In [None]:
#Bi-Tempered-Loss(noisy data)
def log_t(u, t):
    """Compute log_t for `u'."""
    if t==1.0:
        return u.log()
    else:
        return (u.pow(1.0 - t) - 1.0) / (1.0 - t)

def exp_t(u, t):
    """Compute exp_t for `u'."""
    if t==1:
        return u.exp()
    else:
        return (1.0 + (1.0-t)*u).relu().pow(1.0 / (1.0 - t))

def compute_normalization_fixed_point(activations, t, num_iters):

    """Returns the normalization value for each example (t > 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same shape as activation with the last dimension being 1.
    """
    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations_step_0 = activations - mu

    normalized_activations = normalized_activations_step_0

    for _ in range(num_iters):
        logt_partition = torch.sum(
                exp_t(normalized_activations, t), -1, keepdim=True)
        normalized_activations = normalized_activations_step_0 * \
                logt_partition.pow(1.0-t)

    logt_partition = torch.sum(
            exp_t(normalized_activations, t), -1, keepdim=True)
    normalization_constants = - log_t(1.0 / logt_partition, t) + mu

    return normalization_constants

def compute_normalization_binary_search(activations, t, num_iters):

    """Returns the normalization value for each example (t < 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (< 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """

    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations = activations - mu

    effective_dim = \
        torch.sum(
                (normalized_activations > -1.0 / (1.0-t)).to(torch.int32),
            dim=-1, keepdim=True).to(activations.dtype)

    shape_partition = activations.shape[:-1] + (1,)
    lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device)
    upper = -log_t(1.0/effective_dim, t) * torch.ones_like(lower)

    for _ in range(num_iters):
        logt_partition = (upper + lower)/2.0
        sum_probs = torch.sum(
                exp_t(normalized_activations - logt_partition, t),
                dim=-1, keepdim=True)
        update = (sum_probs < 1.0).to(activations.dtype)
        lower = torch.reshape(
                lower * update + (1.0-update) * logt_partition,
                shape_partition)
        upper = torch.reshape(
                upper * (1.0 - update) + update * logt_partition,
                shape_partition)

    logt_partition = (upper + lower)/2.0
    return logt_partition + mu

class ComputeNormalization(torch.autograd.Function):
    """
    Class implementing custom backward pass for compute_normalization. See compute_normalization.
    """
    @staticmethod
    def forward(ctx, activations, t, num_iters):
        if t < 1.0:
            normalization_constants = compute_normalization_binary_search(activations, t, num_iters)
        else:
            normalization_constants = compute_normalization_fixed_point(activations, t, num_iters)

        ctx.save_for_backward(activations, normalization_constants)
        ctx.t=t
        return normalization_constants

    @staticmethod
    def backward(ctx, grad_output):
        activations, normalization_constants = ctx.saved_tensors
        t = ctx.t
        normalized_activations = activations - normalization_constants 
        probabilities = exp_t(normalized_activations, t)
        escorts = probabilities.pow(t)
        escorts = escorts / escorts.sum(dim=-1, keepdim=True)
        grad_input = escorts * grad_output
        
        return grad_input, None, None

def compute_normalization(activations, t, num_iters=5):
    """Returns the normalization value for each example. 
    Backward pass is implemented.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    return ComputeNormalization.apply(activations, t, num_iters)

def tempered_sigmoid(activations, t, num_iters = 5):
    """Tempered sigmoid function.
    Args:
      activations: Activations for the positive class for binary classification.
      t: Temperature tensor > 0.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_probabilities = tempered_softmax(internal_activations, t, num_iters)
    return internal_probabilities[..., 0]


def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature > 1.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    if t == 1.0:
        return activations.softmax(dim=-1)

    normalization_constants = compute_normalization(activations, t, num_iters)
    return exp_t(activations - normalization_constants, t)

def bi_tempered_binary_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing = 0.0,
        num_iters=5,
        reduction='mean'):

    """Bi-Tempered binary logistic loss.
    Args:
      activations: A tensor containing activations for class 1.
      labels: A tensor with shape as activations, containing probabilities for class 1
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing
      num_iters: Number of iterations to run the method.
    Returns:
      A loss tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_labels = torch.stack([labels.to(activations.dtype),
        1.0 - labels.to(activations.dtype)],
        dim=-1)
    return bi_tempered_logistic_loss(internal_activations, 
            internal_labels,
            t1,
            t2,
            label_smoothing = label_smoothing,
            num_iters = num_iters,
            reduction = reduction)

def bi_tempered_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing=0.0,
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
            - labels_onehot * log_t(probabilities, t1) \
            - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
            + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim = -1) #sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()
    
class BiTemperedLogisticLoss(nn.Module): 
    def __init__(self, t1=0.3, t2=1.0, smoothing=0.0): 
        super(BiTemperedLogisticLoss, self).__init__() 
        self.t1 = t1
        self.t2 = t2
        self.smoothing = smoothing
    def forward(self, logit_label, truth_label):
        loss_label = bi_tempered_logistic_loss(
            logit_label, truth_label,
            t1=self.t1, t2=self.t2,
            label_smoothing=self.smoothing,
            reduction='none'
        )
        
        loss_label = loss_label.mean()
        return loss_label

In [None]:
class TaylorSoftmax(nn.Module):
    '''
    This is the autograd version
    '''
    def __init__(self, dim=1, n=2):
        super(TaylorSoftmax, self).__init__()
        assert n % 2 == 0
        self.dim = dim
        self.n = n

    def forward(self, x):
        '''
        usage similar to nn.Softmax:
            >>> mod = TaylorSoftmax(dim=1, n=4)
            >>> inten = torch.randn(1, 32, 64, 64)
            >>> out = mod(inten)
        '''
        fn = torch.ones_like(x)
        denor = 1.
        for i in range(1, self.n+1):
            denor *= i
            fn = fn + x.pow(i) / denor
        out = fn / fn.sum(dim=self.dim, keepdims=True)
        return out

    
class TaylorCrossEntropyLoss(nn.Module):
    def __init__(self, n=2, ignore_index=-1, reduction='mean', smoothing=0.05):
        super(TaylorCrossEntropyLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.lab_smooth = LabelSmoothingLoss(5, smoothing=smoothing)

    def forward(self, logits, labels):
        log_probs = self.taylor_softmax(logits).log()
        #loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
        #        ignore_index=self.ignore_index)
        loss = self.lab_smooth(log_probs, labels)
        return loss

In [None]:
def define_loss_function():
    if CFG['loss_function']=='CrossEntropyLoss':
        return nn.CrossEntropyLoss()
    if CFG['loss_function']=='LabelSmoothingLoss':
        return LabelSmoothingLoss(smoothing=CFG['LabelSmoothingLoss_smoothing'])
    if CFG['loss_function']=='FocalLoss':
        return FocalLoss(alpha=CFG['FocalLoss_alpha'],gamma=CFG['FocalLoss_gamma'])
    if CFG['loss_function']=='FocalCosineLoss':
        return FocalCosineLoss(alpha=CFG['FocalCosineLoss_alpha'],gamma=CFG['FocalCosineLoss_gamma'],xent=CFG['FocalCosineLoss_xent'])
    if CFG['loss_function']=='SymmetricCrossEntropy':
        return SymmetricCrossEntropy(alpha=CFG['SymmetricCrossEntropy_alpha'],beta=CFG['SymmetricCrossEntropy_beta'])
    if CFG['loss_function']=='BiTemperedLogisticLoss':
        return BiTemperedLogisticLoss(t1=CFG['BiTemperedLogisticLoss_t1'],t2=CFG['BiTemperedLogisticLoss_t2'],smoothing=CFG['BiTemperedLogisticLoss_smoothing'])
    if CFG['loss_function']=='TaylorCrossEntropyLoss':
        return TaylorCrossEntropyLoss(n=CFG['TaylorCrossEntropyLoss_n'],smoothing=CFG['TaylorCrossEntropyLoss_smoothing'])

In [None]:
def prepare_dataloader(df, trn_idx, val_idx,device,data_root=CFG['train_imgs_path']):
    
    train_ = df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = df.loc[val_idx,:].reset_index(drop=True)
        
    train_ds = CassavaDataset(train_, data_root, transforms=get_train_transforms(), output_label=True, one_hot_label=False, do_fmix=False, do_cutmix=False)
    valid_ds = CassavaDataset(valid_, data_root, transforms=get_valid_transforms(), output_label=True)
    
    if CFG['device'] == 'cuda:0':
        train_loader = torch.utils.data.DataLoader(
            train_ds,
            batch_size=CFG['train_bs'],
            pin_memory=False,
            drop_last=False,
            shuffle=True,        
            num_workers=CFG['num_workers'],
        )
        valid_loader = torch.utils.data.DataLoader(
            valid_ds, 
            batch_size=CFG['valid_bs'],
            num_workers=CFG['num_workers'],
            shuffle=False,
            pin_memory=False,
        )
        
    else:   
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_ds,
            num_replicas=xm.xrt_world_size(), #divide dataset among this many replicas
            rank=xm.get_ordinal(), #which replica/device/core
            shuffle=True)

        # define DataLoader with the defined sampler
        train_loader = torch.utils.data.DataLoader(
            train_ds,
            batch_size=CFG['train_bs'],
            sampler=train_sampler,
            num_workers=CFG['num_workers'],
            drop_last=True)

        # same as train but with valid data
        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_ds,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False)

        valid_loader = torch.utils.data.DataLoader(
            valid_ds,
            batch_size=CFG['valid_bs'],
            sampler=valid_sampler,
            num_workers=CFG['num_workers'],
            drop_last=False)

        train_loader = pl.MpDeviceLoader(train_loader, device) # puts the train data onto the current TPU core
        valid_loader = pl.MpDeviceLoader(valid_loader, device) # puts the valid data onto the current TPU core
    
    return train_loader, valid_loader

def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, scheduler=None, schd_batch_update=False):
    
    model.train()
    t = time.time()
    running_loss = None
    
    if CFG['device']=='cuda:0' :
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    else:
        pbar = enumerate(train_loader)
        xm.master_print('start_training')

        
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds = model(imgs)   #output = model(input)
        loss = loss_fn(image_preds, image_labels)
        loss.backward()

        if ((step + 1) %  CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)):
            # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

            if CFG['device']=='cuda:0':
                optimizer.step()
            else:
                xm.optimizer_step(optimizer)        
            optimizer.zero_grad() 

        if CFG['device']=='cuda:0':
            if running_loss is None:
                running_loss = loss.item()
            else:
                running_loss = running_loss * .99 + loss.item() * .01

            if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)):
                description = f'epoch {epoch} loss: {running_loss:.4f}'
                pbar.set_description(description)
        else:
            loss_reduced = xm.mesh_reduce('loss_reduce',loss, lambda x: sum(x) / len(x)) 
            xm.master_print(f'training step : {step} loss_reduced : {loss_reduced}')
                
    if scheduler is not None and not schd_batch_update:
        scheduler.step()
        
def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False):
    
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []
    
    if CFG['device']=='cuda:0' :
        pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    else:
        pbar = enumerate(val_loader)
        xm.master_print('start validation')
        
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds = model(imgs)   #output = model(input)
        #print(image_preds.shape, exam_pred.shape)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]
        
        loss = loss_fn(image_preds, image_labels)
        sample_num += image_labels.shape[0]  
        
        if CFG['device']=='cuda:0':
            loss_sum += loss.item()*image_labels.shape[0]
            if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)):
                description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
                pbar.set_description(description)
        elif ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)):
            xm.master_print(f'validation step : {step}')
    
    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    
    if CFG['device']=='cuda:0' or xm.is_master_ordinal():
        print('validation multi-class accuracy = {:.4f}'.format((image_preds_all==image_targets_all).mean()))
    
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()
    return (image_preds_all==image_targets_all).mean()

# OPTIMIZER

In [None]:
def define_optimizer(model):
    if CFG['optimizer']=='Adam':
        return torch.optim.Adam(model.parameters(), lr=CFG['Adam_lr'], weight_decay=CFG['Adam_weight_decay'])
    if CFG['optimizer']=='SGD':
        return torch.optim.SGD(model.parameters(), lr=CFG['SGD_lr'], momentum=CFG['SGD_momentum'])

# LR_SCHEDULAR

In [None]:
def define_lr_schedular(optimizer):
    if CFG['lr_schedular']=='StepLR':
        return torch.optim.lr_scheduler.StepLR(optimizer, step_size=CFG['StepLR_step_size'], gamma=CFG['StepLR_gamma'])
    if CFG['lr_schedular']=='ExponentialLR':
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=CFG['ExponentialLR_gamma'])
    if CFG['lr_schedular']=='CosineAnnealingLR':
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG['CosineAnnealingLR_T_max'], 
                                                          eta_min=CFG['CosineAnnealingLR_eta_min'])
    if CFG['lr_schedular']=='ReduceLROnPlateau':
        return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=CFG['ReduceLROnPlateau_factor'],
                                                          patience=CFG['ReduceLROnPlateau_patience'], 
                                                          threshold=CFG['ReduceLROnPlateau_threshold'], threshold_mode='rel', 
                                                          cooldown=0, min_lr=CFG['ReduceLROnPlateau_min_lr'], eps=1e-08)
    if CFG['lr_schedular']=='CosineAnnealingWarmRestarts':
        return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['CosineAnnealingWarmRestarts_T_0'], 
                                                                    eta_min=CFG['CosineAnnealingWarmRestarts'])

# print hyperparameter

In [None]:
def print_hyperparameter():
    print(f'overall epochs:{CFG["epochs"]}')
    print(f'archetecture:{CFG["model_arch"]}')
    print(f'train images path:{CFG["train_imgs_path"]}')
    print(f'train csv path:{CFG["train_csv_path"]}')
    
    print(f'loss function:{CFG["loss_function"]}')
    if CFG["loss_function"]=='LabelSmoothingLoss':
        print(f'LabelSmoothingLoss_smoothing:{CFG["LabelSmoothingLoss_smoothing"]}')
    elif CFG["loss_function"]=='FocalLoss':
        print(f'FocalLoss_alpha:{CFG["FocalLoss_alpha"]} FocalLoss_gamma:{CFG["FocalLoss_gamma"]}')
    elif CFG["loss_function"]=='FocalCosineLoss':
         print(f'FocalCosineLoss_alpha:{CFG["FocalCosineLoss_alpha"]} \
         FocalCosineLoss_gamma:{CFG["FocalCosineLoss_gamma"]}\
         FocalCosineLoss_xent:{CFG["FocalCosineLoss_xent"]}')
    elif CFG["loss_function"]=='SymmetricCrossEntropy':
         print(f'SymmetricCrossEntropy_alpha:{CFG["SymmetricCrossEntropy_alpha"]} \
         SymmetricCrossEntropy_beta:{CFG["SymmetricCrossEntropy_beta"]}')
    elif CFG["loss_function"]=='BiTemperedLogisticLoss':
         print(f'BiTemperedLogisticLoss_t1:{CFG["BiTemperedLogisticLoss_t1"]} \
               BiTemperedLogisticLoss_t2:{CFG["BiTemperedLogisticLoss_t2"]}\
               BiTemperedLogisticLoss_smoothing:{CFG["BiTemperedLogisticLoss_smoothing"]}')
    elif CFG["loss_function"]=='TaylorCrossEntropyLoss':
         print(f'TaylorCrossEntropyLoss_smoothing:{CFG["TaylorCrossEntropyLoss_smoothing"]} \
         TaylorCrossEntropyLoss_n {CFG["TaylorCrossEntropyLoss_n"]}')
            
    print(f'optimizer:{CFG["optimizer"]}')
    if CFG["optimizer"]=='Adam':
        print(f'Adam_lr:{CFG["Adam_lr"]} Adam_weight_decay:{CFG["Adam_weight_decay"]}')
    elif CFG["optimizer"]=='SGD':
        print(f'SGD_lr:{CFG["SGD_lr"]} SGD_momentum:{CFG["SGD_momentum"]}')
        
    print(f'lr_schedular:{CFG["lr_schedular"]}')
    if CFG["lr_schedular"]=='StepLR':
        print(f'StepLR_step_size:{CFG["StepLR_step_size"]} StepLR_gamma:{CFG["StepLR_gamma"]}')
    elif CFG["lr_schedular"]=='ExponentialLR':
        print(f'ExponentialLR_gamma:{CFG["ExponentialLR_gamma"]}')
    elif CFG["lr_schedular"]=='CosineAnnealingLR':
         print(f'CosineAnnealingLR_T_max:{CFG["CosineAnnealingLR_T_max"]}\
                 CosineAnnealingLR_eta_min:{CFG["CosineAnnealingLR_eta_min"]}')
    elif CFG["lr_schedular"]=='ReduceLROnPlateau':
         print(f'ReduceLROnPlateau_factor:{CFG["ReduceLROnPlateau_factor"]}\
               ReduceLROnPlateau_patience:{CFG["ReduceLROnPlateau_patience"]}\
               ReduceLROnPlateau_threshold:{CFG["ReduceLROnPlateau_threshold"]}\
               ReduceLROnPlateau_min_lr:{CFG["ReduceLROnPlateau_min_lr"]}')
    elif CFG["lr_schedular"]=='CosineAnnealingWarmRestarts':
         print(f'CosineAnnealingWarmRestarts_T_0:{CFG["CosineAnnealingWarmRestarts_T_0"]}\
         CosineAnnealingWarmRestarts_min_lr:{CFG["CosineAnnealingWarmRestarts_min_lr"]}')

# MAIN

In [None]:
def _mp_fn(rank):
    
    global MX
    
    if CFG['device']=='cuda:0' or (CFG['device']!='cuda:0' and xm.is_master_ordinal()):
        print_hyperparameter()
        
    torch.set_default_tensor_type('torch.FloatTensor')
    seed_everything(CFG['seed'])
    device=torch.device('cuda:0') if CFG['device']=='cuda:0' else xm.xla_device()
    
    train = pd.read_csv(CFG['train_csv_path'])
    folds = StratifiedKFold(n_splits=CFG['fold_num'], shuffle=True, random_state=CFG['seed']).split(np.arange(train.shape[0]), train.label.values)
    
    fold_accuracy=[]
    fold_epoch=[]
    for fold, (trn_idx, val_idx) in enumerate(folds):
        # we'll train fold 0 first
        if fold > 0 and CFG['only_one_fold']==True:
            break 
        if  CFG['device'] =='cuda:0' or (CFG['device']!='cuda:0' and xm.is_master_ordinal()):
            print('Training with {} started ,{} train ,{} test'.format(fold,len(trn_idx),len(val_idx)))

        train_loader, val_loader = prepare_dataloader(train, trn_idx, val_idx, device=device,data_root=CFG['train_imgs_path'])
        
        if CFG['device']=='cuda:0':
            model = CassvaImgClassifier(CFG['model_arch'], train.label.nunique(), pretrained=True).to(device)
        else:
            model = MX.to(device)
        optimizer = define_optimizer(model)
        scheduler = define_lr_schedular(optimizer)
        
        loss_tr = define_loss_function().to(device)
        loss_fn = define_loss_function().to(device)
        
        epoch_highest_acc=0 #define highest accuracy recorder
        epoch_highest_record=0 #record highest accuracy occur in which epoch
        
        for epoch in range(CFG['epochs']):
            train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler, schd_batch_update=False)

            with torch.no_grad():
                accuracy = valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False)
                if accuracy > epoch_highest_acc:
                    epoch_highest_acc = accuracy
                    epoch_highest_record = epoch
            
            if CFG['device']=='cuda:0':
                torch.save(model.state_dict(),'{}_fold_{}_{}_{:.4f}'.format(CFG['model_arch'], fold, epoch,accuracy))
            else:
                xm.rendezvous('save_model')
                xm.master_print('save model')
                xm.save(model.state_dict(),'{}_fold_{}_{}_{:.4f}'.format(CFG['model_arch'], fold, epoch,accuracy))
            
        fold_accuracy.append(epoch_highest_acc)
        fold_epoch.append(epoch_highest_record)
        #create best link
        if  CFG['device'] =='cuda:0' or (CFG['device']!='cuda:0' and xm.is_master_ordinal()):
            print(f'fold {fold} finish,highest accuracy:{epoch_highest_acc},highest epoch:{epoch_highest_record}')

        del model, optimizer, train_loader, val_loader, scaler, scheduler
        torch.cuda.empty_cache()
    if  CFG['device'] =='cuda:0' or (CFG['device']!='cuda:0' and xm.is_master_ordinal()):
        print_hyperparameter()
        print('----------ALL FOLDS FINISHED----------')
        if CFG['only_one_fold']:
            print(f'best epochs fold0:{fold_epoch[0]}')
        else:
            print(f'best epochs fold0:{fold_epoch[0]},fold1:{fold_epoch[1]},fold2:{fold_epoch[2]},fold3:{fold_epoch[3]},fold4:{fold_epoch[4]}')
        print('final accuracy = {:.4f}'.format(np.mean(fold_accuracy)))

In [None]:
if __name__ == '__main__':
    if CFG['device']=='cuda:0':
        run(0)
    else:
        MX=xmp.MpModelWrapper(CassvaImgClassifier(CFG['model_arch'], 5, pretrained=True))
        xmp.spawn(_mp_fn, args=(), nprocs=8, start_method='fork')