## Library imports

In [1]:
!pip install timm

Collecting timm
  Downloading timm-0.3.4-py3-none-any.whl (244 kB)
[K     |████████████████████████████████| 244 kB 406 kB/s 
Installing collected packages: timm
Successfully installed timm-0.3.4
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
# append package pathss
import sys
append_paths = ['../input/image-fmix/FMix-master'] #,'../input/pytorch-image-models/pytorch-image-models-master', 
for package_path in append_paths:
    sys.path.append(package_path)

# basic imports
import os
import numpy as np
import pandas as pd
import random
import itertools
from tqdm.notebook import tqdm
import math

# augumentations library
from albumentations.pytorch import ToTensorV2
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightnessContrast,ShiftScaleRotate, Cutout, CoarseDropout, 
    IAAAdditiveGaussianNoise, Transpose, MotionBlur, MedianBlur, GaussianBlur, HueSaturationValue
    )
import albumentations as A
from fmix import sample_mask
import cv2

# DL library imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau, CyclicLR, OneCycleLR
from  torch.cuda.amp import autocast, GradScaler

# timm import
import timm

# metrics calculation
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import KFold, StratifiedKFold

# basic plotting library
import matplotlib.pyplot as plt

# interactive plots
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import warnings  
warnings.filterwarnings('ignore')

## Config params

In [3]:
class CFG:
    # pipeline parameters
    SEED        = 42
    NUM_CLASSES = 5
    TGT_LABEL   = 'label'
    TRAIN       = True
    LR_FIND     = False
    RETRAIN     = False
    TEST        = False
    DEBUG       = False
    N_FOLDS     = 5 
    N_EPOCHS    = 20 
    DF_FRAC     = 1  
    TEST_BATCH_SIZE  = 32
    TRAIN_BATCH_SIZE = 16
    SIZE             = [384, 384]
    NUM_WORKERS      = 4
    FOLD_TO_TRAIN    = [2] #
    GRAD_ACC_STEPS   = 4

    # model parameters
    MODEL_ARCH  = 'vit_base_patch16_384'
    MODEL_NAME  = 'vit_v11'
    WGT_PATH    = ''
    WGT_MODEL   = ''

    # loss fn parameters
    LOSS_FN     = 'LabelSmoothingCrossEntropy' # 'CrossEntropyLoss' 
    SMOOTHING   = 0.3
    MIX_PROB    = 0.25
    
    # lr scheduler variables
    MAX_LR    = 3e-4
    MIN_LR    = 1e-6
    SCHEDULER = 'OneCycleLR' # ['ReduceLROnPlateau', 'CosineAnnealingLR', '', ' CosineAnnealingWarmRestarts']    
    # CosineAnnealingWarmRestarts
    T_0       = 1       
    # CosineAnnealingLR & CyclicLR
    T_MAX     = 2.5      
    # CyclicLR
    GAMMA     = 0.98    
    # ReduceLROnPlateau
    LR_FACTOR = 5       
    PATIENCE  = 3      
    
    # optimizer variables
    OPTIMIZER     = 'Adam'
    WEIGHT_DECAY  = 1e-6
    MAX_GRD_NORM  = 1000


TRAIN_PATH = '../input/cassava-leaf-disease-classification/train_images'
TEST_PATH = '../input/cassava-leaf-disease-classification/test_images'
NPY_FOLDER = '../input/cassava-npy-train-images/train_npy_images'
DIR_INPUT = '../input/cassava-leaf-disease-classification'

index_label_map = {
                0: "Cassava Bacterial Blight (CBB)", 
                1: "Cassava Brown Streak Disease (CBSD)",
                2: "Cassava Green Mottle (CGM)", 
                3: "Cassava Mosaic Disease (CMD)", 
                4: "Healthy"
                }

class_names = [value for key,value in index_label_map.items()]

## Helper functions

In [4]:
def find_no_of_trainable_params(model):
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_trainable_params

In [5]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(CFG.SEED)

In [6]:
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [7]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## Dataset 

In [8]:
train_df = pd.read_csv(f'{DIR_INPUT}/train.csv')
#train_df[['cls0', 'cls1', 'cls2', 'cls3', 'cls4']] = train_labels = pd.get_dummies(train_df.iloc[:, 1])
train_df['npy_image_id'] = train_df['image_id'].str.replace('jpg', 'npy')
if CFG.DF_FRAC < 1:
    train_df = train_df.sample(frac=CFG.DF_FRAC).reset_index(drop=True)
