This is the training notebook of the VIT 2020 model which was used for the our final submission which scored 91.3% on public and private leaderboard of the Cassava Leaf Disease Classification 2020 competition (1st place).

Special thanks go to [Mobassir](https://www.kaggle.com/mobassir) who has posted his notebook on VIT TPU training, the basis of this notebook.

# Installation of TPU dependencies

In [None]:
!python3.7 -m pip install --upgrade pip
!pip install   torch==1.7.0
!pip install  torchvision 
!pip install  cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl

In [None]:
!pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git

# imports 

In [None]:
import os
import torch
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.utils.serialization as xser
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
from warmup_scheduler import GradualWarmupScheduler
import sys; 
package_paths = [
    '../input/pytorch-image-models/pytorch-image-models-master',
    '../input/image-fmix/FMix-master'
]


for pth in package_paths:
    sys.path.append(pth)
    
import warnings
import pandas as pd
import numpy as np
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn import metrics
from transformers import get_linear_schedule_with_warmup
import time
import torchvision
import torch.nn as nn
from tqdm import tqdm_notebook as tqdm
from PIL import Image, ImageFile
from torch.utils.data import Dataset
import torch.optim as optim
from torchvision import transforms
from torch.optim import lr_scheduler
import sys
import gc
import os
import random
import skimage.io
from PIL import Image
import scipy as sp
import sklearn.metrics
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from functools import partial
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
from albumentations import Compose, Normalize, HorizontalFlip, VerticalFlip
from albumentations.pytorch import ToTensorV2
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
from contextlib import contextmanager
from pathlib import Path
from collections import defaultdict, Counter
from fmix import sample_mask, make_low_freq_image, binarise_mask
from glob import glob
from sklearn.model_selection import GroupKFold, StratifiedKFold
import cv2
from skimage import io
import torch
from torch import nn
import os
from datetime import datetime
import time
import random

import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm

import albumentations

import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.cuda.amp import autocast, GradScaler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F

import timm

import sklearn

import joblib
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
import cv2
import pydicom
from scipy.ndimage.interpolation import zoom


os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

warnings.filterwarnings("ignore")

print(torch.__version__)

# Configuration of parameters

In [None]:
CFG = {
    'fold_num': 5,
    'seed': 719,
    'model_arch': 'vit_base_patch16_384',
    'img_size': 384,
    'epochs': 10,
    'train_bs': 10,
    'valid_bs': 10,
    'T_0': 10,
    'lr': 1e-4,
    'min_lr': 1e-4, 
    'smoothing' : 0.06,
    't1' : 0.8,
    't2' : 1.4,
    'warmup_factor' : 7,
    'warmup_epo' : 1,
    'num_workers': 4,#8
    'accum_iter': 2, # batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1
}

# Reading train.csv file

We load here meta data necessary for the out of fold training and prediction process.

In [None]:
import pandas as pd

train = pd.read_csv('../input/cassava-leaf-disease-metadata-vit/train_j.csv')
val = pd.read_csv('../input/cassava-leaf-disease-metadata-vit/val_j.csv')

train = train.rename(columns={'target': 'label', 'img_name' : "image_id"})
val = val.rename(columns={'target': 'label', 'img_name': "image_id"})


def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

# Dataset Class


In [None]:
def rand_bbox(size, lam):
    W = size[0]
    H = size[1]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2


class CassavaDataset(Dataset):
    def __init__(self, df, data_root, 
                 transforms=None, 
                 output_label=True, 
                 one_hot_label=False,
                 do_fmix=False, 
                 fmix_params={
                     'alpha': 1., 
                     'decay_power': 3., 
                     'shape': (CFG['img_size'], CFG['img_size']),
                     'max_soft': True, 
                     'reformulate': False
                 },
                 do_cutmix=False,
                 cutmix_params={
                     'alpha': 1,
                 }
                ):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        self.do_fmix = do_fmix
        self.fmix_params = fmix_params
        self.do_cutmix = do_cutmix
        self.cutmix_params = cutmix_params
        
        self.output_label = output_label
        self.one_hot_label = one_hot_label
        
        if output_label == True:
            self.labels = self.df['label'].values

            if one_hot_label is True:
                self.labels = np.eye(self.df['label'].max()+1)[self.labels]

    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.labels[index]
          
        img  = get_img("{}/{}".format(self.data_root, self.df.loc[index]['image_id']))

        if self.transforms:
            img = self.transforms(image=img)['image']
        
        if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            with torch.no_grad():
                lam = np.clip(np.random.beta(self.fmix_params['alpha'], self.fmix_params['alpha']),0.6,0.7)
                
                # Make mask, get mean / std
                mask = make_low_freq_image(self.fmix_params['decay_power'], self.fmix_params['shape'])
                mask = binarise_mask(mask, lam, self.fmix_params['shape'], self.fmix_params['max_soft'])
    
                fmix_ix = np.random.choice(self.df.index, size=1)[0]
                fmix_img  = get_img("{}/{}".format(self.data_root, self.df.iloc[fmix_ix]['image_id']))

                if self.transforms:
                    fmix_img = self.transforms(image=fmix_img)['image']

                mask_torch = torch.from_numpy(mask)
                
                # mix image
                img = mask_torch*img+(1.-mask_torch)*fmix_img

                rate = mask.sum()/CFG['img_size']/CFG['img_size']
                target = rate*target + (1.-rate)*self.labels[fmix_ix]
       
        if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            with torch.no_grad():
                cmix_ix = np.random.choice(self.df.index, size=1)[0]
                cmix_img  = get_img("{}/{}".format(self.data_root, self.df.iloc[cmix_ix]['image_id']))
                if self.transforms:
                    cmix_img = self.transforms(image=cmix_img)['image']
                    
                lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']),0.3,0.4)
                bbx1, bby1, bbx2, bby2 = rand_bbox((CFG['img_size'], CFG['img_size']), lam)

                img[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2]

                rate = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (CFG['img_size'] * CFG['img_size']))
                target = rate*target + (1.-rate)*self.labels[cmix_ix]
                      
        if self.output_label == True:
            return img, target
        else:
            return img

