In [None]:
!git clone https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer.git
!pip install adamp

In [None]:
import sys
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')
sys.path.append('./Ranger-Deep-Learning-Optimizer/ranger')
sys.path.append('../input/image-fmix/FMix-master')
# sys.setrecursionlimit(10**6)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm.notebook import tqdm
from pprint import pprint
import cv2, glob, time, random, os
import warnings
warnings.filterwarnings("ignore")

import timm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
# https://nvlabs.github.io/iccv2019-mixed-precision-tutorial/files/dusan_stosic_intro_to_mixed_precision_training.pdf
# https://analyticsindiamag.com/pytorch-mixed-precision-training/
# https://pytorch.org/docs/stable/notes/amp_examples.html
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from torch.optim import Adam, AdamW, SGD

import albumentations as A
from albumentations.pytorch import ToTensorV2
from fmix import sample_mask

from adamp import AdamP
# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
from ranger import Ranger  # this is from ranger.py
from ranger913A import RangerVA  # this is from ranger913A.py
from rangerqh import RangerQH  # this is from rangerqh.py

from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold

In [None]:
# get the list of pretrained models
model_names = timm.list_models()
pprint(model_names)

<a id = "cont"></a>
## CFG

In [None]:
BATCH_SIZE = 16 # 8 for bigger architectures
VAL_BATCH_SIZE = 8
EPOCHS = 2 # train upto 10 epochs
IMG_SIZE = 512 # 384 for bigger architectures
ITER_FREQ = 500
NUM_WORKERS = 8
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
SEED = 1111
N_FOLDS = 5
TR_FOLDS = [0,1,2,3,4]
START_FOLD = 0

MODEL_PATH = None

LR = 1e-4
MIN_LR = 5e-5 # SAM, CosineAnnealingWarmRestarts
WEIGHT_DECAY = 1e-6
MOMENTUM = 0.9
T_0 = EPOCHS # SAM, CosineAnnealingWarmRestarts
MAX_NORM = 1000

SNAPMIX = False
SNAPMIX_ALPHA = 5.0
SNAPMIX_PCT = 0.5

MODEL_ARCH = 'tf_efficientnet_b3_ns' # tf_efficientnet_b4_ns, tf_efficientnet_b5_ns, resnext50_32x4d
ITERS_TO_ACCUMULATE = 1

CUTMIX = False
CM_START = 0
CM_ALPHA = 1
DECAY_POWER = 5

BASE_OPTIMIZER = SGD #for SAM, Ranger
OPTIMIZER = 'AdamW' # Ranger, Adam, AdamP, SGD, SAM

SCHEDULER = 'CosineAnnealingWarmRestarts' # ReduceLROnPlateau, CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR
SCHEDULER_UPDATE = 'epoch' # batch

CRITERION = 'TaylorCrossEntropyLoss' # CrossEntropyLoss, TaylorSmoothedLoss, LabelSmoothedLoss
LABEL_SMOOTH = 0.5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class AverageMeter(object):
    
    # Keeps track of most recent, average, sum, and count of a metric.
    
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.val = 0
        self.sum = 0
        self.avg = 0
        self.count = 0
        
    def update(self, val, n=1):
        self.val = val
        self.sum += val*n
        self.count += n
        self.avg = self.sum / self.count

def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_torch(SEED)

In [None]:
TRAIN_DIR = '../input/cassava-leaf-disease-merged/train/'
df = pd.read_csv('../input/cassava-leaf-disease-merged/merged.csv')
# df

In [None]:
folds = df.copy()
Fold = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds['label'])):
    folds.loc[val_index, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)

In [None]:
class CLD_train_DS(Dataset):
    
    def __init__(self, df, transform=None):
        self.df = df
        self.image_id = df['image_id'].values
        self.label = df['label'].values
        self.transform = transform
        
    def __len__(self) -> int:
        return len(self.df)
        
    def __getitem__(self, idx):
        
        image_id = self.image_id[idx]
        img = cv2.imread(TRAIN_DIR + image_id)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img = img / 255.0
        
        if self.transform:
            trans_img = self.transform(image=img)
            img = trans_img['image']
            
        label = torch.tensor(self.label[idx]).long()
        
        return img, label