train_labels = train_df.iloc[:, 1].values
print(train_df.shape)
train_df.head()
folds = StratifiedKFold(n_splits=CFG.N_FOLDS, shuffle=True, random_state=CFG.SEED)

if CFG.DEBUG == True:
    pass
    #folds = train_df.copy()
    #for n, (train_index, val_index) in enumerate(Fold.split(folds, folds[CFG.TGT_LABEL])):
    #    folds.loc[val_index, 'fold'] = int(n)
    #folds['fold'] = folds['fold'].astype(int)
    #print(folds.groupby(['fold', CFG.TGT_LABEL]).size())

(21397, 3)


In [9]:
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['npy_image_id'].values
        self.labels = df[CFG.TGT_LABEL].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = np.load(f'{NPY_FOLDER}/{self.file_names[idx]}')
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).long()
        return image, label
    

class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TEST_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

In [10]:
if CFG.DEBUG == True:
    train_dataset = TrainDataset(train_df, transform=None)
    for i in range(1):
        image, label = train_dataset[i]
        plt.imshow(image)
        plt.title(f'label: {label}')
        plt.show() 

## Transforms for Augumentations

In [11]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix(data, target, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.clip(np.random.beta(alpha, alpha),0.3,0.4)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    new_data = data.clone()
    new_data[:, :, bby1:bby2, bbx1:bbx2] = data[indices, :, bby1:bby2, bbx1:bbx2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))
    targets = (target, shuffled_target, lam)
    return new_data, targets

def fmix(data, targets, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
    lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
    #mask =torch.tensor(mask, device=device).float()
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]
    x1 = torch.from_numpy(mask).to(device)*data
    x2 = torch.from_numpy(1-mask).to(device)*shuffled_data
    targets=(targets, shuffled_targets, lam)
    return (x1+x2), targets

In [12]:
def generate_transforms():
    train_transforms = Compose([
            Resize(height=CFG.SIZE[0], width=CFG.SIZE[1]), #RandomResizedCrop(CFG.size, CFG.size),
            Transpose(p=0.3), VerticalFlip(p=0.3), HorizontalFlip(p=0.3), ShiftScaleRotate(p=0.4),
            RandomBrightnessContrast(p=0.4), 
            IAAAdditiveGaussianNoise(p=0.3),  # sharpen, affine transform
            OneOf([MotionBlur(blur_limit=3), MedianBlur(blur_limit=3), GaussianBlur(blur_limit=3)], p=0.3),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.3),
            CoarseDropout(p=0.4), Cutout(p=0.4),
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0)])
            # RandomCrop, IAAAdditiveGaussianNoise, RandomResizedCrop(sz,sz),   
            # CLAHE, ImageCompression, MaskDropout, elastictransform
            # IAAAffine

    val_transforms = Compose([
            Resize(height=CFG.SIZE[0], width=CFG.SIZE[1]),
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0), ToTensorV2(p=1.0)])

    test_transforms = Compose([
            Resize(height=CFG.SIZE[0], width=CFG.SIZE[1]),
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0), ToTensorV2(p=1.0)])

    return {'train_transforms':train_transforms, 'val_transforms':val_transforms, 'test_transform':test_transforms}

In [13]:
if CFG.DEBUG == True:
    train_dataset = TrainDataset(train_df, transform=generate_transforms()['train_transforms'])
    for i in range(1):
        image, label = train_dataset[i]
        plt.imshow(image[0])
        plt.title(f'label: {label}')
        plt.show() 

## Model class

In [14]:
class ViTBase16(nn.Module):
    def __init__(self, model_name=CFG.MODEL_ARCH, pretrained=False):
        super(ViTBase16, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.model.head = nn.Linear(self.model.head.in_features, CFG.NUM_CLASSES)
        
    def forward(self, x):
        x = self.model(x)
        return x

In [15]:
if CFG.DEBUG == True:
    model = CustomResNext(model_name=CFG.MODEL_ARCH, pretrained=False)
    train_dataset = TrainDataset(train_df, transform=generate_transforms()['train_transforms'])
    train_loader = DataLoader(train_dataset, batch_size= 4, shuffle=True,
                              num_workers=CFG.NUM_WORKERS, pin_memory=True, drop_last=True)
    for image, label in train_loader:
        output = model(image)
        print(output)
        break

## Loss function

In [16]:
class LabelSmoothingCrossEntropy(nn.Module):
    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x, target):
        logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()

