## About this kernel  
#### Used SEresnext50 with SnapMix augmentation
## Source kernel  
#### This notebook was written by refering these great kernels below, so please don't forget to check and upvote them.  
* https://www.kaggle.com/sachinprabhu/pytorch-resnet50-snapmix-train-pipeline  
* https://www.kaggle.com/shaolihuang/training-with-snapmix  

In [None]:
!pip install timm
!pip install torchsummary

In [None]:
import cv2
import torch
from torch import nn
import random
import os
import time
from tqdm import tqdm_notebook as tqdm
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
import pickle
import timm

from PIL import Image
from albumentations import *
from albumentations.pytorch import ToTensorV2

### Config

In [None]:
CFG = {
    'scheduler' : 'MultiStepLR',
    'optimizer' : 'SGD',
    'snapmix_alpha' : 5.0,
    'snapmix_ptc' : 1.0, 
    'model_name' : 'seresnext50_32x4d',
    'num_classes' : 5, 
    'accum_iter' : 1,
    'current_fold': 0,
    'fold_num' : 5, 
    'img_size': 448,
    'resize' : 512,
    'epochs': 10,
    'milestones': [1, 10, 20], # MultiStepLR
    'train_bs': 32,
    'valid_bs': 64,
    'weight_decay_Adam' : 1e-6,
    'weight_decay_SGD' : 1e-2,
    'num_workers': 8,
    'device': 'cuda:0',
    'seed' : 3,
    'verbose_step': 1,
    'T_0': 10, # CosineAnnealingWarmRestarts
    'lr': 1e-3,
    'min_lr': 1e-5 # CosineAnnealingWarmRestarts
}

### Load csv file

In [None]:
train = pd.read_csv('../input/colabcassava/train.csv')
train.label.value_counts()

In [None]:
def seed_everything(seed = 3):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    #print(im_rgb)
    return im_rgb

### Dataset

In [None]:
class CassavaDataset(Dataset):
    """Cassava dataset."""

    def __init__(self, dataframe, root_dir, transforms=None):
        super().__init__()
        self.dataframe = dataframe
        self.root_dir = root_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.dataframe)
    
    def get_img_bgr_to_rgb(self, path):
        im_bgr = cv2.imread(path)
        im_rgb = im_bgr[:, :, ::-1]
        return im_rgb

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_name = os.path.join(self.root_dir,
                                self.dataframe.iloc[idx, 0])
        image = self.get_img_bgr_to_rgb(img_name)
        if self.transforms:
            image = self.transforms(image=image)['image']
        csv_row = self.dataframe.iloc[idx, 1:]
        sample = {
            'image': image, 
            'label': csv_row.label,
        }
        return sample

### Rand augmentation

