## Library imports

In [1]:
# append package pathss
import sys
append_paths = ['../input/pytorch-image-models/pytorch-image-models-master', 
                '../input/image-fmix/FMix-master']
for package_path in append_paths:
    sys.path.append(package_path)

# basic imports
import os
import numpy as np
import pandas as pd
import random
import itertools
from tqdm.notebook import tqdm
import math

# augumentations library
from albumentations.pytorch import ToTensorV2
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightnessContrast,ShiftScaleRotate, Cutout, CoarseDropout, 
    IAAAdditiveGaussianNoise, Transpose, MotionBlur, MedianBlur, GaussianBlur, HueSaturationValue
    )
import albumentations as A
from fmix import sample_mask
import cv2

# DL library imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from  torch.cuda.amp import autocast, GradScaler

# timm import
import timm

# metrics calculation
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import KFold, StratifiedKFold

# basic plotting library
import matplotlib.pyplot as plt

# interactive plots
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import warnings  
warnings.filterwarnings('ignore')

## Config params

In [2]:
class CFG:
    # pipeline parameters
    SEED        = 42
    NUM_CLASSES = 5
    TGT_LABEL   = 'label'
    TRAIN       = True
    LR_FIND     = False
    RETRAIN     = False
    TEST        = False
    DEBUG       = False
    N_FOLDS     = 5 
    N_EPOCHS    = 27 
    DF_FRAC     = 1  
    TEST_BATCH_SIZE  = 32
    TRAIN_BATCH_SIZE = 16
    SIZE             = [448, 448]
    NUM_WORKERS      = 4
    FOLD_TO_TRAIN    = [0] # , 1, 2, 3, 4

    # model parameters
    MODEL_ARCH  = 'tf_efficientnet_b4_ns'
    MODEL_NAME  = 'eff_b4_v5'
    WGT_PATH    = ''
    WGT_MODEL   = ''
    
    # fmix aug prob
    MIX_PROB    = 0.25
    
    # loss fn parameters
    LOSS_FN     = 'CrossEntropyLoss' # ['BiTemperedLogisticLoss', 'LabelSmoothingCrossEntropy']
    
    # LabelSmoothingCrossEntropy param
    SMOOTHING   = 0.3
    
    # BiTemperedLogisticLoss param
    LABEL_SMOOTH = 0.2
    T1 = 0.8
    T2 = 1.4

    # scheduler variables
    MAX_LR    = 1e-3
    MIN_LR    = 1e-6
    SCHEDULER = 'CosineAnnealingLR'  # ['ReduceLROnPlateau', 'OneCycleLR', CosineAnnealingWarmRestarts']
    T_0       = 10   # CosineAnnealingWarmRestarts
    T_MAX     = 2.5    # CosineAnnealingLR

    # optimizer variables
    OPTIMIZER     = 'Adam'
    WEIGHT_DECAY  = 1e-6
    GRD_ACC_STEPS = 1
    MAX_GRD_NORM  = 1000
    


TRAIN_PATH = '../input/cassava-leaf-disease-classification/train_images'
TEST_PATH = '../input/cassava-leaf-disease-classification/test_images'
NPY_FOLDER = '../input/cassava-npy-train-images/train_npy_images'
DIR_INPUT = '../input/cassava-leaf-disease-classification'

index_label_map = {
                0: "Cassava Bacterial Blight (CBB)", 
                1: "Cassava Brown Streak Disease (CBSD)",
                2: "Cassava Green Mottle (CGM)", 
                3: "Cassava Mosaic Disease (CMD)", 
                4: "Healthy"
                }

class_names = [value for key,value in index_label_map.items()]

## Helper functions

In [3]:
def find_no_of_trainable_params(model):
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_trainable_params

In [4]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(CFG.SEED)

In [5]:
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [6]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## Dataset 

In [7]:
train_df = pd.read_csv(f'{DIR_INPUT}/train.csv')
#train_df[['cls0', 'cls1', 'cls2', 'cls3', 'cls4']] = train_labels = pd.get_dummies(train_df.iloc[:, 1])
train_df['npy_image_id'] = train_df['image_id'].str.replace('jpg', 'npy')
if CFG.DF_FRAC < 1:
    train_df = train_df.sample(frac=CFG.DF_FRAC).reset_index(drop=True)
train_labels = train_df.iloc[:, 1].values
print(train_df.shape)
train_df.head()
folds = StratifiedKFold(n_splits=CFG.N_FOLDS, shuffle=True, random_state=CFG.SEED)

if CFG.DEBUG == True:
    pass
    #folds = train_df.copy()
    #for n, (train_index, val_index) in enumerate(Fold.split(folds, folds[CFG.TGT_LABEL])):
    #    folds.loc[val_index, 'fold'] = int(n)
    #folds['fold'] = folds['fold'].astype(int)
    #print(folds.groupby(['fold', CFG.TGT_LABEL]).size())