In [17]:
## Device as cpu or tpu
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device('cpu')
print(device)

if CFG.LOSS_FN == 'CrossEntropyLoss':
    criterion = nn.CrossEntropyLoss()
else:
    criterion = LabelSmoothingCrossEntropy(smoothing=CFG.SMOOTHING)

cuda:0


## Lr_find

In [18]:
def plot_lr_finder_results(lr_finder): 
    # Create subplot grid
    fig = make_subplots(rows=1, cols=2)
    # layout ={'title': 'Lr_finder_result'}
    
    # Create a line (trace) for the lr vs loss, gradient of loss
    trace0 = go.Scatter(x=lr_finder['log_lr'], y=lr_finder['smooth_loss'],name='log_lr vs smooth_loss')
    trace1 = go.Scatter(x=lr_finder['log_lr'], y=lr_finder['grad_loss'],name='log_lr vs loss gradient')

    # Add subplot trace & assign to each grid
    fig.add_trace(trace0, row=1, col=1);
    fig.add_trace(trace1, row=1, col=2);
    #iplot(fig, show_link=False)
    fig.write_html(CFG.MODEL_NAME + '_lr_find.html');

In [19]:
def find_lr(model, optimizer, data_loader, init_value = 1e-8, final_value=100.0, beta = 0.98, num_batches = 200):
    assert(num_batches > 0)
    mult = (final_value / init_value) ** (1/num_batches)
    lr = init_value
    optimizer.param_groups[0]['lr'] = lr
    batch_num = 0
    avg_loss = 0.0
    best_loss = 0.0
    smooth_losses = []
    raw_losses = []
    log_lrs = []
    dataloader_it = iter(data_loader)
    progress_bar = tqdm(range(num_batches))                
        
    for idx in progress_bar:
        batch_num += 1
        try:
            images, labels = next(dataloader_it)
            #print(images.shape)
        except:
            dataloader_it = iter(data_loader)
            images, labels = next(dataloader_it)

        # Move input and label tensors to the default device
        images = images.to(device)
        labels = labels.to(device)

        # handle exception in criterion
        try:
            # Forward pass
            y_preds = model(images.float())
            loss = criterion(y_preds, labels)
        except:
            if len(smooth_losses) > 1:
                grad_loss = np.gradient(smooth_losses)
            else:
                grad_loss = 0.0
            lr_finder_results = {'log_lr':log_lrs, 'raw_loss':raw_losses, 
                                 'smooth_loss':smooth_losses, 'grad_loss': grad_loss}
            return lr_finder_results 
                    
        #Compute the smoothed loss
        avg_loss = beta * avg_loss + (1-beta) *loss.item()
        smoothed_loss = avg_loss / (1 - beta**batch_num)
        
        #Stop if the loss is exploding
        if batch_num > 1 and smoothed_loss > 50 * best_loss:
            if len(smooth_losses) > 1:
                grad_loss = np.gradient(smooth_losses)
            else:
                grad_loss = 0.0
            lr_finder_results = {'log_lr':log_lrs, 'raw_loss':raw_losses, 
                                 'smooth_loss':smooth_losses, 'grad_loss': grad_loss}
            return lr_finder_results
        
        #Record the best loss
        if smoothed_loss < best_loss or batch_num==1:
            best_loss = smoothed_loss
        
        #Store the values
        raw_losses.append(loss.item())
        smooth_losses.append(smoothed_loss)
        log_lrs.append(math.log10(lr))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # print info
        progress_bar.set_description(f"loss: {loss.item()},smoothed_loss: {smoothed_loss},lr : {lr}")

        #Update the lr for the next step
        lr *= mult
        optimizer.param_groups[0]['lr'] = lr
    
    grad_loss = np.gradient(smooth_losses)
    lr_finder_results = {'log_lr':log_lrs, 'raw_loss':raw_losses, 
                         'smooth_loss':smooth_losses, 'grad_loss': grad_loss}
    return lr_finder_results