# Augmentation



In [None]:
def get_train_transforms():
    return Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            Cutout(p=0.5),
            ToTensorV2(p=1.0),
        ], p=1.)
  
        
def get_valid_transforms():
    return Compose([
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            CenterCrop(CFG['img_size'], CFG['img_size'], p=1.),
            Resize(CFG['img_size'], CFG['img_size']),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

# Dataloader


In [None]:
def prepare_dataloader(train_, valid_, data_root='../input/cassava-leaf-disease-classification/train_images/'):
    
    from catalyst.data.sampler import BalanceClassSampler
    
    train_ds = CassavaDataset(train_, data_root, transforms=get_train_transforms(), output_label=True, one_hot_label=False, do_fmix=0, do_cutmix=0)
    valid_ds = CassavaDataset(valid_, data_root, transforms=get_valid_transforms(), output_label=True)
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_ds,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)


    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_ds,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False,
        )
    
    
    train_loader = torch.utils.data.DataLoader(
        dataset=train_ds,
        batch_size=CFG['train_bs'],
        sampler=train_sampler,
        drop_last=True,
        num_workers=CFG['num_workers'],
    )
   
    valid_loader = torch.utils.data.DataLoader(
        dataset=valid_ds,
        batch_size=CFG['valid_bs'],
        sampler=valid_sampler,
        drop_last=True,
        num_workers=CFG['num_workers'],
    )
    return train_loader, valid_loader



# Loss

In [None]:
# Code taken from https://github.com/fhopfmueller/bi-tempered-loss-pytorch/blob/master/bi_tempered_loss_pytorch.py

def log_t(u, t):
    """Compute log_t for `u'."""
    if t==1.0:
        return u.log()
    else:
        return (u.pow(1.0 - t) - 1.0) / (1.0 - t)

def exp_t(u, t):
    """Compute exp_t for `u'."""
    if t==1:
        return u.exp()
    else:
        return (1.0 + (1.0-t)*u).relu().pow(1.0 / (1.0 - t))

def compute_normalization_fixed_point(activations, t, num_iters):

    """Returns the normalization value for each example (t > 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same shape as activation with the last dimension being 1.
    """
    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations_step_0 = activations - mu

    normalized_activations = normalized_activations_step_0

    for _ in range(num_iters):
        logt_partition = torch.sum(
                exp_t(normalized_activations, t), -1, keepdim=True)
        normalized_activations = normalized_activations_step_0 * \
                logt_partition.pow(1.0-t)

    logt_partition = torch.sum(
            exp_t(normalized_activations, t), -1, keepdim=True)
    normalization_constants = - log_t(1.0 / logt_partition, t) + mu

    return normalization_constants

