**Hello! Welcome to my Cassava Leaf Disease Classification notebook, modeled using PyTorch.** 

In my approach to this Kaggle problem, I (like many others) selected EfficientNet as my model to train and later infer the Cassava leaf disease for each image. More specifically, I elected to use Backbone 4 of the scaled EfficientNet model and used the pretrained TIMM model that was trained on the Noisy Student dataset, due to many forum posts sharing their success by using this configuration.

I would also like to thank the authors of the following notebooks for making their code available to other Kagglers for this competition. I've used code from both notebooks in my notebook, and they were instrumental in helping me get a successful notebook up and running. I've also implemented a custom Loss function in my EfficientNet model that has shown improved performance in the presence of noisy data. Code used to create the custom Loss function was also borrowed from the below notebook. Please go and check out these notebooks if you want some additional helpful and easy-to-follow resources, I would highly recommend them!:
* General code usage: https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-train-amp-aug
* Bi-Tempered Logistic Loss: https://www.kaggle.com/capiru/cassavanet-starter-easy-gpu-tpu-cv-0-9


Now, speaking of Bi-Tempered Logistic Loss, I decided to use this more uncommon loss function as opposed to the popular CrossEntropyLoss() function in PyTorch. As many Kagglers know by now, this competition contains training and testing data that have significantly noisy labels. Bi-Tempered Logistic Loss aims to solve this problem by accounting for the presence of noise while training by using two "temperature" parameters to better generalize. See this Google Blog post for more information on this: https://ai.googleblog.com/2019/08/bi-tempered-logistic-loss-for-training.html

Lastly, as seen from the title, I am also freezing the Batch Normalization layers in EfficientNet. It is common practice to freeze these layers when finetuning a pretrained EfficientNet model (see sources below). However, it is still required to do this manually, so in this notebook I show you how to do accomplish this.
* https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/203594
* https://www.kaggle.com/c/siim-isic-melanoma-classification/discussion/172882


In [None]:
package_paths = [
    '../input/pytorch-image-models/pytorch-image-models-master' #'../input/efficientnet-pytorch-07/efficientnet_pytorch-0.7.0'
]
import sys; 

for pth in package_paths:
    sys.path.append(pth)
    

In [None]:
'''
IMPORTS
'''

import cv2
import torch
import os
from torch import nn
from datetime import datetime
import time
import random
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm


import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import SequentialSampler, RandomSampler, WeightedRandomSampler
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F

import timm

import sklearn
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
from sklearn.model_selection import GroupKFold, StratifiedKFold

In [None]:
'''
Load Data
'''

data = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
data.head()


len(data)    # Number of samples = 21397

In [None]:
'''
CONFIGURATION
'''
config = {
    'seed': 419,
    'img_size': 512,
    'tta':3,
    'num_folds': 5,
    
    # input_size = 3, 380, 380. pool_size = 12, 12.
    # DOC: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/efficientnet.py
    'model_arch':'tf_efficientnet_b4_ns',    
    
    'train_bs':16,
    'valid_bs':32,
    'num_workers': 2,
    'epochs':10,
    'device':'cuda:0',
    
    'T1':0.2,
    'T2':1.0,
    'label_smooth': 0.2,
    
    'lr':1e-4,
    'min_lr':1e-6,
    'T_0': 10,
    'weight_decay':1e-6,
    'ep_patience':4,
    'factor':0.2,
    'num_workers':2,
    #'accum_iter':2,
    'update_on_batch':True,
    'use_wrs':False
    
    
}

In [None]:
'''
LOSS FUNCTIONS
'''

def log_t(u, t):
    """Compute log_t for `u'."""
    if t==1.0:
        return u.log()
    else:
        return (u.pow(1.0 - t) - 1.0) / (1.0 - t)

def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature > 1.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    if t == 1.0:
        return activations.softmax(dim=-1)

    normalization_constants = compute_normalization(activations, t, num_iters)
    return exp_t(activations - normalization_constants, t)