In [20]:
if CFG.LR_FIND == True:
    # create Dataset
    temp_train_dataset = TrainDataset(train_df, transform=generate_transforms()['train_transforms'])
    temp_train_dataloader = DataLoader(temp_train_dataset, batch_size= CFG.TRAIN_BATCH_SIZE, shuffle=True,
                          num_workers=CFG.NUM_WORKERS, pin_memory=False, drop_last=False)

    # create model instance
    # load pretrained weight file, if present
    if CFG.RETRAIN == True:
        i_fold = 0
        checkpoint = torch.load(f'{CFG.WGT_PATH}/{CFG.WGT_MODEL}_fold{i_fold}.pth')
        model = ViTBase16(model_name=CFG.MODEL_ARCH, pretrained=False)
        model.to(device)
        model.load_state_dict(checkpoint['model'])
        print(f'Model loaded for {CFG.WGT_MODEL}_fold{i_fold}')
            
    else:
        model = ViTBase16(model_name=CFG.MODEL_ARCH, pretrained=True)
        model.to(device)
    optimizer = optim.Adam(model.parameters(), weight_decay=CFG.WEIGHT_DECAY, lr=CFG.MAX_LR)
    lr_finder_results = find_lr(model, optimizer, temp_train_dataloader)
    plot_lr_finder_results(lr_finder_results)

## One fold train and validation function

In [21]:
def train_one_fold(i_fold, model, optimizer, scheduler, scaler, dataloader_train, dataloader_valid):
    train_fold_results = []
    lr_list = []
    best_val_acc = 0.0
    best_epoch = 0
    
    for epoch in range(CFG.N_EPOCHS):
        print('  Epoch {}/{}'.format(epoch + 1, CFG.N_EPOCHS))
        model.train()
        tr_loss = 0.0
            
        # training iterator
        tr_iterator = iter(dataloader_train)
        train_progress_bar = tqdm(range(len(dataloader_train)))
    
        for idx in train_progress_bar:
            try:
                images, labels = next(tr_iterator)
                #print(images.shape)
            except StopIteration:
                tr_iterator = iter(dataloader_train)
                images, labels = next(tr_iterator)

            images = images.to(device)
            labels = labels.to(device)  
            #print(images.type()) # FloatTensor
            
            mix_decision = np.random.rand()
            if mix_decision < CFG.MIX_PROB:
                images, labels = fmix(images, labels, alpha=1., decay_power=5., shape=(CFG.SIZE[0],CFG.SIZE[1]))
            
            # builtin package to handle automatic mixed precision
            with autocast():
                # Forward pass
                y_preds = model(images.float())            
                if mix_decision < CFG.MIX_PROB:
                    loss = criterion(y_preds, labels[0]) * labels[2] + criterion(y_preds, labels[1]) * (1.0 - labels[2])
                else:
                    loss = criterion(y_preds, labels)
                tr_loss += loss.item()
                
                
                if CFG.GRAD_ACC_STEPS > 1:
                    loss = loss / CFG.GRAD_ACC_STEPS

                # Backward pass
                scaler.scale(loss).backward()
                
                # Gradient accumulation
                if ((idx + 1) % CFG.GRAD_ACC_STEPS == 0) or ((idx + 1) == len(dataloader_train)):
                    #print(f"backward pass done at {idx} batch")
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()                    
                                    
                # onecyle lr scheduler / Cosine annealingLR / Cyclic LR
                if CFG.SCHEDULER in ['CosineAnnealingLR', 'OneCycleLR', 'CyclicLR']:    
                    scheduler.step()
                        
            lr_list.append(optimizer.state_dict()["param_groups"][0]['lr'])
            train_progress_bar.set_description(f"Train_loss: {tr_loss} loss(avg): {tr_loss/(idx+1)}")
        
        # Validate
        model.eval()
        val_loss = 0.0
        val_preds = None
        val_labels = None
        valid_iterator = iter(dataloader_valid)
        valid_progress_bar = tqdm(range(len(dataloader_valid)))

        for idx in valid_progress_bar:
            try:
                images, labels = next(valid_iterator)
            except StopIteration:
                tr_iterator = iter(dataloader_valid)
                images, labels = next(valid_iterator)
            
            images = images.to(device)
            labels = labels.to(device)

            if val_labels is None:
                val_labels = labels.clone()
            else:
                val_labels = torch.cat((val_labels, labels), dim=0)
            
            with torch.no_grad():
                y_preds = model(images)
            
            loss = criterion(y_preds, labels)
            val_loss += loss.item()
            
            #if CFG.gradient_accumulation_steps > 1:
            #    loss = loss / CFG.gradient_accumulation_steps
            preds = torch.softmax(y_preds, dim=1)
            
            # store predictions            
            if val_preds is None:
                val_preds = preds
            else:
                val_preds = torch.cat((val_preds, preds), dim=0)
                
            # print to console
            valid_progress_bar.set_description(f"val_loss: {val_loss} loss(avg): {val_loss/(idx+1)}")
            
        # Warm restart policy
        if CFG.SCHEDULER == 'ReduceLROnPlateau':
            scheduler.step(val_loss / len(dataloader_valid))

        
        # save predictions
        val_preds  = np.argmax(val_preds.cpu().data.numpy(), axis=1)
        val_labels = val_labels.cpu().data.numpy()
        #print(val_preds.shape, val_labels.shape)
        # compute accuracy
        val_score = accuracy_score(val_labels, val_preds)
        # class wise accuracy, print results
        cm = confusion_matrix(val_labels, val_preds)
        class_wise_acc = []
        for i, val in enumerate(cm):
            class_wise_acc.append(val[i]/sum(val)*100)
        print(f"Fold:{i_fold}, Epoch:{epoch}, Overall accuracy : {val_score * 100.0}, \
               Classwise_acc:{class_wise_acc}")
        
        # store results
        train_fold_results.append({ 'fold': i_fold, 'epoch': epoch, 'train_loss': tr_loss / len(dataloader_train), 
                                    'valid_loss': val_loss / len(dataloader_valid), 'valid_score': val_score,
                                    'class_wise_acc': class_wise_acc})
            
        # save best models        
        if val_score > best_val_acc:
            # reset variables
            best_val_acc = val_score
            best_epoch = epoch
                        
            # save model weights
            torch.save({'model': model.state_dict(), 'val_preds':val_preds, 'val_labels':val_labels}, 
                        f"{CFG.MODEL_NAME}_fold_{i_fold}_epoch{epoch}.pth")
    
    print(f"For Fold {i_fold}, Best validation accuracy of {best_val_acc} was got at epoch {best_epoch}")                
    lr_list = np.array(lr_list)
    np.save(f"{CFG.MODEL_NAME}_fold{i_fold}_LRlist.npy", lr_list)
    return train_fold_results