In [None]:
def get_train_transforms():
    return Compose([
            Resize(CFG['resize'], CFG['resize'], p = 1.),
            RandomCrop(CFG['img_size'], CFG['img_size'], p = 1.),
            HorizontalFlip(p=0.5),
            #VerticalFlip(p=0.5),
            #Transpose(p = 0.5),
            #Rotate(limit = 10, p = 0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)
    
def get_valid_transforms():
    return Compose([
            Resize(CFG['resize'], CFG['resize'], p = 1.),
            CenterCrop(CFG['img_size'], CFG['img_size'], p=1.),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

In [None]:
print("Available SEresnext Models: ")
timm.list_models("seresnext*")

### Model

In [None]:
class CassavaNet(nn.Module):
    def __init__ (self):
        super().__init__()
        backbone = timm.create_model(CFG['model_name'], pretrained=True)
        n_features = backbone.fc.in_features
        self.backbone = nn.Sequential(*backbone.children())[:-2]
        self.classifier = nn.Linear(n_features, CFG['num_classes'])
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        #self.model.fc = nn.Linear(n_features, CFG['num_classes'])
    def forward_features(self, x):
        x = self.backbone(x)
        return x

    def forward(self, x):
        feats = self.forward_features(x)
        x = self.pool(feats).view(x.size(0), -1)
        x = self.classifier(x)
        return x, feats

### summary model

In [None]:
from torchsummary import summary
device = torch.device(CFG['device'])
model = CassavaNet().to(device)
print(summary(model, (3, CFG['img_size'], CFG['img_size'])))
del model

### SnapMix augmentation

In [None]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def get_spm(input,target,model):
    imgsize = (CFG['img_size'], CFG['img_size'])
    bs = input.size(0)
    with torch.no_grad():
        output,fms = model(input)
        clsw = model.classifier
        weight = clsw.weight.data
        bias = clsw.bias.data
        weight = weight.view(weight.size(0),weight.size(1),1,1)
        fms = F.relu(fms)
        poolfea = F.adaptive_avg_pool2d(fms,(1,1)).squeeze()
        clslogit = F.softmax(clsw.forward(poolfea))
        logitlist = []
        for i in range(bs):
            logitlist.append(clslogit[i,target[i]])
        clslogit = torch.stack(logitlist)

        out = F.conv2d(fms, weight, bias=bias)

        outmaps = []
        for i in range(bs):
            evimap = out[i,target[i]]
            outmaps.append(evimap)

        outmaps = torch.stack(outmaps)
        if imgsize is not None:
            outmaps = outmaps.view(outmaps.size(0),1,outmaps.size(1),outmaps.size(2))
            outmaps = F.interpolate(outmaps,imgsize,mode='bilinear',align_corners=False)

        outmaps = outmaps.squeeze()

        for i in range(bs):
            outmaps[i] -= outmaps[i].min()
            outmaps[i] /= outmaps[i].sum()


    return outmaps,clslogit


def snapmix(input, target, alpha, model=None):

    r = np.random.rand(1)
    lam_a = torch.ones(input.size(0))
    lam_b = 1 - lam_a
    target_b = target.clone()

    if True:
        wfmaps,_ = get_spm(input, target, model)
        bs = input.size(0)
        lam = np.random.beta(alpha, alpha)
        lam1 = np.random.beta(alpha, alpha)
        rand_index = torch.randperm(bs).cuda()
        wfmaps_b = wfmaps[rand_index,:,:]
        target_b = target[rand_index]

        same_label = target == target_b
        bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
        bbx1_1, bby1_1, bbx2_1, bby2_1 = rand_bbox(input.size(), lam1)

        area = (bby2-bby1)*(bbx2-bbx1)
        area1 = (bby2_1-bby1_1)*(bbx2_1-bbx1_1)

        if  area1 > 0 and  area > 0:
            ncont = input[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone()
            ncont = F.interpolate(ncont, size=(bbx2-bbx1,bby2-bby1), mode='bilinear', align_corners=True)
            input[:, :, bbx1:bbx2, bby1:bby2] = ncont
            lam_a = 1 - wfmaps[:,bbx1:bbx2,bby1:bby2].sum(2).sum(1)/(wfmaps.sum(2).sum(1)+1e-8)
            lam_b = wfmaps_b[:,bbx1_1:bbx2_1,bby1_1:bby2_1].sum(2).sum(1)/(wfmaps_b.sum(2).sum(1)+1e-8)
            tmp = lam_a.clone()
            lam_a[same_label] += lam_b[same_label]
            lam_b[same_label] += tmp[same_label]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
            lam_a[torch.isnan(lam_a)] = lam
            lam_b[torch.isnan(lam_b)] = 1-lam

    return input,target,target_b,lam_a.cuda(),lam_b.cuda()

### SnapMix Loss

In [None]:
class SnapMixLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, criterion, outputs, ya, yb, lam_a, lam_b):
        loss_a = criterion(outputs, ya)
        loss_b = criterion(outputs, yb)
        loss = torch.mean(loss_a * lam_a + loss_b * lam_b)
        return loss

In [None]:
def prepare_dataloader(df, trn_idx, val_idx, 
                       data_root=None):
    
    train_ = df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = df.loc[val_idx,:].reset_index(drop=True)
        
    train_ds = CassavaDataset(train_, data_root, transforms=get_train_transforms())
    valid_ds = CassavaDataset(valid_, data_root, transforms=get_valid_transforms())
    
    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=CFG['train_bs'],
        pin_memory=False,
        drop_last=False,
        shuffle=True,        
        num_workers=CFG['num_workers'],
    )
    val_loader = torch.utils.data.DataLoader(
        valid_ds, 
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
    return train_loader, val_loader

In [None]:
def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device):
    model.train()

    t = time.time()
    running_loss = None
    sample = 0

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, data in pbar:
        (imgs, image_labels) = data.values()
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        with autocast(): 

            outputs = model(imgs)
            loss = loss_fn(outputs, image_labels)
            
            scaler.scale(loss).backward()

            if running_loss is None:
                running_loss = loss.item()*image_labels.shape[0]
            else:
                running_loss += loss.item()*image_labels.shape[0]

            sample += image_labels.shape[0]

            if ((step + 1) %  CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)):
                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad() 

            if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)):
                description = f'epoch {epoch} loss: {running_loss/sample:.4f}'
                
                pbar.set_description(description)

