## Install and Set-Up

In [None]:
dev = 'GPU'
if dev == 'TPU':
    !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py  > /dev/null
    !python pytorch-xla-env-setup.py --version nightly  > /dev/null
!pip install timm  > /dev/null
!wget https://raw.githubusercontent.com/davda54/sam/main/sam.py
!pip install torch_optimizer

## Imports

In [None]:
##################################### Standard Imports #####################################
import os
import json
import gc
import pickle
import struct
from time import time
from tqdm import tqdm
from uuid import uuid4
from typing import List, Optional

###################################### Data Handlers #######################################
import numpy as np
import pandas as pd

##################################### Image Handlers #######################################
import cv2
import PIL
from PIL import Image
import albumentations as A

####################################### Optimizers #########################################
from sam import SAM
import torch_optimizer as optim

##################################### Plotting Tools #######################################
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.figure_factory as ff
import plotly.express as px
from plotly.subplots import make_subplots

################################## Data Processing Fxns ####################################
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn import metrics

##################################### Torch Imports ########################################
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

################################### Torch XLA Imports ######################################
if dev == 'TPU':
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
    import torch_xla.utils.serialization as xser
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.data_parallel as dp
    import torch_xla.utils.utils as xu
    import torch_xla.test.test_utils as test_utils

##################################### Other Imports #######################################
import timm
import warnings
warnings.filterwarnings("ignore")
if dev == 'TPU':
    os.environ['XLA_USE_BF16']="1"
    os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

## Helpers

In [None]:
class TwoWayDict:
    def __init__(self, input_dict):
        self.d1 = input_dict
        try:
            self.d2 = {v: k for k, v in input_dict.items()}
        except Exception:
            raise ('Duplicate Key')

    def get(self, key):
        if key in self.d1:
            return self.d1[key]
        if key in self.d2:
            return self.d2[key]
        return None

    def length(self):
        return len(self.d1)

## Load Data

In [None]:
df = pd.read_csv('../input/cassava-leaf-disease-merged/merged.csv')
df = df.sample(frac=1).reset_index(drop=True)
with open('../input/cassava-leaf-disease-classification/label_num_to_disease_map.json', 'r') as fp:
    label_map = json.load(fp)
class_dict = TwoWayDict(label_map)

y_cols = df.label.unique().tolist()
y_cols.sort()
display(df.head())
display(label_map)

In [None]:
odf = pd.concat([df, pd.get_dummies(df.label)], 1)
splitter = StratifiedShuffleSplit(n_splits=1, train_size=0.95, test_size=0.05, random_state=42)
train_indices, valid_indices = next(splitter.split(odf.label.values, odf[y_cols].values))

train_df = odf.iloc[train_indices]
valid_df = odf.iloc[valid_indices]

In [None]:
class LeafDataset(Dataset):
    def __init__(self, df, aug=None, transform=None, encoding="Label", process_first=False, path_prefix=""):
        self.y_cols = y_cols
        self.image_paths = df.image_id.values
        if encoding == "Label":
            self.enc = LabelEncoder()
            self.enc.fit(df.label)
            self.targets = self.enc.transform(df.label)
        elif encoding == "OneHot":
            self.targets = pd.get_dummies(df.label).values
        else:
            raise "Unsupprted Target Encoding"
        self.aug = aug
        self.transform = transform
        self.path_prefix = path_prefix
        self.process_first = process_first
        self.process_dict = {}
        
        if self.process_first:
            if not os.path.exists('./temp_data'):
                os.makedirs('./temp_data')
            for path in tqdm(self.image_paths):
                save_name = './temp_data/' + uuid4().hex + '.tnsr'
                self.process_dict[path] = save_name
                image = cv2.imread(self.path_prefix  + path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                if self.aug is not None:
                    augmented = self.aug(image=image)
                    image = augmented["image"]
                image = np.transpose(image, (2, 0, 1)).astype(np.float32)
                image = torch.tensor(image, dtype=torch.float32)
                if self.transform is not None:
                    image = self.transform(image)
                torch.save(image, save_name)
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        if self.process_first:
            image = torch.load(self.process_dict[self.image_paths[index]])
        else:
            image = cv2.imread(self.path_prefix + self.image_paths[index])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            if self.aug is not None:
                augmented = self.aug(image=image)
                image = augmented["image"]
            image = np.transpose(image, (2, 0, 1)).astype(np.float32)
            image = torch.tensor(image, dtype=torch.float32)
            if self.transform is not None:
                image = self.transform(image)
        targets = self.targets[index]
        return {
            "image": image,
            "targets": torch.tensor(targets, dtype=torch.long),
        }

In [None]:
IMG_SIZE = 224
train_transforms = transforms.Compose([
#     transforms.RandomApply([
#         transforms.CenterCrop((IMG_SIZE, IMG_SIZE))
#     ], p=0.5),
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=3),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomPerspective(0.3, p=0.5),
    transforms.RandomHorizontalFlip(p=0.5),
#     transforms.RandomApply([
#         transforms.ColorJitter(0.2, 0.1, 0.1, 0.05),
#         transforms.RandomRotation(45),
#     ], p=0.5),
    transforms.Normalize(mean=[0.42984136, 0.49624753, 0.3129598], std=[0.21417203, 0.21910103, 0.19542212]),
])