## Training and validation function calls

In [22]:
if CFG.TRAIN == True:
    train_results = []

    for i_fold, (train_idx, valid_idx) in enumerate(folds.split(train_df, train_labels)):
        if i_fold in CFG.FOLD_TO_TRAIN:
            print("Fold {}/{}".format(i_fold + 1, CFG.N_FOLDS))
            
            # create fold data
            train_data = train_df.iloc[train_idx].reset_index()    
            valid_data = train_df.iloc[valid_idx].reset_index()
            print(train_data.shape, valid_data.shape)

            dataset_train = TrainDataset(train_data, transform=generate_transforms()['train_transforms'])
            dataset_valid = TrainDataset(valid_data, transform=generate_transforms()['val_transforms'])            
            dataloader_train = DataLoader(dataset_train, batch_size= CFG.TRAIN_BATCH_SIZE, shuffle=True,
                          num_workers=CFG.NUM_WORKERS, pin_memory=False, drop_last=False)
            dataloader_valid = DataLoader(dataset_valid, batch_size= CFG.TRAIN_BATCH_SIZE, shuffle=True,
                          num_workers=CFG.NUM_WORKERS, pin_memory=False, drop_last=False)

            # load pretrained weight file
            if CFG.RETRAIN == True:
                checkpoint = torch.load(f'{CFG.WGT_PATH}/{CFG.WGT_MODEL}_fold{i_fold}.pth')
                model = ViTBase16(model_name=CFG.MODEL_ARCH, pretrained=False)
                model.to(device)
                model.load_state_dict(checkpoint['model'])
                print(f'Model loaded for {CFG.WGT_MODEL}_fold{i_fold}')
            
            else:
                model = ViTBase16(model_name=CFG.MODEL_ARCH, pretrained=True)
                model.to(device)

            # scaler to handle AMP
            scaler = GradScaler()   
            
            if CFG.OPTIMIZER == 'Adam':
                optimizer = optim.Adam(model.parameters(), weight_decay=CFG.WEIGHT_DECAY, lr=CFG.MAX_LR)
            else:
                optimizer = optim.SGD(model.parameters(), weight_decay=CFG.WEIGHT_DECAY, lr=CFG.MAX_LR, momentum=0.9)
                
                
            # LR scheduler
            if CFG.SCHEDULER == 'OneCycleLR':
                scheduler = OneCycleLR(optimizer, max_lr= CFG.MAX_LR, epochs = CFG.N_EPOCHS,steps_per_epoch = len(dataloader_train),  
                                  pct_start=0.4, div_factor=10, anneal_strategy='cos')
            
            elif CFG.SCHEDULER == 'CosineAnnealingWarmRestarts':
                scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.MIN_LR, last_epoch=-1)
            
            elif CFG.SCHEDULER == 'CyclicLR':
                scheduler = CyclicLR(optimizer, base_lr=CFG.MIN_LR, max_lr=CFG.MAX_LR, 
                                     step_size_up= int(CFG.T_MAX*len(dataloader_train)),
                                    mode='exp_range', gamma=CFG.GAMMA, cycle_momentum=False)
                
            elif CFG.SCHEDULER == 'ReduceLROnPlateau':
                scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.LR_FACTOR, patience=CFG.PATIENCE,
                                              threshold_mode='abs')

            else:
                scheduler = CosineAnnealingLR(optimizer, T_MAX=CFG.T_MAX*len(dataloader_train), eta_min=CFG.MIN_LR, last_epoch=-1)
            print(scheduler)
            train_fold_results = train_one_fold(i_fold, model, optimizer, scheduler, scaler, dataloader_train, dataloader_valid)
            train_results = train_results + train_fold_results