def bi_tempered_logistic_loss(activations,
        labels,
        t1 = config['T1'],
        t2 = config['T2'],
        label_smoothing=config['label_smooth'],
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
            - labels_onehot * log_t(probabilities, t1) \
            - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
            + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim = -1) #sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()

class BiTemperedLogistic(nn.Module):
    def __init__(self, T1 = config['T1'], T2 = config['T2'], LABEL_SMOOTH = config['label_smooth']):
        super().__init__()
        self.T1 = T1
        self.T2 = T2
        self.LABEL_SMOOTH = LABEL_SMOOTH

    def forward(self, logits,labels):
        return bi_tempered_logistic_loss(logits, labels,t1 = self.T1,t2 = self.T2, label_smoothing = self.LABEL_SMOOTH)


In [None]:
'''
AUGMENTATIONS
'''

from albumentations import (
    ShiftScaleRotate, Normalize, Compose, CenterCrop, Resize, HorizontalFlip,
    VerticalFlip, Transpose, RandomResizedCrop, HueSaturationValue, RandomBrightnessContrast,
    CoarseDropout, Cutout
)

from albumentations.pytorch import ToTensorV2

def get_train_transforms():
    return Compose([
            RandomResizedCrop(config['img_size'], config['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            Cutout(p=0.5),
            ToTensorV2(p=1.0),
        ], p=1.0)

def get_valid_transforms():
    return Compose([
                CenterCrop(config['img_size'], config['img_size'], p=1.0),
                Resize(config['img_size'], config['img_size']),
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
                ToTensorV2(p=1.0),
        ], p=1.0)

In [None]:
'''
HELPER FUNCTIONS
'''

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def get_img(path):
    im_bgr = cv2.imread(path)
    #print(np.shape(im_bgr))
    #im_rgb = im_bgr[:, :, ::-1]
    #print(im_rgb)
    im_rgb = cv2.cvtColor(im_bgr, cv2.COLOR_BGR2RGB)
    #print(np.shape(im_rgb))
    return im_rgb

def prep_dataloader(df, train_idx, val_idx, use_wrs, sampler, data_dir):
    
    train_ref =  df.loc[train_idx, :].reset_index(drop=True)
    valid_ref = df.loc[val_idx, :].reset_index(drop=True)
    
    train_ds = LeafDataset(train_ref, data_dir, transforms=get_train_transforms(), include_labels = True)
    valid_ds = LeafDataset(valid_ref, data_dir, transforms=get_valid_transforms())
    
    if use_wrs:
        train_loader = torch.utils.data.DataLoader(
            train_ds,
            batch_size=config['train_bs'],
            drop_last=True,
            shuffle=False,
            num_workers=config['num_workers'],
            sampler=sampler
        )
    else:
        train_loader = torch.utils.data.DataLoader(
            train_ds,
            batch_size=config['train_bs'],
            drop_last=True,
            shuffle=True,
            num_workers=config['num_workers'],
        )
    val_loader = torch.utils.data.DataLoader(
        valid_ds,
        batch_size=config['valid_bs'],
        drop_last=True,
        shuffle=False,
        num_workers=config['num_workers']
    )
    return train_loader, val_loader

In [None]:
'''
DATASET
'''

class LeafDataset(Dataset):
    def __init__(self, df, img_dir, transforms=None, include_labels=True):
        super().__init__()
        self.df = df     #.reset_index(drop=True).copy()
        self.img_dir = img_dir
        self.transforms = transforms
        self.include_labels = include_labels
        
        if include_labels:
            self.labels = self.df['label'].values
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index: int):
        img = get_img("{}/{}".format(self.img_dir, self.df.loc[index]['image_id']))
        if self.transforms:
            img = self.transforms(image=img)['image']
        
        if self.include_labels:
            label = self.labels[index]
            return img, label
        else:
            return img;
        

In [None]:
'''
MODEL
'''

class LeafDiseaseClassifier(nn.Module):
    def __init__(self, model_arch, num_classes, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, num_classes)

    
    def forward(self, x):
        x = self.model(x)
        return x
        
    def freeze_batch_norm(self):
        layers = [mod for mod in self.model.children()]
        for layer in layers:
            if isinstance(layer, nn.BatchNorm2d):
                #print(layer)
                for param in layer.parameters():
                    param.requires_grad = False
                
            elif isinstance(layer, nn.Sequential):
                for seq_layers in layer.children():
                    if isinstance(layer, nn.BatchNorm2d):
                        #print(layer)
                        param.requires_grad = False
    
            

In [None]:
'''
MAIN
'''

if __name__ == '__main__':
    seed_everything(config['seed'])
    
    # Calculate class stats for WeightedRandomSampler
    class_counts = []
    n_classes = data.label.nunique()
    for n in range(n_classes):
        n_count = 0
        for class_ in data['label']:
            if class_ == n:
                n_count += 1
        class_counts.append(n_count)
  
    num_samples = len(data)
    targs = data['label']
    class_weights = [num_samples/class_counts[i] for i in range(n_classes)]
    weights = [class_weights[targs[i]] for i in range(num_samples)]
    sampler = WeightedRandomSampler(torch.DoubleTensor(weights), n_classes)
    
    # Initialize lists for tracking batch and epoch losses
    train_batch_loss = []
    val_batch_loss = []
    mean_val_epoch_loss = []
    epoch_acc = []
    
    folds = StratifiedKFold(n_splits=config['num_folds'], shuffle=True, random_state=config['seed']).split(X=np.zeros(len(data)), y=data.label.values)
    print('Initializing folds...')
    
    # K-Fold Cross Validation Loop
    for fold, (trn_idx, val_idx) in enumerate(folds):
        
        # Already trained network and extracted best model, so break after first fold
        if fold > 0:
            break

        train_loader, val_loader = prep_dataloader(data, trn_idx, val_idx, config['use_wrs'], sampler, data_dir='../input/cassava-leaf-disease-classification/train_images/')
                                                                                                    
        print('Initializing model, fold {} selected.'.format(fold))                                                                                                   
        device = torch.device(config['device'] if torch.cuda.is_available() else "cpu")
        
        model = LeafDiseaseClassifier(config['model_arch'], data.label.nunique(), pretrained=True).to(device)
        model.freeze_batch_norm()   
        
        scaler = GradScaler()
        optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
        #scheduler =  torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='max', patience=config['patience'], verbose=True, factor=config['factor'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=config['T_0'], T_mult=1, eta_min=config['min_lr'], last_epoch=-1)
        
        #loss_tr = nn.CrossEntropyLoss().to(device)
        #loss_val = nn.CrossEntropyLoss().to(device)
        
        loss_tr = BiTemperedLogistic()
        loss_val = BiTemperedLogistic()
        best_acc = 0
        bad_ep_count = 1
        
        for epoch in range(config['epochs']):
            
            # TRAINING LOOP
            model.train()
            t = time.time()
            running_loss = None
            
            pbar = tqdm(enumerate(train_loader), total=len(train_loader))
            for step, trn_batch in pbar:
                trn_x, trn_labels = trn_batch
                trn_x = trn_x.to(device).float()
                trn_labels = trn_labels.to(device).long()
                

                with autocast():
                    trn_preds = model(trn_x)
                    loss = loss_tr(trn_preds, trn_labels)
                    
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                
                if running_loss is None:
                    running_loss = loss.item()
                else:
                    running_loss = running_loss * .99 + loss.item() * .01
                if scheduler is not None and config['update_on_batch']:
                    scheduler.step()    
                if step+1 == len(train_loader):
                    desc = 'Fold {0}, Epoch {1} Train Loss: {2:.4f}'.format(fold, epoch, running_loss)
                    pbar.set_description(desc)
                    
            if scheduler is not None and not config['update_on_batch']:
                scheduler.step()
            
            
            # VALIDATION LOOP
            with torch.no_grad():
                model.eval()
                
                t = time.time()
                loss_sum = 0
                num_samples = 0
                val_preds_list = []
                val_labels_list = []
                
                pbar = tqdm(enumerate(val_loader), total=len(val_loader))
                for step, val_batch in pbar:
                    val_x, val_labels = val_batch
                    val_x = val_x.to(device).float()
                    val_labels = val_labels.to(device).long()
                    
                    val_preds = model(val_x)
                    val_preds_list += [torch.argmax(val_preds, 1).detach().cpu().numpy()]
                    val_labels_list += [val_labels.detach().cpu().numpy()]
                    
                    loss = loss_val(val_preds, val_labels)
                    num_samples += val_labels.shape[0]
                    loss_sum += loss.item() * val_labels.shape[0]
                    val_batch_loss.append(loss_sum/num_samples)
                    
                    if step+1 == len(val_loader):
                        desc = 'Fold {0}, Epoch {1} Val Loss: {2:.4f}'.format(fold, epoch, loss_sum/num_samples)
                        pbar.set_description(desc)
                    
                val_preds_list = np.concatenate(val_preds_list)
                val_labels_list = np.concatenate(val_labels_list)
                acc = (val_preds_list==val_labels_list).mean()

                if scheduler is not None:
                    scheduler.step()
                
                # Checking if it's the best accuracy of this epoch
                if acc > best_acc:
                    best_acc = acc
                    torch.save(model.state_dict(), '{0}_Fold{1}_Epoch{2}_Acc_{3}.pth'.format(config['model_arch'], fold, epoch, round(best_acc, 4)))
                    bad_ep_count = 0

                epoch_acc.append(acc)
                mean_val_epoch_loss.append(sum(val_batch_loss)/len(val_batch_loss))
                bad_ep_count += 1
                
                # Exit early if not improving
                if bad_ep_count >= config['ep_patience']:
                    print('Early stopping due to model not improving for {0} epochs'.format(config['ep_patience']))
                    break

                print('Validated on Fold {0}, epoch {1}; Validation Loss: {2}, Accuracy: {3}. Best Accuracy: {4}'.format(fold, epoch, loss_sum/num_samples, acc, best_acc))
    
    
    ## Plot Results
#    plt.figure()
#    plt.plot(len(train_batch_loss), train_batch_loss)
#    plt.title('Training Loss Per Batch')
#    plt.show()
    
#    plt.figure()
#    plt.plot(len(val_batch_loss), val_batch_loss)
#    plt.title('Validation Loss Per Batch')
#    plt.show()
    
#    plt.figure()
#    plt.plot(len(mean_val_epoch_loss), mean_val_epoch_loss)
#    plt.title('Mean Validtion Loss each Epoch')
#    plt.show()
    
#    plt.figure()
#    plt.plot(len(epoch_acc), epoch_acc)
#    plt.title('Accuracy each Epoch')
#    plt.show()
    
  

In [None]:
del model, optimizer, train_loader, val_loader, scaler, scheduler
torch.cuda.empty_cache()

**And thus concludes my Cassava Leaf Disease Classification notebook.**

I hope you've enjoyed reading this and learned something. This is my first notebook and submission to a Kaggle contest, so please feel free to comment below any suggestions or advice for next time! I'm always looking to learn more and grow. Furthermore, if you have any questions about the code implementation or design decisions, I'd be happy to answer any questions in the comments below.

Lastly, here are the sources I've used for helping me get started writing the code for this notebook. Please give them and this notebook and upvote if you've found this notebook useful:
* General code usage: https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-train-amp-aug
* Bi-Tempered Logistic Loss: https://www.kaggle.com/capiru/cassavanet-starter-easy-gpu-tpu-cv-0-9