(21397, 3)


In [8]:
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['npy_image_id'].values
        self.labels = df[CFG.TGT_LABEL].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = np.load(f'{NPY_FOLDER}/{self.file_names[idx]}')
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).long()
        return image, label
    

class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TEST_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

In [9]:
if CFG.DEBUG == True:
    train_dataset = TrainDataset(train_df, transform=None)
    for i in range(1):
        image, label = train_dataset[i]
        plt.imshow(image)
        plt.title(f'label: {label}')
        plt.show() 

## Transforms for Augumentations

In [10]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix(data, target, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.clip(np.random.beta(alpha, alpha),0.3,0.4)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    new_data = data.clone()
    new_data[:, :, bby1:bby2, bbx1:bbx2] = data[indices, :, bby1:bby2, bbx1:bbx2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))
    targets = (target, shuffled_target, lam)
    return new_data, targets

def fmix(data, targets, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
    lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
    #mask =torch.tensor(mask, device=device).float()
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]
    x1 = torch.from_numpy(mask).to(device)*data
    x2 = torch.from_numpy(1-mask).to(device)*shuffled_data
    targets=(targets, shuffled_targets, lam)
    return (x1+x2), targets

In [11]:
def generate_transforms():
    train_transforms = Compose([
            RandomResizedCrop(CFG.SIZE[0],CFG.SIZE[1]),
            Resize(height=CFG.SIZE[0], width=CFG.SIZE[1]), #RandomResizedCrop(CFG.size, CFG.size),
            Transpose(p=0.3), VerticalFlip(p=0.3), HorizontalFlip(p=0.3), ShiftScaleRotate(p=0.4),
            RandomBrightnessContrast(p=0.4), 
            IAAAdditiveGaussianNoise(p=0.3),  # sharpen, affine transform
            OneOf([MotionBlur(blur_limit=3), MedianBlur(blur_limit=3), GaussianBlur(blur_limit=3)], p=0.3),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.3),
            CoarseDropout(p=0.4), Cutout(p=0.4),
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0)])
            # RandomCrop, IAAAdditiveGaussianNoise, RandomResizedCrop(sz,sz),   
            # CLAHE, ImageCompression, MaskDropout, elastictransform
            # IAAAffine

    val_transforms = Compose([
            Resize(height=CFG.SIZE[0], width=CFG.SIZE[1]),
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0), ToTensorV2(p=1.0)])

    test_transforms = Compose([
            Resize(height=CFG.SIZE[0], width=CFG.SIZE[1]),
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0), ToTensorV2(p=1.0)])

    return {'train_transforms':train_transforms, 'val_transforms':val_transforms, 'test_transform':test_transforms}

In [12]:
if CFG.DEBUG == True:
    train_dataset = TrainDataset(train_df, transform=generate_transforms()['train_transforms'])
    for i in range(1):
        image, label = train_dataset[i]
        plt.imshow(image[0])
        plt.title(f'label: {label}')
        plt.show() 

## Model class

In [13]:
# original resnext class    
class EfficientnetClassifier(nn.Module):
    def __init__(self, model_arch, n_class=CFG.NUM_CLASSES, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
        
    def forward(self, x):
        x = self.model(x)
        return x

In [14]:
if CFG.DEBUG == True:
    model = EfficientnetClassifier(model_arch=CFG.MODEL_ARCH, pretrained=False)
    train_dataset = TrainDataset(train_df, transform=generate_transforms()['train_transforms'])
    train_loader = DataLoader(train_dataset, batch_size= 4, shuffle=True,
                              num_workers=CFG.NUM_WORKERS, pin_memory=True, drop_last=True)
    for image, label in train_loader:
        output = model(image)
        print(output)
        break

## Loss function

In [15]:
# Code taken from https://github.com/fhopfmueller/bi-tempered-loss-pytorch/blob/master/bi_tempered_loss_pytorch.py
def log_t(u, t):
    """Compute log_t for `u'."""
    if t==1.0:
        return u.log()
    else:
        return (u.pow(1.0 - t) - 1.0) / (1.0 - t)

def exp_t(u, t):
    """Compute exp_t for `u'."""
    if t==1:
        return u.exp()
    else:
        return (1.0 + (1.0-t)*u).relu().pow(1.0 / (1.0 - t))

def compute_normalization_fixed_point(activations, t, num_iters):

    """Returns the normalization value for each example (t > 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same shape as activation with the last dimension being 1.
    """
    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations_step_0 = activations - mu

    normalized_activations = normalized_activations_step_0

    for _ in range(num_iters):
        logt_partition = torch.sum(
                exp_t(normalized_activations, t), -1, keepdim=True)
        normalized_activations = normalized_activations_step_0 * \
                logt_partition.pow(1.0-t)

    logt_partition = torch.sum(
            exp_t(normalized_activations, t), -1, keepdim=True)
    normalization_constants = - log_t(1.0 / logt_partition, t) + mu

    return normalization_constants

def compute_normalization_binary_search(activations, t, num_iters):

    """Returns the normalization value for each example (t < 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (< 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """

    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations = activations - mu

    effective_dim = \
        torch.sum(
                (normalized_activations > -1.0 / (1.0-t)).to(torch.int32),
            dim=-1, keepdim=True).to(activations.dtype)

    shape_partition = activations.shape[:-1] + (1,)
    lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device)
    upper = -log_t(1.0/effective_dim, t) * torch.ones_like(lower)

    for _ in range(num_iters):
        logt_partition = (upper + lower)/2.0
        sum_probs = torch.sum(
                exp_t(normalized_activations - logt_partition, t),
                dim=-1, keepdim=True)
        update = (sum_probs < 1.0).to(activations.dtype)
        lower = torch.reshape(
                lower * update + (1.0-update) * logt_partition,
                shape_partition)
        upper = torch.reshape(
                upper * (1.0 - update) + update * logt_partition,
                shape_partition)

    logt_partition = (upper + lower)/2.0
    return logt_partition + mu