Fold 3/5
(17118, 4) (4279, 4)


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth" to /root/.cache/torch/hub/checkpoints/jx_vit_base_p16_384-83fb41ba.pth


<torch.optim.lr_scheduler.OneCycleLR object at 0x7f64410f4690>
  Epoch 1/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:0, Overall accuracy : 86.5389109605048,                Classwise_acc:[46.08294930875576, 68.0365296803653, 72.32704402515722, 96.35258358662614, 82.33009708737865]
  Epoch 2/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:1, Overall accuracy : 86.95956999298902,                Classwise_acc:[64.97695852534562, 72.8310502283105, 75.47169811320755, 97.6823708206687, 64.07766990291263]
  Epoch 3/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:2, Overall accuracy : 87.02967983173639,                Classwise_acc:[60.36866359447005, 71.46118721461188, 73.37526205450735, 97.53039513677811, 70.48543689320388]
  Epoch 4/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:3, Overall accuracy : 86.60902079925215,                Classwise_acc:[59.907834101382484, 57.534246575342465, 72.32704402515722, 97.98632218844985, 77.66990291262135]
  Epoch 5/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:4, Overall accuracy : 81.72470203318532,                Classwise_acc:[56.68202764976959, 23.74429223744292, 62.68343815513627, 99.24012158054711, 69.70873786407768]
  Epoch 6/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:5, Overall accuracy : 85.06660434681001,                Classwise_acc:[38.70967741935484, 81.27853881278538, 66.87631027253668, 93.5790273556231, 81.16504854368932]
  Epoch 7/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:6, Overall accuracy : 85.41715354054685,                Classwise_acc:[49.76958525345622, 69.40639269406392, 66.87631027253668, 95.06079027355622, 81.94174757281554]
  Epoch 8/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:7, Overall accuracy : 82.37906052816079,                Classwise_acc:[57.14285714285714, 63.242009132420094, 67.71488469601678, 89.70364741641338, 85.43689320388349]
  Epoch 9/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:8, Overall accuracy : 86.98293993923814,                Classwise_acc:[56.68202764976959, 75.79908675799086, 72.53668763102725, 96.88449848024317, 72.03883495145631]
  Epoch 10/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:9, Overall accuracy : 86.46880112175742,                Classwise_acc:[50.23041474654379, 66.89497716894978, 69.18238993710692, 95.93465045592706, 86.01941747572816]
  Epoch 11/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:10, Overall accuracy : 84.59920542182753,                Classwise_acc:[64.97695852534562, 69.86301369863014, 65.40880503144653, 93.00911854103343, 80.19417475728156]
  Epoch 12/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:11, Overall accuracy : 86.58565085300303,                Classwise_acc:[63.133640552995395, 78.08219178082192, 77.14884696016772, 94.98480243161094, 69.51456310679612]
  Epoch 13/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:12, Overall accuracy : 86.63239074550128,                Classwise_acc:[34.56221198156682, 82.1917808219178, 67.29559748427673, 97.30243161094225, 75.72815533980582]
  Epoch 14/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:13, Overall accuracy : 86.02477214302408,                Classwise_acc:[52.07373271889401, 80.82191780821918, 59.11949685534591, 97.98632218844985, 68.54368932038835]
  Epoch 15/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:14, Overall accuracy : 87.75414816545923,                Classwise_acc:[65.43778801843318, 74.88584474885845, 74.63312368972747, 97.15045592705167, 72.23300970873787]
  Epoch 16/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:15, Overall accuracy : 87.59055854171535,                Classwise_acc:[64.0552995391705, 75.57077625570776, 73.37526205450735, 96.65653495440729, 74.5631067961165]
  Epoch 17/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:16, Overall accuracy : 87.7307782192101,                Classwise_acc:[62.67281105990783, 77.85388127853882, 76.31027253668763, 96.80851063829788, 70.87378640776699]
  Epoch 18/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:17, Overall accuracy : 87.6372984342136,                Classwise_acc:[62.21198156682027, 75.79908675799086, 74.42348008385744, 96.42857142857143, 75.72815533980582]
  Epoch 19/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:18, Overall accuracy : 87.87099789670484,                Classwise_acc:[65.43778801843318, 75.34246575342466, 75.47169811320755, 96.69452887537993, 74.36893203883496]
  Epoch 20/20


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=268.0), HTML(value='')))