def compute_normalization_binary_search(activations, t, num_iters):

    """Returns the normalization value for each example (t < 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (< 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """

    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations = activations - mu

    effective_dim = \
        torch.sum(
                (normalized_activations > -1.0 / (1.0-t)).to(torch.int32),
            dim=-1, keepdim=True).to(activations.dtype)

    shape_partition = activations.shape[:-1] + (1,)
    lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device)
    upper = -log_t(1.0/effective_dim, t) * torch.ones_like(lower)

    for _ in range(num_iters):
        logt_partition = (upper + lower)/2.0
        sum_probs = torch.sum(
                exp_t(normalized_activations - logt_partition, t),
                dim=-1, keepdim=True)
        update = (sum_probs < 1.0).to(activations.dtype)
        lower = torch.reshape(
                lower * update + (1.0-update) * logt_partition,
                shape_partition)
        upper = torch.reshape(
                upper * (1.0 - update) + update * logt_partition,
                shape_partition)

    logt_partition = (upper + lower)/2.0
    return logt_partition + mu

class ComputeNormalization(torch.autograd.Function):
    """
    Class implementing custom backward pass for compute_normalization. See compute_normalization.
    """
    @staticmethod
    def forward(ctx, activations, t, num_iters):
        if t < 1.0:
            normalization_constants = compute_normalization_binary_search(activations, t, num_iters)
        else:
            normalization_constants = compute_normalization_fixed_point(activations, t, num_iters)

        ctx.save_for_backward(activations, normalization_constants)
        ctx.t=t
        return normalization_constants

    @staticmethod
    def backward(ctx, grad_output):
        activations, normalization_constants = ctx.saved_tensors
        t = ctx.t
        normalized_activations = activations - normalization_constants 
        probabilities = exp_t(normalized_activations, t)
        escorts = probabilities.pow(t)
        escorts = escorts / escorts.sum(dim=-1, keepdim=True)
        grad_input = escorts * grad_output
        
        return grad_input, None, None

def compute_normalization(activations, t, num_iters=5):
    """Returns the normalization value for each example. 
    Backward pass is implemented.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    return ComputeNormalization.apply(activations, t, num_iters)

def tempered_sigmoid(activations, t, num_iters = 5):
    """Tempered sigmoid function.
    Args:
      activations: Activations for the positive class for binary classification.
      t: Temperature tensor > 0.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_probabilities = tempered_softmax(internal_activations, t, num_iters)
    return internal_probabilities[..., 0]


def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature > 1.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    if t == 1.0:
        return activations.softmax(dim=-1)

    normalization_constants = compute_normalization(activations, t, num_iters)
    return exp_t(activations - normalization_constants, t)