valid_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=3),
    transforms.Normalize(mean=[0.42984136, 0.49624753, 0.3129598], std=[0.21417203, 0.21910103, 0.19542212]),
])


train_aug = A.Compose(
        [
            A.Resize(IMG_SIZE, IMG_SIZE, interpolation=4, always_apply=True),
            A.Transpose(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.HueSaturationValue(
                hue_shift_limit=0.2, 
                sat_shift_limit=0.2, 
                val_shift_limit=0.2,
                p=0.5
            ),
            A.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1), 
                p=0.5
            ),
            A.CoarseDropout(p=0.5),
            A.Cutout(p=0.5)
        ]
    )

valid_aug = A.Compose(
        [
            A.Resize(IMG_SIZE, IMG_SIZE, interpolation=4, always_apply=True),
        ]
)

train_dataset = LeafDataset(train_df, encoding="Label", transform=train_transforms, aug=None, path_prefix="../input/cassava-leaf-disease-merged/train/") 
valid_dataset = LeafDataset(valid_df, encoding="Label", transform=valid_transforms, aug=None, path_prefix="../input/cassava-leaf-disease-merged/train/")

## Load Model

In [None]:
# timm.create_model(f"vit_base_patch32_384", pretrained=False)
# timm.list_models()
# timm.create_model(f"tf_efficientnet_b1_ns", pretrained=False)

In [None]:
def get_model():
    effnet = timm.create_model(f"tf_efficientnet_b1_ns", pretrained=True)

    effnet.classifier = nn.Sequential(
#         nn.Linear(1536, 128),  # B3
        nn.Linear(1280, 384),  # B4
        nn.Dropout(0.1),
        nn.Linear(384, 5),
        nn.Softmax(1)
    )

    for param in effnet.parameters():
        param.requires_grad = True

    def set_bn_eval(m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.eval()

#     effnet = effnet.apply(set_bn_eval)
    return effnet

## Training Helpers 

### Helpers

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

### Loss

In [None]:
def log_t(u, t):
    """Compute log_t for `u'."""
    if t==1.0:
        return u.log()
    else:
        return (u.pow(1.0 - t) - 1.0) / (1.0 - t)

def exp_t(u, t):
    """Compute exp_t for `u'."""
    if t==1:
        return u.exp()
    else:
        return (1.0 + (1.0-t)*u).relu().pow(1.0 / (1.0 - t))

def compute_normalization_fixed_point(activations, t, num_iters):

    """Returns the normalization value for each example (t > 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same shape as activation with the last dimension being 1.
    """
    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations_step_0 = activations - mu

    normalized_activations = normalized_activations_step_0

    for _ in range(num_iters):
        logt_partition = torch.sum(
                exp_t(normalized_activations, t), -1, keepdim=True)
        normalized_activations = normalized_activations_step_0 * \
                logt_partition.pow(1.0-t)

    logt_partition = torch.sum(
            exp_t(normalized_activations, t), -1, keepdim=True)
    normalization_constants = - log_t(1.0 / logt_partition, t) + mu

    return normalization_constants

def compute_normalization_binary_search(activations, t, num_iters):

    """Returns the normalization value for each example (t < 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (< 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """

    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations = activations - mu

    effective_dim = \
        torch.sum(
                (normalized_activations > -1.0 / (1.0-t)).to(torch.int32),
            dim=-1, keepdim=True).to(activations.dtype)

    shape_partition = activations.shape[:-1] + (1,)
    lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device)
    upper = -log_t(1.0/effective_dim, t) * torch.ones_like(lower)

    for _ in range(num_iters):
        logt_partition = (upper + lower)/2.0
        sum_probs = torch.sum(
                exp_t(normalized_activations - logt_partition, t),
                dim=-1, keepdim=True)
        update = (sum_probs < 1.0).to(activations.dtype)
        lower = torch.reshape(
                lower * update + (1.0-update) * logt_partition,
                shape_partition)
        upper = torch.reshape(
                upper * (1.0 - update) + update * logt_partition,
                shape_partition)

    logt_partition = (upper + lower)/2.0
    return logt_partition + mu