Fold:2, Epoch:19, Overall accuracy : 87.91773778920309,                Classwise_acc:[65.43778801843318, 75.79908675799086, 75.05241090146751, 96.69452887537993, 74.75728155339806]
For Fold 2, Best validation accuracy of 0.8791773778920309 was got at epoch 19


## Plot training results

In [23]:
def plot_training_results():
    fig = make_subplots(rows=2, cols=1)

    colors = [
        ('#d32f2f', '#ef5350'),
        ('#303f9f', '#5c6bc0'),
        ('#00796b', '#26a69a'),
        ('#fbc02d', '#ffeb3b'),
        ('#5d4037', '#8d6e63'),
    ]

    for i in range(CFG.N_FOLDS):
        data = train_results[train_results['fold'] == i]

        fig.add_trace(go.Scatter(x=data['epoch'].values,
                                 y=data['train_loss'].values,
                                 mode='lines',
                                 visible='legendonly' if i > 0 else True,
                                 line=dict(color=colors[i][0], width=2),
                                 name='Train loss - Fold #{}'.format(i)),
                     row=1, col=1)

        fig.add_trace(go.Scatter(x=data['epoch'],
                                 y=data['valid_loss'].values,
                                 mode='lines+markers',
                                 visible='legendonly' if i > 0 else True,
                                 line=dict(color=colors[i][1], width=2),
                                 name='Valid loss - Fold #{}'.format(i)),
                     row=1, col=1)

        fig.add_trace(go.Scatter(x=data['epoch'].values,
                                 y=data['valid_score'].values,
                                 mode='lines+markers',
                                 line=dict(color=colors[i][0], width=2),
                                 name='Valid score - Fold #{}'.format(i),
                                 showlegend=False),
                     row=2, col=1)

    fig.update_layout({
      "annotations": [
        {
          "x": 0.225, 
          "y": 1.0, 
          "font": {"size": 16}, 
          "text": "Train / valid losses", 
          "xref": "paper", 
          "yref": "paper", 
          "xanchor": "center", 
          "yanchor": "bottom", 
          "showarrow": False
        }, 
        {
          "x": 0.775, 
          "y": 1.0, 
          "font": {"size": 16}, 
          "text": "Validation scores", 
          "xref": "paper", 
          "yref": "paper", 
          "xanchor": "center", 
          "yanchor": "bottom", 
          "showarrow": False
        }, 
      ]
    })

    fig.show()