def bi_tempered_binary_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing = 0.0,
        num_iters=5,
        reduction='mean'):

    """Bi-Tempered binary logistic loss.
    Args:
      activations: A tensor containing activations for class 1.
      labels: A tensor with shape as activations, containing probabilities for class 1
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing
      num_iters: Number of iterations to run the method.
    Returns:
      A loss tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_labels = torch.stack([labels.to(activations.dtype),
        1.0 - labels.to(activations.dtype)],
        dim=-1)
    return bi_tempered_logistic_loss(internal_activations, 
            internal_labels,
            t1,
            t2,
            label_smoothing = label_smoothing,
            num_iters = num_iters,
            reduction = reduction)

def bi_tempered_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing=0.0,
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
            - labels_onehot * log_t(probabilities, t1) \
            - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
            + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim = -1) #sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()
    
    

def get_probs(activations,
        labels,
        t1,
        t2,
        label_smoothing=0.0,
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)
    return probabilities

# Training and Validation

1. when we will call train_one_epoch() function, we will be training our model **for 1 epoch** every time using the train_loader that we have prepared above for 8 core tpu training,**from *train_loader* we are using batch images and labels and later feeding them to our model,calculating loss,doing backpropagation,doing gradient accumulation,taking optimizer step and scheduler step**

2. valid_one_epoch() function uses our validation dataloader that we have created above for 8 core training on tpu.it calculates validation loss and validation accuracy using our validation dataset and returns validation accuracy, while calculatin validation accuracy we should do this : **accuracy = xm.mesh_reduce('test_accuracy', acc, np.mean)**,if you don't do this then you won't be able to save best weight file for each fold and xm.save() will hang for forever(i made this silly mistake few days ago while working on this kernel,so keep this in mind)

In [None]:
def train_one_epoch(epoch, model, optimizer, train_loader, device, scheduler=None, schd_batch_update=False):
    model.train()

    t = time.time()

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    z = 0
    for step, (imgs, image_labels) in pbar:
        z = z + 1
        if z % 20 == 0:
            gc.collect()
        
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        with autocast():
            image_preds = model(imgs)   #output = model(input)

            #loss = loss_fn(image_preds, image_labels)
            loss = bi_tempered_logistic_loss(image_preds, image_labels, t1=CFG['t1'], t2=CFG['t2'], label_smoothing=CFG['smoothing'])
            
            #scaler.scale(loss).backward()
            
            loss.backward()

            if ((step + 1) %  CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)):
                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

                #scaler.step(optimizer)
                xm.optimizer_step(optimizer)
                #scaler.update()
                optimizer.zero_grad() 
                
                if scheduler is not None and schd_batch_update:
                    scheduler.step()
    pbar.close()
    if scheduler is not None and not schd_batch_update:
        scheduler.step()
        
def valid_one_epoch(epoch, model, val_loader, device):
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []

    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds = model(imgs)   #output = model(input)
        xm.mark_step()
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]
        
        #loss = loss_fn(image_preds, image_labels)
        loss = bi_tempered_logistic_loss(image_preds, image_labels, t1=CFG['t1'], t2=CFG['t2'], label_smoothing=CFG['smoothing'])
        
        #loss_sum += loss*image_labels.shape[0]
        #sample_num += image_labels.shape[0]  

    pbar.close()
    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    acc = (image_preds_all==image_targets_all).mean()
    #LOGGER.debug('validation multi-class accuracy = {:.4f}'.format(acc))
    accuracy = xm.mesh_reduce('test_accuracy', acc, np.mean)
    xm.master_print("Validation Accuracy = ",accuracy)
   
    return loss,accuracy

**CassvaImgClassifier()** class below is used for loading pretrained **vit_base_patch16_384** model from [timm](https://github.com/rwightman/pytorch-image-models/tree/master/timm)

In [None]:
class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        #print(self.model)
        n_features = self.model.head.in_features
        self.model.head = nn.Linear(n_features, n_class)

    def forward(self, x):
        x = self.model(x)
        return x

# Loading the Model

In [None]:
model = CassvaImgClassifier(CFG['model_arch'], train.label.nunique(), pretrained=True)

# Train Model

**In order to train the model, we need to spawn the training processes on each of the TPU cores.we will call train_model() function during spawn call**

we use StratifiedKFold for reducing overfitting,calling prepare_dataloader() function for preparing train and valid loader,Adam optimizer,CosineAnnealingWarmRestarts scheduler,CrossEntropyLoss as our loss function and inside epoch loop saving best weight of each fold and doing **train_one_epoch() and valid_one_epoch()** as discussed above

In [None]:
def train_model(folds = range(0, 5)):
    device = xm.xla_device()
    model.to(device)
        
    
    for fold in folds:
        trn_idx = train[train['fold'] == fold]
        val_idx = val[val['fold'] == fold]
        trn_idx = pd.concat([trn_idx, val_idx])

        lr =  CFG['lr']* xm.xrt_world_size()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr/CFG['warmup_factor'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['T_0'], T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)

        best_accuracy = 0.0
        
        for epoch in range(CFG['epochs']):
            gc.collect()
            train_loader, val_loader = prepare_dataloader(trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/')
            para_loader = pl.ParallelLoader(train_loader, [device])
            train_one_epoch(epoch, model, optimizer, para_loader.per_device_loader(device), device, scheduler=scheduler, schd_batch_update=False)
    
            del para_loader
            del(train_loader)
            gc.collect()
            para_loader = pl.ParallelLoader(val_loader, [device])
            val_loss,cur_accuracy = valid_one_epoch(epoch, model, para_loader.per_device_loader(device), device)
            
            del para_loader
            del(val_loader)
            gc.collect()

            content = time.ctime() + ' ' + f'FOLD -> {fold} --> Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val_loss : {(val_loss):.5f},Validation_Accuracy: {(cur_accuracy):.5f}'

            with open(f'log.txt', 'a') as appender:
                appender.write(content + '\n')
            
            xm.save(model.state_dict(),'{}_fold_{}_best_epoch_{}_Validation_Accuracy.h5'.format(CFG['model_arch'], epoch,cur_accuracy))
            
             
            if cur_accuracy >= best_accuracy:
                xm.save(model.state_dict(),'{}_fold_{}_best_epoch'.format(CFG['model_arch'], fold))
                best_accuracy = cur_accuracy
                    
        
        

# Start training processes

In [None]:
%%time

#important, we specify here the fold to train (one out of [0,4])
folds_to_train  = [1]

def _mp_fn(rank, flags):
    global acc_list
    torch.set_default_tensor_type('torch.FloatTensor')
    res = train_model(folds=folds_to_train)

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

In [None]:
gc.collect()
torch.cuda.empty_cache()
try:
    del(model)
    
except:
    pass
gc.collect()

# Full Training Log

In [None]:
f = open(f'./log.txt', "r")

print(f.read())