In [None]:
def get_transform(*, train=True):
    
    if train:
        return A.Compose([
            A.RandomResizedCrop(IMG_SIZE, IMG_SIZE),
            A.Transpose(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.Normalize(mean=MEAN, std=STD),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
#             A.CenterCrop(IMG_SIZE, IMG_SIZE),
            A.Resize(IMG_SIZE, IMG_SIZE),
            A.Normalize(mean=MEAN, std=STD, max_pixel_value=255.0, p=1.0),
            ToTensorV2(),
        ])

[SAM optimizer](https://github.com/davda54/sam)

In [None]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

[Back to CFG(Click here)](#cont)

## Loss Functions

In [None]:
class LabelSmoothingLoss(nn.Module): 
    def __init__(self, classes=5, smoothing=0.0, dim=-1): 
        super(LabelSmoothingLoss, self).__init__() 
        self.confidence = 1.0 - smoothing 
        self.smoothing = smoothing 
        self.cls = classes 
        self.dim = dim 
    def forward(self, pred, target): 
        if CRITERION == 'LabelSmoothingLoss':
            pred = pred.log_softmax(dim=self.dim) 
        with torch.no_grad():
            true_dist = torch.zeros_like(pred) 
            true_dist.fill_(self.smoothing / (self.cls - 1)) 
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
    
class TaylorSoftmax(nn.Module):
    '''
    This is the autograd version
    '''
    def __init__(self, dim=1, n=2):
        super(TaylorSoftmax, self).__init__()
        assert n % 2 == 0
        self.dim = dim
        self.n = n

    def forward(self, x):
        '''
        usage similar to nn.Softmax:
            >>> mod = TaylorSoftmax(dim=1, n=4)
            >>> inten = torch.randn(1, 32, 64, 64)
            >>> out = mod(inten)
        '''
        fn = torch.ones_like(x)
        denor = 1.
        for i in range(1, self.n+1):
            denor *= i
            fn = fn + x.pow(i) / denor
        out = fn / fn.sum(dim=self.dim, keepdims=True)
        return out
    
class TaylorCrossEntropyLoss(nn.Module):
    '''
    This is the autograd version
    '''
    def __init__(self, n=2, ignore_index=-1, reduction='mean'):
        super(TaylorCrossEntropyLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, logits, labels):
        '''
        usage similar to nn.CrossEntropyLoss:
            >>> crit = TaylorCrossEntropyLoss(n=4)
            >>> inten = torch.randn(1, 10, 64, 64)
            >>> label = torch.randint(0, 10, (1, 64, 64))
            >>> out = crit(inten, label)
        '''
        log_probs = self.taylor_softmax(logits).log()
        loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
                ignore_index=self.ignore_index)
        return loss
    
class TaylorSmoothedLoss(nn.Module):

    def __init__(self, n=2, ignore_index=-1, reduction='mean', smoothing=0.2):
        super(TaylorSmoothedLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.lab_smooth = LabelSmoothingLoss(5, smoothing=LABEL_SMOOTH)

    def forward(self, logits, labels):

        log_probs = self.taylor_softmax(logits).log()
        #loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
        #        ignore_index=self.ignore_index)
        loss = self.lab_smooth(log_probs, labels)
        return loss

## Augmentations

[SnapMix](https://github.com/Shaoli-Huang/SnapMix)

In [None]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def get_spm(input,target,model):
    imgsize = (IMG_SIZE, IMG_SIZE)
    bs = input.size(0)
    with torch.no_grad():
        output,fms = model(input)
        clsw = model.classifier
        weight = clsw.weight.data
        bias = clsw.bias.data
        weight = weight.view(weight.size(0),weight.size(1),1,1)
        fms = F.relu(fms)
        poolfea = F.adaptive_avg_pool2d(fms,(1,1)).squeeze()
        clslogit = F.softmax(clsw.forward(poolfea))
        logitlist = []
        for i in range(bs):
            logitlist.append(clslogit[i,target[i]])
        clslogit = torch.stack(logitlist)

        out = F.conv2d(fms, weight, bias=bias)

        outmaps = []
        for i in range(bs):
            evimap = out[i,target[i]]
            outmaps.append(evimap)

        outmaps = torch.stack(outmaps)
        if imgsize is not None:
            outmaps = outmaps.view(outmaps.size(0),1,outmaps.size(1),outmaps.size(2))
            outmaps = F.interpolate(outmaps,imgsize,mode='bilinear',align_corners=False)

        outmaps = outmaps.squeeze()

        for i in range(bs):
            outmaps[i] -= outmaps[i].min()
            outmaps[i] /= outmaps[i].sum()


    return outmaps,clslogit


def snapmix(input, target, alpha, model=None):

    r = np.random.rand(1)
    lam_a = torch.ones(input.size(0))
    lam_b = 1 - lam_a
    target_b = target.clone()

    if True:
        wfmaps,_ = get_spm(input, target, model)
        bs = input.size(0)
        lam = np.random.beta(alpha, alpha)
        lam1 = np.random.beta(alpha, alpha)
        rand_index = torch.randperm(bs).cuda()
        wfmaps_b = wfmaps[rand_index,:,:]
        target_b = target[rand_index]

        same_label = target == target_b
        bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
        bbx1_1, bby1_1, bbx2_1, bby2_1 = rand_bbox(input.size(), lam1)

        area = (bby2-bby1)*(bbx2-bbx1)
        area1 = (bby2_1-bby1_1)*(bbx2_1-bbx1_1)

        if  area1 > 0 and  area>0:
            ncont = input[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone()
            ncont = F.interpolate(ncont, size=(bbx2-bbx1,bby2-bby1), mode='bilinear', align_corners=True)
            input[:, :, bbx1:bbx2, bby1:bby2] = ncont
            lam_a = 1 - wfmaps[:,bbx1:bbx2,bby1:bby2].sum(2).sum(1)/(wfmaps.sum(2).sum(1)+1e-8)
            lam_b = wfmaps_b[:,bbx1_1:bbx2_1,bby1_1:bby2_1].sum(2).sum(1)/(wfmaps_b.sum(2).sum(1)+1e-8)
            tmp = lam_a.clone()
            lam_a[same_label] += lam_b[same_label]
            lam_b[same_label] += tmp[same_label]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
            lam_a[torch.isnan(lam_a)] = lam
            lam_b[torch.isnan(lam_b)] = 1-lam

    return input,target,target_b,lam_a.cuda(),lam_b.cuda()

class SnapMixLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, criterion, outputs, ya, yb, lam_a, lam_b):
        loss_a = criterion(outputs, ya)
        loss_b = criterion(outputs, yb)
        loss = torch.mean(loss_a * lam_a + loss_b * lam_b)
        return loss

CutMix

In [None]:
def cutmix(batch, alpha):
    data, targets = batch

    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]

    lam = np.random.beta(alpha, alpha)

    image_h, image_w = data.shape[2:]
    cx = np.random.uniform(0, image_w)
    cy = np.random.uniform(0, image_h)
    w = image_w * np.sqrt(1 - lam)
    h = image_h * np.sqrt(1 - lam)
    x0 = int(np.round(max(cx - w / 2, 0)))
    x1 = int(np.round(min(cx + w / 2, image_w)))
    y0 = int(np.round(max(cy - h / 2, 0)))
    y1 = int(np.round(min(cy + h / 2, image_h)))

    data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1]
    return_targets = torch.zeros((len(targets),3),dtype=torch.int64)
    return_targets[:,0] = targets
    return_targets[:,1] = shuffled_targets
    return_targets[0,2] = lam

    return data, return_targets
        
class CutMixCollator:
    def __init__(self, alpha):
        self.alpha = alpha

    def __call__(self, batch):
        batch = torch.utils.data.dataloader.default_collate(batch)
        batch = cutmix(batch, self.alpha)
        return batch
    
class CutMixCriterion(nn.Module):
    def __init__(self, criterion):
        super(CutMixCriterion, self).__init__()
        self.criterion = criterion

    def forward(self, preds, labels):
        targets1 = labels[:,0] 
        targets2 = labels[:,1]
        lam = labels[0,2]
        loss = lam * self.criterion(preds, targets1) + (1 - lam) * self.criterion(preds, targets2)
        return loss

In [None]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix_(data, target, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.clip(np.random.beta(alpha, alpha),0.3,0.4)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    new_data = data.clone()
    new_data[:, :, bby1:bby2, bbx1:bbx2] = data[indices, :, bby1:bby2, bbx1:bbx2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))
    targets = (target, shuffled_target, lam)

    return new_data, targets

def fmix(data, targets, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
    lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
    #mask =torch.tensor(mask, device=device).float()
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]
    x1 = torch.from_numpy(mask).to(DEVICE)*data
    x2 = torch.from_numpy(1-mask).to(DEVICE)*shuffled_data
    targets=(targets, shuffled_targets, lam)
    
    return (x1+x2), targets

## Model

In [None]:
# Modified for SnapMix
class CassavaNet(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = timm.create_model(MODEL_ARCH, pretrained=True)
        n_features = backbone.classifier.in_features
        self.backbone = nn.Sequential(*backbone.children())[:-2]
        self.classifier = nn.Linear(n_features, 5)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward_features(self, x):
        x = self.backbone(x)
        return x

    def forward(self, x):
        feats = self.forward_features(x)
        x = self.pool(feats).view(x.size(0), -1)
        x = self.classifier(x)
        return x, feats

class CustomEffNet(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained, n_class)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
        
    def forward(self, x):
        x = self.model(x)
        return x
    
class CustomResNext(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.fc.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
        
    def forward(self, x):
        x = self.model(x)
        return x

class CustomViT(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.head.in_features
        self.model.head = nn.Linear(n_features, n_class)

    def forward(self, x):
        x = self.model(x)
        return x

class CustomDeiT(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = torch.hub.load('facebookresearch/deit:main', model_arch, pretrained=pretrained)
        n_features = self.model.head.in_features
        self.model.head = nn.Linear(n_features, n_class)

    def forward(self, x):
        x = self.model(x)
        return x

[Back to CFG(Click here)](#cont)

In [None]:
def GetCriterion(criterion_name, criterion=None):
#     if criterion_name == 'BiTemperedLoss':
#         criterion = BiTemperedLogistic()
#     elif criterion_name == 'SymmetricCrossEntropyLoss':
#         criterion = SymmetricCrossEntropy()
    if criterion_name == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss()
    elif criterion_name == 'LabelSmoothingLoss':
        criterion = LabelSmoothingLoss()
#     elif criterion_name == 'FocalLoss':
#         criterion = FocalLoss()
#     elif criterion_name == 'FocalCosineLoss':
#         criterion = FocalCosineLoss()
    elif criterion_name == 'TaylorCrossEntropyLoss':
        criterion = TaylorCrossEntropyLoss()
    elif criterion_name == 'TaylorSmoothedLoss':
        criterion = TaylorSmoothedLoss()
    elif criterion_name == 'CutMix':
        criterion = CutMixCriterion(criterion)
    elif criterion_name == 'SnapMix':
        criterion = SnapMixLoss()
    return criterion
    
    
def GetScheduler(scheduler_name, optimizer, batches=None):
    #['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'GradualWarmupSchedulerV2']
    if scheduler_name == 'OneCycleLR':
        return torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr = 1e-2,epochs = CFG.EPOCHS,
                                                   steps_per_epoch = batches+1,pct_start = 0.1)
    if scheduler_name == 'CosineAnnealingWarmRestarts':
        return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = T_0, T_mult=1,
                                                                    eta_min=MIN_LR, last_epoch=-1)
    elif scheduler_name == 'CosineAnnealingLR':
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = CFG.T_max, eta_min=0, last_epoch=-1)
    elif scheduler_name == 'ReduceLROnPlateau':
        return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=1, threshold=0.0001,
                                                          cooldown=0, min_lr=MIN_LR)
#     elif scheduler_name == 'GradualWarmupSchedulerV2':
#         return GradualWarmupSchedulerV2(optimizer=optimizer)
    
def GetOptimizer(optimizer_name,parameters):
    #['Adam','Ranger']
    if optimizer_name == 'Adam':
#         if CFG.scheduler_name == 'GradualWarmupSchedulerV2':
#             return torch.optim.Adam(parameters, lr=CFG.LR_START, weight_decay=CFG.weight_decay, amsgrad=False)
#         else:
        return torch.optim.Adam(parameters, lr=LR, weight_decay=WEIGHT_DECAY, amsgrad=False)
    elif optimizer_name == 'AdamW':
#         if CFG.scheduler_name == 'GradualWarmupSchedulerV2':
#             return torch.optim.AdamW(parameters, lr=CFG.LR_START, weight_decay=CFG.weight_decay, amsgrad=False)
#         else:
        return torch.optim.Adam(parameters, lr=LR, weight_decay=WEIGHT_DECAY, amsgrad=False)
    elif optimizer_name == 'AdamP':
#         if CFG.scheduler_name == 'GradualWarmupSchedulerV2':
#             return AdamP(parameters, lr=CFG.LR_START, weight_decay=CFG.weight_decay)
#         else:
        return AdamP(parameters, lr=LR, weight_decay=WEIGHT_DECAY)
    elif optimizer_name == 'Ranger':
        return Ranger(parameters, lr = LR, alpha = 0.5, k = 6, N_sma_threshhold = 5, 
                      betas = (0.95,0.999), weight_decay=WEIGHT_DECAY)
    elif optimizer_name == 'SAM':
        return SAM(parameters, BASE_OPTIMIZER, lr=0.1, momentum=0.9,weight_decay=0.0005)
    
    elif optimizer_name == 'AdamP':
        return AdamP(parameters, lr=LR, weight_decay=WEIGHT_DECAY)

# Train and validation functions

In [None]:
def train_fn(model, dataloader, device, epoch, optimizer, criterion, scheduler):
    
    data_time = AverageMeter()
    batch_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()
    
    model.train()
    # https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation
    scaler = GradScaler()
    start_time = time.time()
    loader = tqdm(dataloader, total=len(dataloader))
    for step, (images, labels) in enumerate(loader):
        
        images = images.to(device).float()
        labels = labels.to(device)
        data_time.update(time.time() - start_time)
        
#         mix_decision = np.random.rand()
#         if mix_decision < 0.25:
#             images, labels = cutmix_(images, labels, CM_ALPHA)
#             images = images.float()
#         elif mix_decision >=0.25 and mix_decision < 0.5:
#             images, labels = fmix(images, labels, alpha=CM_ALPHA, decay_power=DECAY_POWER, shape=(IMG_SIZE,IMG_SIZE))
#             images = images.float()

        # Run forward pass with autocasting.
        with autocast():
            
            if SNAPMIX:
                rand = np.random.rand()
                if rand > (1.0-SNAPMIX_PCT):
                    X, ya, yb, lam_a, lam_b = snapmix(images, labels, SNAPMIX_ALPHA, model)
                    output, _ = model(X)
                    snapmix_criterion = GetCriterion('SnapMix')
                    loss = snapmix_criterion(criterion, output, ya, yb, lam_a, lam_b)
                else:
                    output, _ = model(images)
                    criterion = GetCriterion('CrossEntropyLoss').to(device)
                    loss = criterion(output, labels)
            else:
                output = model(images)
                loss = criterion(output, labels)
#                 if mix_decision < 0.50:
#                     loss = criterion(output, labels[0]) * labels[2] + criterion(output, labels[1]) * (1. - labels[2])
#                 else:
#                     loss = criterion(output, labels)
                if CUTMIX:
                    losses.update(loss, BATCH_SIZE)
                else:
                    losses.update(loss.item(), BATCH_SIZE)
        
        # Call .backward() to accumulate scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm = MAX_NORM)

        if not CUTMIX:
                accuracy = (F.softmax(output).argmax(dim=1) == labels).float().mean()

        # Uncomment below if using cutmix_ and fmix
#         else:
#             if labels[2] >= 0.5:
#                 accuracy = (F.softmax(output).argmax(dim=1) == labels[0]).float().mean()
#             else:
#                 accuracy = (F.softmax(output).argmax(dim=1) == labels[1]).float().mean()
        else:
            if labels[0,2] >= 0.5:
                accuracy = (output.argmax(dim=1) == labels[:,0]).float().mean()
            else:
                accuracy = (output.argmax(dim=1) == labels[:,1]).float().mean()

        accuracies.update(accuracy.item(), BATCH_SIZE)
        
        if (step+1) % ITERS_TO_ACCUMULATE == 0:
            # Unscale the gradients of the optimizer's assigned params. If these
            # gradients do not contain infs or NaNs, optimizer.step() is then called,
            # otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)
            # Update the scale for next iteration.
            scaler.update()
            optimizer.zero_grad()
        
        if scheduler is not None and SCHEDULER_UPDATE == 'batch':
            scheduler.step()

        batch_time.update(time.time() - start_time)
        start_time = time.time()
        
        if step % ITER_FREQ == 0:
            
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s), '
                  'Data Time {data_time.val:.3f}s ({data_time.avg:.3f}s)\t'
                  'Loss: {loss.val:.4f} ({loss.avg:.4f}), '
                  'Accuracy {accuracy.val:.4f} ({accuracy.avg:.4f})'.format((epoch+1),
                                                                             step, len(dataloader),
                                                                             batch_time=batch_time,
                                                                             data_time=data_time,
                                                                             loss=losses,
                                                                             accuracy=accuracies))
        # To check the loss real-time while iterating over data.
        loader.set_postfix(loss=losses.avg, accuracy=accuracies.avg)
#         del images, labels
    if scheduler is not None and SCHEDULER_UPDATE == 'epoch':
        scheduler.step()
        
    return losses.avg, accuracies.avg

In [None]:
def valid_fn(epoch, model, criterion, val_loader, device, scheduler):
    
    model.eval()
    losses = AverageMeter()
    accuracies = AverageMeter()
    
    loader = tqdm(val_loader, total=len(val_loader))
    with torch.no_grad():  # without torch.no_grad() will make the CUDA run OOM.
        for step, (images, labels) in enumerate(loader):

            images = images.to(device)
            labels = labels.to(device)

            output = model(images)

            loss = criterion(output, labels)
            if CUTMIX:
                losses.update(loss, BATCH_SIZE)
            else:
                losses.update(loss.item(), BATCH_SIZE)
            accuracy = (F.softmax(output).argmax(dim=1) == labels).float().mean()
            accuracies.update(accuracy.item(), VAL_BATCH_SIZE)

            loader.set_postfix(loss=losses.avg, accuracy=accuracies.avg)
#             del images, labels
    
    if scheduler is not None:
        scheduler.step()
        
    return losses.avg, accuracies.avg

[Back to CFG(Click here)](#cont)

# Main

In [None]:
def engine(device, folds, fold, model_path=None):
    
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    
    train_data = CLD_train_DS(train_folds, transform=get_transform())
    val_data = CLD_train_DS(valid_folds, transform=get_transform(train=False))

    if CUTMIX:
        print('Using CutMix')
        collator = CutMixCollator(CM_ALPHA)                
    
    train_loader = DataLoader(train_data,
                              batch_size=BATCH_SIZE, 
                              shuffle=True, 
                              num_workers=NUM_WORKERS,
#                               collate_fn=collator, 
                              pin_memory=True, # enables faster data transfer to CUDA-enabled GPUs.
                              drop_last=True)
    val_loader = DataLoader(val_data, 
                              batch_size=VAL_BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              shuffle=False, 
                              pin_memory=True,
                              drop_last=False)
    

    if model_path is not None:
        model = torch.load(model_path)
        START_EPOCH = int(model_path.split('_')[-1])
    else:
        model = CustomEffNet(MODEL_ARCH, 5, True)
        START_EPOCH = 0
    model.to(device)
    
    params = filter(lambda p: p.requires_grad, model.parameters())    
    optimizer = GetOptimizer(OPTIMIZER, params)
    if not CUTMIX:
        criterion = GetCriterion(CRITERION).to(device)
    else:
        criterion = GetCriterion('CutMix', GetCriterion(CRITERION)).to(device)
        
    val_criterion = GetCriterion(CRITERION).to(device)
    scheduler = GetScheduler(SCHEDULER, optimizer)
    
    loss = []
    accuracy = []
    for epoch in range(START_EPOCH, EPOCHS):
        
        epoch_start = time.time()        
        avg_loss, avg_accuracy = train_fn(model, train_loader, device, epoch, optimizer, criterion, scheduler)

        torch.cuda.empty_cache()
        avg_val_loss, avg_val_acc = valid_fn(epoch, model, val_criterion, val_loader, device, scheduler)
        epoch_end = time.time() - epoch_start
        
        print(f'Validation accuracy after epoch {epoch+1}: {avg_val_acc:.4f}')
        loss.append(avg_loss)
        accuracy.append(avg_accuracy)
        
        content = f'Fold {fold} Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_train_accuracy: {avg_accuracy:.4f} avg_val_loss: {avg_val_loss:.4f} avg_val_acc: {avg_val_acc:.4f} time: {epoch_end:.0f}s'
        with open(f'GPU_{MODEL_ARCH}_{OPTIMIZER}_{CRITERION}.txt', 'a') as appender:
            appender.write(content + '\n')
        
        # Save the model to use it for inference.
        torch.save(model.state_dict(), f'GPU_{MODEL_ARCH}_fold_{fold}_epoch_{(epoch+1)}.pth')
        torch.save(model, f'GPU_{MODEL_ARCH}_fold_{fold}_epoch_{(epoch+1)}')
        torch.cuda.empty_cache()
    
    return {'loss':loss, 'accuracy':accuracy}

In [None]:
if __name__ == '__main__':
    
    if MODEL_PATH is not None:
        START_FOLD = int(MODEL_PATH.split('_')[-3])
    
    for fold in range(START_FOLD, N_FOLDS):
        print(f'===== Fold {fold} Starting =====')
        fold_start = time.time()
        logs = engine(DEVICE, folds, fold, MODEL_PATH)
        print(f'Time taken in fold {fold}: {time.time()-fold_start}')