val_preds_0 = np.load('./R18_imagenet_v2_val_preds_0.npy')
val_labels_0 = np.load('./R18_imagenet_v2_val_labels_0.npy')

cm = confusion_matrix(val_labels_0, val_preds_0)
print(cm)
plt.figure(figsize=(8,8))
plot_confusion_matrix(cm, classes=class_names, normalize=True)

In [24]:
if CFG.TRAIN == True:
    train_results = pd.DataFrame(train_results)
    print(train_results)
    train_results.to_csv('train_results.csv', index=False)
    best_folds = np.array([train_results[train_results['fold']==x]['valid_score'].max() for x in CFG.FOLD_TO_TRAIN])
    print(f'Overall CV accuracy : {best_folds.mean()}, std: {best_folds.std()}')
    plot_training_results()

    fold  epoch  train_loss  valid_loss  valid_score  \
0      2      0    1.165265    1.064151     0.865389   
1      2      1    1.090477    1.046992     0.869596   
2      2      2    1.076745    1.042073     0.870297   
3      2      3    1.078238    1.059497     0.866090   
4      2      4    1.079170    1.123302     0.817247   
5      2      5    1.085362    1.068543     0.850666   
6      2      6    1.080656    1.069479     0.854172   
7      2      7    1.073797    1.096788     0.823791   
8      2      8    1.069541    1.043894     0.869829   
9      2      9    1.053037    1.054913     0.864688   
10     2     10    1.046385    1.071126     0.845992   
11     2     11    1.034312    1.053369     0.865857   
12     2     12    1.023803    1.057301     0.866324   
13     2     13    1.010141    1.072457     0.860248   
14     2     14    0.991824    1.050073     0.877541   
15     2     15    0.985131    1.048692     0.875906   
16     2     16    0.966997    1.052442     0.87

## Testing function

In [25]:
if CFG.TEST == True:
    # read submission file
    submission_df = pd.read_csv(DIR_INPUT + '/sample_submission.csv')
    submission_df.iloc[:, 1] = 4
    #print(submission_df.head())
    submission_df.to_csv('submission.csv', index=False)

if pipeline["test"] == True:
    # read submission file
    submission_df = pd.read_csv(DIR_INPUT + '/sample_submission.csv')
    submission_df.iloc[:, 1] = 0
    #print(submission_df.head())


    # just for debugging purporse, adding 1 more row
    if submission_df.shape[0] == 1:
        submission_df = pd.DataFrame([{'image_id': '2216849948.jpg', 'label': 0},{'image_id': '2216849948.jpg', 'label': 0}])
        submission_df.reset_index(drop=True, inplace=True)
    #print(submission_df.head())


    # Creating test dataset and dataloaders
    dataset_test = CassavaDataset(df=submission_df, dataset='test', transforms=transforms_test)
    dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, num_workers=4, shuffle=False)
    
    
    # placeholder for final submission csv
    submissions = None

    """
    1. Iterate and store predictions (one-hot encoded format) of N-folds of model 
    2. Average the predictions of all folds
    3. argmax of mean one-hot encoded prediction is output
    """
    for i_fold in range(N_FOLDS):
        print(f'Inference for {i_fold}th fold')
        model = CassavaModel(num_classes=5, use_pretrained_weights=False)
        model.to(device)

        checkpoint = torch.load(f"{model_cfg['weight_path']}/{model_cfg['model_name']}_fold_{i_fold}.pth", map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'], strict=True)
        model.eval()
        test_preds = None

        for step, (images, _) in enumerate(dataloader_test):
            images = images.to(device, dtype=torch.float)
            with torch.no_grad():
                outputs = model(images)
                preds = torch.softmax(outputs, dim=1).data.cpu()
                if test_preds is None:
                    test_preds = preds
                else:
                    test_preds = torch.cat((test_preds, preds), dim=0)

        # submission_df[['label']] = test_preds.argmax(test_preds, dim=1)
        # submission_df.to_csv('submission_fold_{}.csv'.format(i_fold), index=False)

        # logits avg
        if submissions is None:
            submissions = test_preds / N_FOLDS
        else:
            submissions += test_preds / N_FOLDS
            
        
    #print(submissions[:10])
    # argmax of predictions and write to csv
    submission_df['label'] = torch.argmax(submissions, dim=1)
    submission_df.to_csv('submission.csv', index=False)
    #print(submission_df.head())