def train_one_epoch_snapmix(epoch, model, loss_fn, optimizer, train_loader, device):
    model.train()
    snapmix_criterion = SnapMixLoss().to(device)
    t = time.time()
    running_loss = None
    sample = 0
    count_snapmix = 0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, data in pbar:
        (imgs, image_labels) = data.values()
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        with autocast(): 
            
            rand = np.random.rand()
            loss = None
            if rand > (1.0 - CFG['snapmix_ptc']):
                count_snapmix += 1
                imgs, ya, yb, lam_a, lam_b = snapmix(imgs, image_labels, CFG['snapmix_alpha'], model)
                outputs, _ = model(imgs)
                loss = snapmix_criterion(loss_fn, outputs, ya, yb, lam_a, lam_b)
            else:
                outputs, _ = model(imgs)
                loss = torch.mean(loss_fn(outputs, image_labels))

            #outputs = model(imgs)
            #loss = loss_fn(outputs, image_labels)
            
            scaler.scale(loss).backward()

            if running_loss is None:
                running_loss = loss.item()*image_labels.shape[0]
            else:
                running_loss += loss.item()*image_labels.shape[0]

            sample += image_labels.shape[0]

            if ((step + 1) %  CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)):
                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad() 

            if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)):
                description = f'epoch {epoch} loss: {running_loss/sample:.4f}'
                
                pbar.set_description(description)
    print('Number of snapmix iterations: {}/{}'.format(count_snapmix, len(train_loader)))

def valid_one_epoch(epoch, model, loss_fn, val_loader, device):
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []
    
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, data in pbar:
        (imgs, image_labels) = data.values()
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds, _ = model(imgs)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]
        
        loss = loss_fn(image_preds, image_labels)
        
        loss_sum += loss.item()*image_labels.shape[0]
        sample_num += image_labels.shape[0]  

        if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)):
            description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
            pbar.set_description(description)
    
    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    accuracy = (image_preds_all==image_targets_all).mean()
    print('validation multi-class accuracy = {:.4f}'.format(accuracy))
    
    return accuracy

### Train

In [None]:
if __name__ == '__main__':
     # for training only, need nightly build pytorch

    seed_everything(CFG['seed'])
    
    for fold in range(CFG['fold_num']):

        if fold != CFG['current_fold']:
            continue
            
        pickle_in = open("../input/5folds/train_fold " + str(fold) + ".pickle","rb")
        train_folds = pickle.load(pickle_in)
        pickle_in.close()
    
        pickle_in = open("../input/5folds/valid_fold " + str(fold) + ".pickle","rb")
        valid_folds = pickle.load(pickle_in)
        pickle_in.close()
        
        trn_idx = list(train_folds)
        val_idx = list(valid_folds)
        
        print('Training with {} started'.format(fold))
        print(len(trn_idx), len(val_idx))

        train_loader, val_loader = prepare_dataloader(train, trn_idx, val_idx, 
                                    data_root='../input/colabcassava/handle_data')

        device = torch.device(CFG['device'])
        
        model = CassavaNet().to(device)

        scaler = GradScaler()   
        param_groups = [
            {'params': model.backbone.parameters(), 'lr': 1e-2},
            {'params': model.classifier.parameters()},
        ]

        optimizer = None
        if CFG['optimizer'] == 'SGD':
            optimizer = torch.optim.SGD(param_groups, lr=1e-1, momentum=0.9,
                                    weight_decay=CFG['weight_decay_SGD'], nesterov=True)
        elif optimizer == 'Adam':
            optimizer = torch.optim.Adam(param_groups, lr=CFG['lr'], 
                                         weight_decay=CFG['weight_decay_Adam'])

        scheduler = None
        if CFG['scheduler'] == 'CosineAnnealingWarmRestarts':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['T_0'], 
                                    T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1, verbose = 1)
        elif CFG['scheduler'] == 'ReduceLROnPlateau':
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.3, 
                                      patience=2, verbose=True, eps=1e-6)
        elif CFG['scheduler'] == 'MultiStepLR':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=CFG['milestones'], 
                                                     gamma=0.1, last_epoch=-1, verbose=True)
        
        loss_tr = nn.CrossEntropyLoss(reduction='none').to(device)
        loss_fn = nn.CrossEntropyLoss().to(device)
        
        best_accuracy = 0
        #model.load_state_dict( 
        #           torch.load('/content/gdrive/MyDrive/Data/resnext50/resnext50_32x4d_fold_0.pth'))
        for epoch in range(CFG['epochs']):
            train_one_epoch_snapmix(epoch, model, loss_tr, optimizer, train_loader, device)
            
            valid_acc = 0
            with torch.no_grad():
                valid_acc = valid_one_epoch(epoch, model, loss_fn, val_loader, device)
            if CFG['scheduler'] == 'ReduceLROnPlateau':
                scheduler.step(valid_acc)
            else:
                scheduler.step()
            if valid_acc > best_accuracy:
                best_accuracy = valid_acc
                print('Save best model at epoch {} : {}'.format(epoch, valid_acc))
                torch.save(model.state_dict(), '{}_fold_{}.pth'.format(CFG['model_name'], fold))
            torch.save(model.state_dict(), '{}_cur_fold_{}.pth'.format(CFG['model_name'], fold))
        del model, optimizer, train_loader, val_loader, scaler, scheduler
        torch.cuda.empty_cache()

*If you see this kernel helpfully, please Upvote it! Have fun*