class ComputeNormalization(torch.autograd.Function):
    """
    Class implementing custom backward pass for compute_normalization. See compute_normalization.
    """
    @staticmethod
    def forward(ctx, activations, t, num_iters):
        if t < 1.0:
            normalization_constants = compute_normalization_binary_search(activations, t, num_iters)
        else:
            normalization_constants = compute_normalization_fixed_point(activations, t, num_iters)

        ctx.save_for_backward(activations, normalization_constants)
        ctx.t=t
        return normalization_constants

    @staticmethod
    def backward(ctx, grad_output):
        activations, normalization_constants = ctx.saved_tensors
        t = ctx.t
        normalized_activations = activations - normalization_constants 
        probabilities = exp_t(normalized_activations, t)
        escorts = probabilities.pow(t)
        escorts = escorts / escorts.sum(dim=-1, keepdim=True)
        grad_input = escorts * grad_output
        
        return grad_input, None, None

def compute_normalization(activations, t, num_iters=5):
    """Returns the normalization value for each example. 
    Backward pass is implemented.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    return ComputeNormalization.apply(activations, t, num_iters)

def tempered_sigmoid(activations, t, num_iters = 5):
    """Tempered sigmoid function.
    Args:
      activations: Activations for the positive class for binary classification.
      t: Temperature tensor > 0.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_probabilities = tempered_softmax(internal_activations, t, num_iters)
    return internal_probabilities[..., 0]


def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature > 1.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    if t == 1.0:
        return activations.softmax(dim=-1)

    normalization_constants = compute_normalization(activations, t, num_iters)
    return exp_t(activations - normalization_constants, t)