class ComputeNormalization(torch.autograd.Function):
    """
    Class implementing custom backward pass for compute_normalization. See compute_normalization.
    """
    @staticmethod
    def forward(ctx, activations, t, num_iters):
        if t < 1.0:
            normalization_constants = compute_normalization_binary_search(activations, t, num_iters)
        else:
            normalization_constants = compute_normalization_fixed_point(activations, t, num_iters)

        ctx.save_for_backward(activations, normalization_constants)
        ctx.t=t
        return normalization_constants

    @staticmethod
    def backward(ctx, grad_output):
        activations, normalization_constants = ctx.saved_tensors
        t = ctx.t
        normalized_activations = activations - normalization_constants 
        probabilities = exp_t(normalized_activations, t)
        escorts = probabilities.pow(t)
        escorts = escorts / escorts.sum(dim=-1, keepdim=True)
        grad_input = escorts * grad_output
        
        return grad_input, None, None

def compute_normalization(activations, t, num_iters=5):
    """Returns the normalization value for each example. 
    Backward pass is implemented.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    return ComputeNormalization.apply(activations, t, num_iters)

def tempered_sigmoid(activations, t, num_iters = 5):
    """Tempered sigmoid function.
    Args:
      activations: Activations for the positive class for binary classification.
      t: Temperature tensor > 0.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_probabilities = tempered_softmax(internal_activations, t, num_iters)
    return internal_probabilities[..., 0]


def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature > 1.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    if t == 1.0:
        return activations.softmax(dim=-1)

    normalization_constants = compute_normalization(activations, t, num_iters)
    return exp_t(activations - normalization_constants, t)