def bi_tempered_binary_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing = 0.0,
        num_iters=5,
        reduction='mean'):

    """Bi-Tempered binary logistic loss.
    Args:
      activations: A tensor containing activations for class 1.
      labels: A tensor with shape as activations, containing probabilities for class 1
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing
      num_iters: Number of iterations to run the method.
    Returns:
      A loss tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_labels = torch.stack([labels.to(activations.dtype),
        1.0 - labels.to(activations.dtype)],
        dim=-1)
    return bi_tempered_logistic_loss(internal_activations, 
            internal_labels,
            t1,
            t2,
            label_smoothing = label_smoothing,
            num_iters = num_iters,
            reduction = reduction)

def bi_tempered_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing=0.0,
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
            - labels_onehot * log_t(probabilities, t1) \
            - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
            + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim = -1) #sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()

In [16]:
class LabelSmoothingCrossEntropy(nn.Module):
    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x, target):
        logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()

In [17]:
class BiTemperedLogisticLoss(nn.Module):
    def __init__(self, t1=CFG.T1, t2=CFG.T2, label_smoothing=CFG.LABEL_SMOOTH):
        super(BiTemperedLogisticLoss, self).__init__()
        assert label_smoothing < 1.0
        self.t1 = t1
        self.t2 = t2
        self.label_smoothing = label_smoothing
    
    def forward(self, preds, labels):
        loss = bi_tempered_logistic_loss(preds, labels, t1=self.t1, t2=self.t2, label_smoothing=self.label_smoothing)
        return loss

In [18]:
## Device as cpu or tpu
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device('cpu')
print(device)

if CFG.LOSS_FN == 'LabelSmoothingCrossEntropy':
    criterion = LabelSmoothingCrossEntropy(smoothing=CFG.SMOOTHING)
elif CFG.LOSS_FN == 'BiTemperedLogisticLoss':
    criterion = BiTemperedLogisticLoss()
else:
    criterion = nn.CrossEntropyLoss()

cuda:0


## Lr_find

In [19]:
def plot_lr_finder_results(lr_finder): 
    # Create subplot grid
    fig = make_subplots(rows=1, cols=2)
    # layout ={'title': 'Lr_finder_result'}
    
    # Create a line (trace) for the lr vs loss, gradient of loss
    trace0 = go.Scatter(x=lr_finder['log_lr'], y=lr_finder['smooth_loss'],name='log_lr vs smooth_loss')
    trace1 = go.Scatter(x=lr_finder['log_lr'], y=lr_finder['grad_loss'],name='log_lr vs loss gradient')

    # Add subplot trace & assign to each grid
    fig.add_trace(trace0, row=1, col=1);
    fig.add_trace(trace1, row=1, col=2);
    #iplot(fig, show_link=False)
    fig.write_html(CFG.MODEL_NAME + '_lr_find.html');

In [20]:
def find_lr(model, optimizer, data_loader, init_value = 1e-8, final_value=100.0, beta = 0.98, num_batches = 200):
    assert(num_batches > 0)
    mult = (final_value / init_value) ** (1/num_batches)
    lr = init_value
    optimizer.param_groups[0]['lr'] = lr
    batch_num = 0
    avg_loss = 0.0
    best_loss = 0.0
    smooth_losses = []
    raw_losses = []
    log_lrs = []
    dataloader_it = iter(data_loader)
    progress_bar = tqdm(range(num_batches))                
        
    for idx in progress_bar:
        batch_num += 1
        try:
            images, labels = next(dataloader_it)
            #print(images.shape)
        except:
            dataloader_it = iter(data_loader)
            images, labels = next(dataloader_it)

        # Move input and label tensors to the default device
        images = images.to(device)
        labels = labels.to(device)

        # handle exception in criterion
        try:
            # Forward pass
            y_preds = model(images.float())
            loss = criterion(y_preds, labels)
        except:
            if len(smooth_losses) > 1:
                grad_loss = np.gradient(smooth_losses)
            else:
                grad_loss = 0.0
            lr_finder_results = {'log_lr':log_lrs, 'raw_loss':raw_losses, 
                                 'smooth_loss':smooth_losses, 'grad_loss': grad_loss}
            return lr_finder_results 
                    
        #Compute the smoothed loss
        avg_loss = beta * avg_loss + (1-beta) *loss.item()
        smoothed_loss = avg_loss / (1 - beta**batch_num)
        
        #Stop if the loss is exploding
        if batch_num > 1 and smoothed_loss > 50 * best_loss:
            if len(smooth_losses) > 1:
                grad_loss = np.gradient(smooth_losses)
            else:
                grad_loss = 0.0
            lr_finder_results = {'log_lr':log_lrs, 'raw_loss':raw_losses, 
                                 'smooth_loss':smooth_losses, 'grad_loss': grad_loss}
            return lr_finder_results
        
        #Record the best loss
        if smoothed_loss < best_loss or batch_num==1:
            best_loss = smoothed_loss
        
        #Store the values
        raw_losses.append(loss.item())
        smooth_losses.append(smoothed_loss)
        log_lrs.append(math.log10(lr))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # print info
        progress_bar.set_description(f"loss: {loss.item()},smoothed_loss: {smoothed_loss},lr : {lr}")

        #Update the lr for the next step
        lr *= mult
        optimizer.param_groups[0]['lr'] = lr
    
    grad_loss = np.gradient(smooth_losses)
    lr_finder_results = {'log_lr':log_lrs, 'raw_loss':raw_losses, 
                         'smooth_loss':smooth_losses, 'grad_loss': grad_loss}
    return lr_finder_results

In [21]:
if CFG.LR_FIND == True:
    # create Dataset
    temp_train_dataset = TrainDataset(train_df, transform=generate_transforms()['train_transforms'])
    temp_train_dataloader = DataLoader(temp_train_dataset, batch_size= CFG.TRAIN_BATCH_SIZE, shuffle=True,
                          num_workers=CFG.NUM_WORKERS, pin_memory=False, drop_last=False)

    # create model instance
    # load pretrained weight file, if present
    if CFG.RETRAIN == True:
        i_fold = 0
        checkpoint = torch.load(f'{CFG.WGT_PATH}/{CFG.WGT_MODEL}_fold{i_fold}.pth')
        model = EfficientnetClassifier(model_arch=CFG.MODEL_ARCH, pretrained=False)
        model.to(device)
        model.load_state_dict(checkpoint['model'])
        print(f'Model loaded for {CFG.WGT_MODEL}_fold{i_fold}')
            
    else:
        model = CustomResNext(model_arch=CFG.MODEL_ARCH, pretrained=True)
        model.to(device)
    optimizer = optim.Adam(model.parameters(), weight_decay=CFG.WEIGHT_DECAY, lr=CFG.MAX_LR)
    lr_finder_results = find_lr(model, optimizer, temp_train_dataloader)
    plot_lr_finder_results(lr_finder_results)

## One fold train and validation function

In [22]:
def train_one_fold(i_fold, model, optimizer, scheduler, scaler, dataloader_train, dataloader_valid):
    train_fold_results = []
    lr_list = []
    best_val_acc = 0.0
    best_epoch = 0
    
    for epoch in range(CFG.N_EPOCHS):
        print('  Epoch {}/{}'.format(epoch + 1, CFG.N_EPOCHS))
        model.train()
        tr_loss = 0.0
            
        # training iterator
        tr_iterator = iter(dataloader_train)
        train_progress_bar = tqdm(range(len(dataloader_train)))
    
        for idx in train_progress_bar:
            try:
                images, labels = next(tr_iterator)
            except StopIteration:
                tr_iterator = iter(dataloader_train)
                images, labels = next(tr_iterator)

            images = images.to(device)
            labels = labels.to(device)  
            
            mix_decision = np.random.rand()
            if mix_decision < CFG.MIX_PROB:
                images, labels = fmix(images, labels, alpha=1., decay_power=5., shape=(CFG.SIZE[0],CFG.SIZE[1]))
            
            # builtin package to handle automatic mixed precision
            with autocast():
                # Forward pass
                y_preds = model(images.float())            
                if mix_decision < CFG.MIX_PROB:
                    loss = criterion(y_preds, labels[0]) * labels[2] + criterion(y_preds, labels[1]) * (1.0 - labels[2])
                else:
                    loss = criterion(y_preds, labels)
                    #print(loss.shape)
                #loss = bi_tempered_logistic_loss(y_preds, labels, t1=CFG.T1, t2=CFG.T2, label_smoothing=CFG.LABEL_SMOOTH)    
                # Backward pass
                scaler.scale(loss).backward()
                tr_loss += loss.item()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad() 
            
            # onecyle lr scheduler / CosineAnnealingLR scheduler
            scheduler.step()
            # CosineAnnealingWarmRestarts scheduler
            # scheduler.step(epoch + idx / len(dataloader_train))
            
            lr_list.append(optimizer.state_dict()["param_groups"][0]['lr'])
            train_progress_bar.set_description(f"Train_loss: {tr_loss} loss(avg): {tr_loss/(idx+1)}")
        
        # Validate
        model.eval()
        val_loss = 0.0
        val_preds = None
        val_labels = None
        valid_iterator = iter(dataloader_valid)
        valid_progress_bar = tqdm(range(len(dataloader_valid)))

        for idx in valid_progress_bar:
            try:
                images, labels = next(valid_iterator)
            except StopIteration:
                valid_iterator = iter(dataloader_valid)
                images, labels = next(valid_iterator)
            
            images = images.to(device)
            labels = labels.to(device)

            if val_labels is None:
                val_labels = labels.clone()
            else:
                val_labels = torch.cat((val_labels, labels), dim=0)
            
            with torch.no_grad():
                y_preds = model(images)
            
            loss = criterion(y_preds, labels)
            val_loss += loss.item()
            preds = torch.softmax(y_preds, dim=1)
            
            # store predictions            
            if val_preds is None:
                val_preds = preds
            else:
                val_preds = torch.cat((val_preds, preds), dim=0)
                
            # print to console
            valid_progress_bar.set_description(f"val_loss: {val_loss} loss(avg): {val_loss/(idx+1)}")
        
        
        # save predictions
        val_preds  = np.argmax(val_preds.cpu().data.numpy(), axis=1)
        val_labels = val_labels.cpu().data.numpy()
        #print(val_preds.shape, val_labels.shape)
        # compute accuracy
        val_score = accuracy_score(val_labels, val_preds)
        # class wise accuracy, print results
        cm = confusion_matrix(val_labels, val_preds)
        class_wise_acc = []
        for i, val in enumerate(cm):
            class_wise_acc.append(val[i]/sum(val)*100)
        print(f"Fold:{i_fold}, Epoch:{epoch}, Overall accuracy : {val_score * 100.0}, \
               Classwise_acc:{class_wise_acc}")
        
        # store results
        train_fold_results.append({ 'fold': i_fold, 'epoch': epoch, 'train_loss': tr_loss / len(dataloader_train), 
                                    'valid_loss': val_loss / len(dataloader_valid), 'valid_score': val_score,
                                    'class_wise_acc': class_wise_acc})
            
        # save best models        
        if val_score > best_val_acc:
            # reset variables
            best_val_acc = val_score
            best_epoch = epoch
                        
            # save model weights
            torch.save({'model': model.state_dict(), 'val_preds':val_preds, 'val_labels':val_labels}, 
                        f"{CFG.MODEL_NAME}_fold_{i_fold}_epoch{epoch}.pth")
    
    print(f"For Fold {i_fold}, Best validation accuracy of {best_val_acc} was got at epoch {best_epoch}")                
    lr_list = np.array(lr_list)
    np.save(f"{CFG.MODEL_NAME}_fold{i_fold}_LRlist.npy", lr_list)
    return train_fold_results

## Training and validation function calls

In [23]:
if CFG.TRAIN == True:
    train_results = []

    for i_fold, (train_idx, valid_idx) in enumerate(folds.split(train_df, train_labels)):
        if i_fold in CFG.FOLD_TO_TRAIN:
            print("Fold {}/{}".format(i_fold + 1, CFG.N_FOLDS))
            
            # create fold data
            train_data = train_df.iloc[train_idx].reset_index()    
            valid_data = train_df.iloc[valid_idx].reset_index()
            print(train_data.shape, valid_data.shape)

            dataset_train = TrainDataset(train_data, transform=generate_transforms()['train_transforms'])
            dataset_valid = TrainDataset(valid_data, transform=generate_transforms()['val_transforms'])            
            dataloader_train = DataLoader(dataset_train, batch_size= CFG.TRAIN_BATCH_SIZE, shuffle=True,
                          num_workers=CFG.NUM_WORKERS, pin_memory=False, drop_last=False)
            dataloader_valid = DataLoader(dataset_valid, batch_size= CFG.TEST_BATCH_SIZE, shuffle=True,
                          num_workers=CFG.NUM_WORKERS, pin_memory=False, drop_last=False)

            # load pretrained weight file
            if CFG.RETRAIN == True:
                checkpoint = torch.load(f'{CFG.WGT_PATH}/{CFG.WGT_MODEL}_fold{i_fold}.pth')
                model = EfficientnetClassifier(model_arch=CFG.MODEL_ARCH, pretrained=False)
                model.to(device)
                model.load_state_dict(checkpoint['model'])
                print(f'Model loaded for {CFG.WGT_MODEL}_fold{i_fold}')
            
            else:
                model = EfficientnetClassifier(model_arch=CFG.MODEL_ARCH, pretrained=True)
                model.to(device)

            # scaler to handle AMP
            scaler = GradScaler()   
            
            if CFG.OPTIMIZER == 'Adam':
                optimizer = optim.Adam(model.parameters(), weight_decay=CFG.WEIGHT_DECAY, lr=CFG.MAX_LR)
            else:
                optimizer = optim.SGD(model.parameters(), weight_decay=CFG.WEIGHT_DECAY, lr=CFG.MAX_LR, momentum=0.9)
            
            if CFG.SCHEDULER == 'OneCycleLR':
                scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr= CFG.MAX_LR, epochs = CFG.N_EPOCHS, 
                                  steps_per_epoch = len(dataloader_train), pct_start=0.25, div_factor=10, anneal_strategy='cos')
            elif CFG.SCHEDULER == 'CosineAnnealingWarmRestarts':
                scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.MIN_LR, last_epoch=-1)
            else:
                scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_MAX * len(dataloader_train), eta_min=CFG.MIN_LR, last_epoch=-1)
            
            print(f"optimizer={optimizer}, scheduler={scheduler}, loss_fn={criterion}")
            
            train_fold_results = train_one_fold(i_fold, model, optimizer, scheduler, scaler, dataloader_train, dataloader_valid)
            train_results = train_results + train_fold_results

Fold 1/5
(17117, 4) (4280, 4)


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b4_ns-d6313a46.pth


optimizer=Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.001
    lr: 0.001
    weight_decay: 1e-06
), scheduler=<torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7f790a89b350>, loss_fn=CrossEntropyLoss()
  Epoch 1/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:0, Overall accuracy : 82.57009345794393,                Classwise_acc:[74.77064220183486, 55.47945205479452, 60.167714884696025, 93.15849486887116, 75.5813953488372]
  Epoch 2/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:1, Overall accuracy : 87.38317757009347,                Classwise_acc:[46.330275229357795, 79.90867579908677, 75.47169811320755, 96.84530596731281, 73.83720930232558]
  Epoch 3/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:2, Overall accuracy : 87.64018691588785,                Classwise_acc:[52.752293577981646, 75.34246575342466, 74.63312368972747, 96.08513873052071, 81.78294573643412]
  Epoch 4/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:3, Overall accuracy : 86.1214953271028,                Classwise_acc:[45.87155963302752, 67.35159817351598, 68.13417190775681, 96.95933105283162, 80.42635658914729]
  Epoch 5/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:4, Overall accuracy : 82.64018691588785,                Classwise_acc:[27.981651376146786, 74.65753424657534, 51.57232704402516, 94.2987457240593, 81.78294573643412]
  Epoch 6/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:5, Overall accuracy : 83.78504672897196,                Classwise_acc:[57.798165137614674, 75.57077625570776, 78.40670859538784, 89.09160015203345, 79.65116279069767]
  Epoch 7/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:6, Overall accuracy : 87.73364485981308,                Classwise_acc:[60.550458715596335, 75.57077625570776, 75.47169811320755, 95.81908019764349, 79.65116279069767]
  Epoch 8/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:7, Overall accuracy : 87.80373831775701,                Classwise_acc:[52.752293577981646, 76.9406392694064, 76.31027253668763, 95.59103002660585, 82.75193798449612]
  Epoch 9/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:8, Overall accuracy : 84.25233644859813,                Classwise_acc:[73.39449541284404, 72.3744292237443, 61.84486373165618, 92.39832763207906, 78.10077519379846]
  Epoch 10/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:9, Overall accuracy : 84.74299065420561,                Classwise_acc:[61.00917431192661, 50.456621004566216, 73.58490566037736, 96.54123907259597, 74.03100775193798]
  Epoch 11/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:10, Overall accuracy : 87.71028037383178,                Classwise_acc:[63.30275229357798, 72.6027397260274, 81.13207547169812, 95.47700494108705, 77.32558139534885]
  Epoch 12/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:11, Overall accuracy : 88.76168224299066,                Classwise_acc:[60.09174311926605, 77.6255707762557, 80.71278825995807, 96.95933105283162, 75.96899224806202]
  Epoch 13/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:12, Overall accuracy : 88.66822429906543,                Classwise_acc:[55.04587155963303, 78.31050228310502, 77.9874213836478, 96.54123907259597, 81.3953488372093]
  Epoch 14/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:13, Overall accuracy : 86.61214953271028,                Classwise_acc:[63.76146788990825, 78.99543378995433, 78.40670859538784, 93.6145952109464, 74.6124031007752]
  Epoch 15/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:14, Overall accuracy : 87.54672897196262,                Classwise_acc:[66.5137614678899, 79.68036529680366, 75.8909853249476, 96.38920562523755, 68.7984496124031]
  Epoch 16/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:15, Overall accuracy : 87.89719626168224,                Classwise_acc:[61.00917431192661, 78.31050228310502, 82.18029350104821, 95.59103002660585, 73.44961240310077]
  Epoch 17/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:16, Overall accuracy : 88.36448598130842,                Classwise_acc:[64.22018348623854, 74.88584474885845, 82.38993710691824, 95.40098821740783, 79.65116279069767]
  Epoch 18/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:17, Overall accuracy : 88.8785046728972,                Classwise_acc:[56.88073394495413, 79.45205479452055, 80.29350104821803, 96.42721398707715, 79.84496124031007]
  Epoch 19/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:18, Overall accuracy : 85.70093457943925,                Classwise_acc:[71.55963302752293, 73.74429223744292, 75.8909853249476, 91.9042189281642, 79.26356589147287]
  Epoch 20/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:19, Overall accuracy : 86.26168224299066,                Classwise_acc:[40.825688073394495, 81.5068493150685, 66.87631027253668, 95.47700494108705, 80.42635658914729]
  Epoch 21/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:20, Overall accuracy : 87.03271028037383,                Classwise_acc:[39.908256880733944, 83.56164383561644, 70.64989517819707, 95.096921322691, 83.91472868217055]
  Epoch 22/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:21, Overall accuracy : 88.29439252336448,                Classwise_acc:[54.58715596330275, 76.48401826484019, 80.71278825995807, 96.16115545419991, 79.45736434108527]
  Epoch 23/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:22, Overall accuracy : 88.03738317757009,                Classwise_acc:[57.3394495412844, 76.9406392694064, 78.61635220125787, 95.66704675028507, 80.23255813953489]
  Epoch 24/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:23, Overall accuracy : 86.49532710280374,                Classwise_acc:[44.03669724770643, 73.97260273972603, 68.34381551362684, 96.08513873052071, 82.94573643410853]
  Epoch 25/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:24, Overall accuracy : 86.63551401869158,                Classwise_acc:[72.47706422018348, 76.9406392694064, 85.74423480083857, 92.8164196123147, 70.15503875968993]
  Epoch 26/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:25, Overall accuracy : 87.45327102803738,                Classwise_acc:[44.03669724770643, 75.79908675799086, 79.24528301886792, 95.85708855948309, 80.42635658914729]
  Epoch 27/27


HBox(children=(FloatProgress(value=0.0, max=1070.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))


Fold:0, Epoch:26, Overall accuracy : 88.69158878504673,                Classwise_acc:[63.76146788990825, 77.3972602739726, 80.50314465408806, 96.12314709236031, 78.48837209302324]
For Fold 0, Best validation accuracy of 0.888785046728972 was got at epoch 17


## Plot training results

In [24]:
def plot_training_results():
    fig = make_subplots(rows=2, cols=1)

    colors = [
        ('#d32f2f', '#ef5350'),
        ('#303f9f', '#5c6bc0'),
        ('#00796b', '#26a69a'),
        ('#fbc02d', '#ffeb3b'),
        ('#5d4037', '#8d6e63'),
    ]

    for i in range(CFG.N_FOLDS):
        data = train_results[train_results['fold'] == i]

        fig.add_trace(go.Scatter(x=data['epoch'].values,
                                 y=data['train_loss'].values,
                                 mode='lines',
                                 visible='legendonly' if i > 0 else True,
                                 line=dict(color=colors[i][0], width=2),
                                 name='Train loss - Fold #{}'.format(i)),
                     row=1, col=1)

        fig.add_trace(go.Scatter(x=data['epoch'],
                                 y=data['valid_loss'].values,
                                 mode='lines+markers',
                                 visible='legendonly' if i > 0 else True,
                                 line=dict(color=colors[i][1], width=2),
                                 name='Valid loss - Fold #{}'.format(i)),
                     row=1, col=1)

        fig.add_trace(go.Scatter(x=data['epoch'].values,
                                 y=data['valid_score'].values,
                                 mode='lines+markers',
                                 line=dict(color=colors[i][0], width=2),
                                 name='Valid score - Fold #{}'.format(i),
                                 showlegend=False),
                     row=2, col=1)

    fig.update_layout({
      "annotations": [
        {
          "x": 0.225, 
          "y": 1.0, 
          "font": {"size": 16}, 
          "text": "Train / valid losses", 
          "xref": "paper", 
          "yref": "paper", 
          "xanchor": "center", 
          "yanchor": "bottom", 
          "showarrow": False
        }, 
        {
          "x": 0.775, 
          "y": 1.0, 
          "font": {"size": 16}, 
          "text": "Validation scores", 
          "xref": "paper", 
          "yref": "paper", 
          "xanchor": "center", 
          "yanchor": "bottom", 
          "showarrow": False
        }, 
      ]
    })

    fig.show()

val_preds_0 = np.load('./R18_imagenet_v2_val_preds_0.npy')
val_labels_0 = np.load('./R18_imagenet_v2_val_labels_0.npy')

cm = confusion_matrix(val_labels_0, val_preds_0)
print(cm)
plt.figure(figsize=(8,8))
plot_confusion_matrix(cm, classes=class_names, normalize=True)

In [25]:
if CFG.TRAIN == True:
    train_results = pd.DataFrame(train_results)
    print(train_results)
    train_results.to_csv('train_results.csv', index=False)
    best_folds = np.array([train_results[train_results['fold']==x]['valid_score'].max() for x in CFG.FOLD_TO_TRAIN])
    print(f'Overall CV accuracy : {best_folds.mean()}, std: {best_folds.std()}')
    plot_training_results()

    fold  epoch  train_loss  valid_loss  valid_score  \
0      0      0    0.722584    0.502727     0.825701   
1      0      1    0.549066    0.368869     0.873832   
2      0      2    0.504923    0.379628     0.876402   
3      0      3    0.536700    0.424554     0.861215   
4      0      4    0.586780    0.503547     0.826402   
5      0      5    0.552341    0.476545     0.837850   
6      0      6    0.492360    0.365667     0.877336   
7      0      7    0.449103    0.365253     0.878037   
8      0      8    0.488236    0.483387     0.842523   
9      0      9    0.539505    0.452714     0.847430   
10     0     10    0.536633    0.390180     0.877103   
11     0     11    0.480153    0.344321     0.887617   
12     0     12    0.453277    0.346015     0.886682   
13     0     13    0.482147    0.414724     0.866121   
14     0     14    0.517948    0.383010     0.875467   
15     0     15    0.519327    0.371456     0.878972   
16     0     16    0.468210    0.347640     0.88