def bi_tempered_binary_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing = 0.0,
        num_iters=5,
        reduction='mean'):

    """Bi-Tempered binary logistic loss.
    Args:
      activations: A tensor containing activations for class 1.
      labels: A tensor with shape as activations, containing probabilities for class 1
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing
      num_iters: Number of iterations to run the method.
    Returns:
      A loss tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_labels = torch.stack([labels.to(activations.dtype),
        1.0 - labels.to(activations.dtype)],
        dim=-1)
    return bi_tempered_logistic_loss(internal_activations, 
            internal_labels,
            t1,
            t2,
            label_smoothing = label_smoothing,
            num_iters = num_iters,
            reduction = reduction)

def bi_tempered_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing=0.0,
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
            - labels_onehot * log_t(probabilities, t1) \
            - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
            + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim = -1) #sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()

In [None]:
class FocalLoss(nn.Module):
    """ Focal Loss, as described in https://arxiv.org/abs/1708.02002.
    It is essentially an enhancement to cross entropy loss and is
    useful for classification tasks when there is a large class imbalance.
    x is expected to contain raw, unnormalized scores for each class.
    y is expected to contain class labels.
    Shape:
        - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
        - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
    """

    def __init__(self,
                 alpha: Optional[Tensor] = None,
                 gamma: float = 0.,
                 reduction: str = 'mean',
                 ignore_index: int = -100):
        """Constructor.
        Args:
            alpha (Tensor, optional): Weights for each class. Defaults to None.
            gamma (float, optional): A constant, as described in the paper.
                Defaults to 0.
            reduction (str, optional): 'mean', 'sum' or 'none'.
                Defaults to 'mean'.
            ignore_index (int, optional): class label to ignore.
                Defaults to -100.
        """
        if reduction not in ('mean', 'sum', 'none'):
            raise ValueError(
                'Reduction must be one of: "mean", "sum", "none".')

        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.reduction = reduction

        self.nll_loss = nn.NLLLoss(
            weight=alpha, reduction='none', ignore_index=ignore_index)

    def __repr__(self):
        arg_keys = ['alpha', 'gamma', 'ignore_index', 'reduction']
        arg_vals = [self.__dict__[k] for k in arg_keys]
        arg_strs = [f'{k}={v}' for k, v in zip(arg_keys, arg_vals)]
        arg_str = ', '.join(arg_strs)
        return f'{type(self).__name__}({arg_str})'

    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        if x.ndim > 2:
            # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
            c = x.shape[1]
            x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
            # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
            y = y.view(-1)

        unignored_mask = y != self.ignore_index
        y = y[unignored_mask]
        if len(y) == 0:
            return 0.
        x = x[unignored_mask]

        # compute weighted cross entropy term: -alpha * log(pt)
        # (alpha is already part of self.nll_loss)
        log_p = F.log_softmax(x, dim=-1)
        ce = self.nll_loss(log_p, y)

        # get true class column from each row
        all_rows = torch.arange(len(x))
        log_pt = log_p[all_rows, y]

        # compute focal term: (1 - pt)^gamma
        pt = log_pt.exp()
        focal_term = (1 - pt)**self.gamma

        # the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
        loss = focal_term * ce

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()

        return loss

In [None]:
class LabelSmoothingLoss(nn.Module): 
    def __init__(self, classes=5, smoothing=0.0, dim=-1): 
        super(LabelSmoothingLoss, self).__init__() 
        self.confidence = 1.0 - smoothing 
        self.smoothing = smoothing 
        self.cls = classes 
        self.dim = dim 
    def forward(self, pred, target): 
#         if CFG.criterion_name == 'LabelSmoothingLoss':
#             pred = pred.log_softmax(dim=self.dim) 
        with torch.no_grad():
            true_dist = torch.zeros_like(pred) 
            true_dist.fill_(self.smoothing / (self.cls - 1)) 
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

class TaylorSoftmax(nn.Module):
    '''
    This is the autograd version
    '''
    def __init__(self, dim=1, n=2):
        super(TaylorSoftmax, self).__init__()
        assert n % 2 == 0
        self.dim = dim
        self.n = n

    def forward(self, x):
        '''
        usage similar to nn.Softmax:
            >>> mod = TaylorSoftmax(dim=1, n=4)
            >>> inten = torch.randn(1, 32, 64, 64)
            >>> out = mod(inten)
        '''
        fn = torch.ones_like(x)
        denor = 1.
        for i in range(1, self.n+1):
            denor *= i
            fn = fn + x.pow(i) / denor
        out = fn / fn.sum(dim=self.dim, keepdims=True)
        return out
    
class TaylorCrossEntropyLoss(nn.Module):
    '''
    This is the autograd version
    '''
    def __init__(self, n=2, ignore_index=-1, reduction='mean'):
        super(TaylorCrossEntropyLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, logits, labels):
        '''
        usage similar to nn.CrossEntropyLoss:
            >>> crit = TaylorCrossEntropyLoss(n=4)
            >>> inten = torch.randn(1, 10, 64, 64)
            >>> label = torch.randint(0, 10, (1, 64, 64))
            >>> out = crit(inten, label)
        '''
        log_probs = self.taylor_softmax(logits).log()
        loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
                ignore_index=self.ignore_index)
        return loss
    
class TaylorSmoothedLoss(nn.Module):

    def __init__(self, n=2, classes=5, ignore_index=-1, reduction='mean', smoothing=0.2):
        super(TaylorSmoothedLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.lab_smooth = LabelSmoothingLoss(classes, smoothing=smoothing)

    def forward(self, logits, labels):

        log_probs = self.taylor_softmax(logits).log()
        #loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
        #        ignore_index=self.ignore_index)
        loss = self.lab_smooth(log_probs, labels)
        return loss

### Optimizer

In [None]:
class SAM_XLA(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM_XLA, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            
            def update_state(p):
                if p.grad is not None:
                    e_w = p.grad * scale.to(p)
                    p.add_(e_w)  # climb to the local maximum "w + e(w)"
                    self.state[p]["e_w"] = e_w
            
            map(update_state, group["params"])
            
#             for p in group["params"]:
#                 if p.grad is None: continue
#                 e_w = p.grad * scale.to(p)
#                 p.add_(e_w)  # climb to the local maximum "w + e(w)"
#                 self.state[p]["e_w"] = e_w
        if zero_grad: self.zero_grad()
    
    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            
            def update_p(p):
                if p.grad is not None:
                    p.sub_(self.state[p]["e_w"])
            
            map(update_p, group["params"])
            
#             for p in group["params"]:
#                 if p.grad is None: continue
#                 p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        stack = [torch.tensor(list(map(lambda p: p.grad.norm(p=2).to(shared_device), list(filter(lambda p: p.grad is not None, group["params"]))))) for group in self.param_groups]  # torch.tensor(list(map(lambda p: p.grad.norm(p=2).to(shared_device) if p.grad is not None else p, group["params"])))
        norm = torch.norm(torch.stack(stack), p=2)
        return norm
    
#     def _grad_norm(self):
#         shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
#         print(shared_device)
#         norm = torch.norm(
#                     torch.stack([
#                         p.grad.norm(p=2).to(shared_device)
#                         for group in self.param_groups for p in group["params"]
#                         if p.grad is not None
#                     ]),
#                     p=2
#                )
#         return norm

### Wrappers

In [None]:
def get_optimizer(model):
    base_optimizer = optim.RAdam
    optimizer = SAM_XLA(model.parameters(), base_optimizer, lr=0.0002)
#     optimizer = optim.RAdam(model.parameters(), lr=0.0002)
    return optimizer
def get_scheduler(optimizer):
    return ReduceLROnPlateau(optimizer, mode='max', patience=3, cooldown=2)
def save_checkpoint(model, optimizer, path):
    if not os.path.exists(os.path.dirname(path)):
        print("Creating directories on path: `{}`".format(path))
        os.makedirs(os.path.dirname(path))
    torch.save({
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }, path)
def save_model(model, path):
    if not os.path.exists(os.path.dirname(path)):
        print("Creating directories on path: `{}`".format(path))
        os.makedirs(os.path.dirname(path))
    torch.save({
        "model_state_dict": model.state_dict(),
    }, path)
def load_checkpoint(model, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer = get_optimizer(model)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    return model, optimizer
def load_model(model, path):
    restore_dict = torch.load(path)
    model.load_state_dict(restore_dict["model_state_dict"])
    model.eval()
    return model

### Train & Eval Fxns

In [None]:
def train_model(data_loader, model, optimizer, scheduler, criterion, device, epoch=0, update_after=3):
    model.train()
    final_targets = []
    final_outputs = []
    losses = AverageMeter()
    for step, data in enumerate(data_loader):
        inputs = data["image"]
        targets = data["targets"]
        inputs = inputs.to(device)
        targets = targets.to(device)
        def closure():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            losses.update(loss.item(), config['batch_size'])
            loss.backward()
            return loss
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        losses.update(loss.item(), config['batch_size'])
        loss.backward
        if config['device_name'] == 'TPU':
            xm.optimizer_step(optimizer, optimizer_args={'closure': closure})
            optimizer.zero_grad()
        elif config['device_name'] == 'GPU':
            optimizer.step(closure)
            optimizer.zero_grad()
        del inputs
        targets = targets.detach().cpu().numpy().tolist()
        outputs = outputs.detach().cpu().numpy().tolist()
        
        final_targets.extend(targets)
        final_outputs.extend(outputs)
        if (step+ 1) % update_after == 0 or (step + 1) == len(data_loader):
            if config['device_name'] == 'TPU':
                # since the loss is on all 8 cores, reduce the loss values and print the average
                loss_reduced = xm.mesh_reduce('loss_reduce', losses.avg, lambda x: sum(x) / len(x))
                final_targets_tensor = torch.tensor(final_targets, device=device)
                final_outputs_tensor = torch.tensor(final_outputs, device=device)
                targets_reduced = xm.all_gather(final_targets_tensor, dim=0)
                outputs_reduced = xm.all_gather(final_outputs_tensor, dim=1)
                # master_print will only print once (not from all 8 cores)
                lr = optimizer.param_groups[0]['lr']
                accuracy = metrics.accuracy_score(np.array(targets_reduced.cpu()), np.array(outputs_reduced.cpu()).argmax(axis=1))
                cohen_kappa = metrics.cohen_kappa_score(np.array(targets_reduced.cpu()), np.array(outputs_reduced.cpu()).argmax(axis=1))
                try:
                    pass
                    roc_auc = metrics.roc_auc_score(np.array(targets_reduced.cpu()), np.array(outputs_reduced.cpu()), multi_class='ovr', labels=config['labels'])
                except Exception:
                    roc_auc = 0
                xm.master_print(f"Epoch {epoch + 1}/{config['epochs']}, Step {step + 1}/{len(data_loader)} :: Train Loss={loss_reduced}, LR={lr}, Accuracy={accuracy}, Cohen Kappa Score={cohen_kappa}, MultiClass ROC AUC={roc_auc}")
                # xm.master_print(f"loss_reduced: {loss_reduced}")
            elif config['device_name'] == 'GPU':
                lr = optimizer.param_groups[0]['lr']
                accuracy = round(metrics.accuracy_score(np.array(final_targets), np.array(final_outputs).argmax(axis=1)), 5)
                cohen_kappa = round(metrics.cohen_kappa_score(np.array(final_targets), np.array(final_outputs).argmax(axis=1)), 5)
                try:
                    roc_auc = round(metrics.roc_auc_score(np.array(final_targets), np.array(final_outputs), multi_class='ovr', labels=config['labels']), 5)
                except Exception:
                    roc_auc = 0
                print(f"Epoch {epoch + 1}/{config['epochs']}, Step {step + 1}/{len(data_loader)} :: Train Loss={losses.avg}, LR={lr}, Accuracy={accuracy}, Cohen Kappa Score={cohen_kappa}, MultiClass ROC AUC={roc_auc}", end='\r')
        
        del targets, outputs
    gc.collect() # delete for memory conservation

    scheduler.step(loss)
    if config['device_name'] == 'GPU':
        torch.cuda.empty_cache()
        print(f"Epoch {epoch + 1}/{config['epochs']}, Step {step + 1}/{len(data_loader)} :: Train Loss={losses.avg}, LR={lr}, Accuracy={accuracy}, Cohen Kappa Score={cohen_kappa}, MultiClass ROC AUC={roc_auc}")
    return losses.avg, accuracy, cohen_kappa, roc_auc

In [None]:
def evaluate_model(data_loader, model, criterion, device, epoch=0, update_after=3):
    model.eval()
    final_targets = []
    final_outputs = []
    losses = AverageMeter()
    for step, data in enumerate(data_loader):      
        inputs = data["image"]
        targets = data["targets"]
        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        losses.update(loss.item(), config['batch_size'])
        del inputs
        targets = targets.detach().cpu().numpy().tolist()
        outputs = outputs.detach().cpu().numpy().tolist()
        
        final_targets.extend(targets)
        final_outputs.extend(outputs)
        if (step+ 1) % update_after == 0 or (step + 1) == len(data_loader):
            if config['device_name'] == 'TPU':
                # since the loss is on all 8 cores, reduce the loss values and print the average
                loss_reduced = xm.mesh_reduce('loss_reduce', losses.avg, lambda x: sum(x) / len(x)) 
                targets_reduced = xm.all_gather(final_targets, dim=0)
                outputs_reduced = xm.all_gather(final_outputs, dim=1)
                # master_print will only print once (not from all 8 cores)
                accuracy = metrics.accuracy_score(np.array(targets_reduced), np.array(outputs_reduced).argmax(axis=1))
                cohen_kappa = metrics.cohen_kappa_score(np.array(targets_reduced), np.array(outputs_reduced).argmax(axis=1))
                try:
                    roc_auc = metrics.roc_auc_score(np.array(targets_reduced), np.array(outputs_reduced), multi_class='ovr', labels=config['labels'])
                except Exception:
                    roc_auc = 0
                xm.master_print(f"Epoch {epoch + 1}/{config['epochs']}, Step {step + 1}/{len(data_loader)} :: Train Loss={loss_reduced}, Accuracy={accuracy}, Cohen Kappa Score={cohen_kappa}, MultiClass ROC AUC={roc_auc}")

            elif config['device_name'] == 'GPU':
                accuracy = round(metrics.accuracy_score(np.array(final_targets), np.array(final_outputs).argmax(axis=1)), 5)
                cohen_kappa = round(metrics.cohen_kappa_score(np.array(final_targets), np.array(final_outputs).argmax(axis=1)), 5)
                try:
                    roc_auc = round(metrics.roc_auc_score(np.array(final_targets), np.array(final_outputs), multi_class='ovr', labels=config['labels']), 5)
                except Exception:
                    roc_auc = 0
                print(f"Epoch {epoch + 1}/{config['epochs']}, Step {step + 1}/{len(data_loader)} :: Valid Loss={losses.avg}, Accuracy={accuracy}, Cohen Kappa Score={cohen_kappa}, MultiClass ROC AUC={roc_auc}", end='\r')

    del targets, outputs
    gc.collect() # delete for memory conservation
    if config['device_name'] == 'GPU':
        torch.cuda.empty_cache()
        print(f"Epoch {epoch + 1}/{config['epochs']}, Step {step + 1}/{len(data_loader)} :: Valid Loss={losses.avg}, Accuracy={accuracy}, Cohen Kappa Score={cohen_kappa}, MultiClass ROC AUC={roc_auc}")
    return losses.avg, accuracy, cohen_kappa, roc_auc

## Training

In [None]:
def train_model_with_data(rank, model, train_dataset, valid_dataset):
    
    ######################################################################
    ######################## TRAINING DEPENDENCIES #######################
    ######################################################################
    
    device = config['device'](config['device_name'])
    if config['device_name'] == 'TPU':
        MX = xmp.MpModelWrapper(model)
        model = MX.to(device)
    elif config['device_name'] == 'GPU':
        model.to(device)
    
    optimizer = config['optimizer'](model)
    scheduler = config['scheduler'](optimizer)
    criterion = config['loss_fn']
    
    gc.collect()  # delete for memory conservation
    
    ######################################################################
    ############################ LOAD DATA ###############################
    ######################################################################
    
    if config['device_name'] == 'TPU':
        # special sampler needed for distributed/multi-core (divides dataset among the replicas/cores/devices)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(), #divide dataset among this many replicas
            rank=xm.get_ordinal(), #which replica/device/core
            shuffle=True)

        # define DataLoader with the defined sampler
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            sampler=train_sampler,
            num_workers=config['num_workers'],
            drop_last=True,
            prefetch_factor=8,
            persistent_workers=True
        )

        # same as train but with valid data
        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dat48aset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False)

        valid_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=config['batch_size'],
            sampler=valid_sampler,
            num_workers=config['num_workers'],
            drop_last=False,
            prefetch_factor=8,
            persistent_workers=True
        )

        train_loader = pl.MpDeviceLoader(train_loader, device) # puts the train data onto the current TPU core
        valid_loader = pl.MpDeviceLoader(valid_loader, device) # puts the valid data onto the current TPU core
    else:  # GPU and CPU
        train_loader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'], num_workers=config['num_workers'], pin_memory=False, shuffle=False, prefetch_factor=8, persistent_workers=True)
        valid_loader = DataLoader(dataset=valid_dataset, batch_size=config['batch_size'], num_workers=config['num_workers'], pin_memory=False, shuffle=False, prefetch_factor=8, persistent_workers=True)
    
    ######################################################################
    ############################## TRAIN LOOP ############################
    ######################################################################
    
    train_history = {
        "train_loss": [],
        "valid_loss": [],
        "train_accuracy": [],
        "valid_accuracy": [],
        "train_cohen_kappa": [],
        "valid_cohen_kappa": [],
        "train_roc_auc": [],
        "valid_roc_auc": []
    }
    
    for epoch in range(config['epochs']):
        
        train_loss, train_accuracy, train_cohen_kappa, train_roc_auc = train_model(train_loader, model, optimizer, scheduler, criterion, device, epoch, update_after=1)
        valid_loss, valid_accuracy, valid_cohen_kappa, valid_roc_auc = evaluate_model(valid_loader, model, criterion, device, epoch)
        
        train_history["train_loss"].append(train_loss)
        train_history["valid_loss"].append(valid_loss)
        train_history["train_accuracy"].append(train_accuracy)
        train_history["valid_accuracy"].append(valid_accuracy)
        train_history["train_cohen_kappa"].append(train_cohen_kappa)
        train_history["valid_cohen_kappa"].append(valid_cohen_kappa)
        train_history["train_roc_auc"].append(train_roc_auc)
        train_history["valid_roc_auc"].append(valid_roc_auc)
        
        save_checkpoint(model, optimizer, f'./effnet_{epoch}.pth')
    
    return train_history

In [None]:
config = {       # should actually call it dependencies
    'optimizer': lambda _model: get_optimizer(_model),
    'scheduler': lambda optimizer: get_scheduler(optimizer),
#     'loss_fn': lambda outputs, targets: bi_tempered_logistic_loss(outputs, targets, 0.7, 1.3, label_smoothing=0.2),
#     'loss_fn': lambda outputs, targets: FocalLoss().to(config['device'](config['device_name']))(outputs, targets),
    'loss_fn': lambda outputs, targets: TaylorSmoothedLoss(classes=5, smoothing=0.2).to(config['device'](config['device_name']))(outputs, targets),
    'device': lambda d: xm.xla_device() if d == 'TPU' else torch.device('cuda' if torch.cuda.is_available() and d == 'GPU' else 'cpu'),
    'device_name': dev,
    'epochs': 25,
    'num_workers': 4,
    'batch_size': 48,
    'labels': [0.0, 1.0, 2.0, 3.0, 4.0]
}

In [None]:
if config['device_name'] == 'GPU':
    torch.cuda.empty_cache()
#     model, _ = load_checkpoint(get_model(), '../input/cassavapytorcheffnettrain/effnet_9.pth')
    model = get_model()
    history = train_model_with_data(0, model, train_dataset, valid_dataset)
elif config['device_name'] == 'TPU':
#     xmp.spawn(train_model_with_data, args=(get_model(), train_dataset, valid_dataset), nprocs=1, start_method='fork')
    xmp.spawn(train_model_with_data, args=(get_model(), train_dataset, valid_dataset), nprocs=config['num_workers'], start_method='fork')

In [None]:
def plot_fold(history, title):

    fig = make_subplots(specs=[[{"secondary_y": True}]])

    comx = np.arange(len(history['train_loss']))
    trace = (        
          px.line(history, x=comx, y='train_loss') 
         .add_trace(px.line(history, x=comx, y='valid_loss').data[0]) 
         .add_trace(px.line(history, x=comx, y='train_accuracy').data[0])
         .add_trace(px.line(history, x=comx, y='valid_accuracy').data[0]) 
         
        ).data

    
    fig.add_trace(trace[0], secondary_y=False) 
    fig.add_trace(trace[1], secondary_y=False) 
    fig.add_trace(trace[2], secondary_y=True)
    fig.add_trace(trace[3], secondary_y=True)

    
    fig.data[0].line.dash='dash';fig.data[0].mode ='markers+lines';fig.data[0].line.color='#2ca02c';fig.data[0].line.width=3;fig.data[0].hovertemplate=None;fig.data[0].name='train loss' 
    fig.data[1].line.dash='dash';fig.data[1].mode ='markers+lines';fig.data[1].line.color='#d62728';fig.data[1].line.width=3;fig.data[1].hovertemplate=None;fig.data[1].name='valid loss'
    fig.data[2].line.dash='dashdot';fig.data[2].mode ='markers+lines';fig.data[2].line.color='#ff7f0e';fig.data[2].line.width=3;fig.data[2].hovertemplate=None;fig.data[2].name='train accuracy'
    fig.data[3].line.dash='dashdot';fig.data[3].mode ='markers+lines';fig.data[3].line.color='#1f77b4';fig.data[3].line.width=3;fig.data[3].hovertemplate=None;fig.data[3].name='valid accuracy'
    
    
    # Set x-axis title
    fig.update_xaxes(title_text="Epoch")

    # Set y-axes titles
    fig.update_yaxes(title_text="Loss", secondary_y=False)    
    fig.update_yaxes(title_text="Accuracy", secondary_y=True)
    fig.update_layout(height=450, margin=dict(r=5, t=50, b=50, l=5), title_text='<b>'+title+'</b>', title_font_size=12, legend=dict(orientation='h',yanchor='top',y=1.03,xanchor='left',x=0.15))
    fig.update_layout(font_size=12)
    fig.for_each_annotation(lambda a: a.update(font=dict(size=14)))
    fig.update_layout(hovermode="x unified")
    fig.update_traces(showlegend=True)
    
    fig.show()
    

    
def plot_confusion_matrix(label, pred):
    c_matrix = metrics.confusion_matrix(label, pred, labels=range(len(LABELS)), normalize='true')
    df = pd.DataFrame(c_matrix, index=LABELS, columns=LABELS)
    df_text = np.around(df.values, decimals=2)

    fig = ff.create_annotated_heatmap(df.values, annotation_text=df_text, x=LABELS, y=LABELS, colorscale='PuBu' )
    fig.update_layout(font_size=9, height=450, margin=dict(r=5, t=50, b=50, l=5)) 
    
    fig.show()       

In [None]:
plot_fold(history, "Training")