# About this notebook  
This notebook is developed for the Kaggle competition "Cassava Leaf Disease Classification" and referenced [starter code](https://www.kaggle.com/yasufuminakama/cassava-resnext50-32x4d-starter-training/) shared by other competitors.

For the loss functions we referenced [this workbook](https://www.kaggle.com/piantic/train-cassava-starter-using-various-loss-funcs/notebook).

For the Test-Time Augmentation we referenced [this workbook](https://www.kaggle.com/japandata509/ensemble-resnext50-32x4d-efficientnet-0-903).

# Data Loading

In [None]:
import os

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

In [None]:
os.listdir('../input/cassava-leaf-disease-classification')

In [None]:
train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
test = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
val = pd.read_csv('../input/valid/val_data.csv')
label_map = pd.read_json('../input/cassava-leaf-disease-classification/label_num_to_disease_map.json', 
                         orient='index')
display(train.head())
display(test.head())
display(label_map)

# Directory settings

In [None]:
# ====================================================
# Directory settings
# ====================================================
import os

OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

TRAIN_PATH = '../input/cassava-leaf-disease-classification/train_images'
TEST_PATH = '../input/cassava-leaf-disease-classification/test_images'
OUTPUT_PATH = '../input/models/'
VAL_PATH = '../input/valid/'

# Explore Your Data

In [None]:
counts = train.groupby('label').count()
counts

In [None]:
train.label.hist(bins=range(label_map.shape[0]+1))
plt.grid(False)
plt.xlabel('Cassava Leaf Disease Classification Labels')
plt.ylabel('Label Frequency')
plt.title('Cassava Leaf Disease Classification Class Distribution')

The classes are highly skewed, thus, I will apply StratifiedKFold for model selection. Refer to class CFG for the configuration (chosen parameters).

In [None]:
# get min and max from a sample image
file_name = train['image_id'][0]
file_path = f'{TRAIN_PATH}/{file_name}'
image = plt.imread(file_path)
print('Min: {}, Max: {}'.format(image.min(), image.max()))

In [None]:
# show 30 sample images from each class
classes = []
for i in range(5):
    classes.append(train[train['label'] == 0])
    
for i in range(5):
    df = classes[i]
    indexes = np.random.choice(df.shape[0], size=30, replace=False)
    file_names = df['image_id'].values
    plt.figure(figsize = (24, 18))
    plt.suptitle('Label: %d' %i, y = 0.9)
    for j, idx in enumerate(indexes):
        plt.subplot(5, 6, j+1)
        file_name = file_names[idx]
        file_path = f'{TRAIN_PATH}/{file_name}'
        image = plt.imread(file_path)
        plt.imshow(image)

After running the above cells a few times, it can be seen that the training set is fairly noisy. Some of images for healthy leaves (class 4) appears diseased: the leaves are withered, and/or showed a yellowish color with brown spots. Some images did not show leaves at all.

# CFG
Class CFG contains configuration parameters.

In [None]:
# ====================================================
# CFG
# ====================================================
class CFG:
    print_freq=100
    num_workers=4
    size=256 # please change size to 224 when using vit or deit models
    model_name='resnext50_32x4d' # ['resnext50_32x4d', 'tf_efficientnet_b3_ns', 'vit_small_patch16_224', 'vit_base_patch16_224', 'deit_base_patch16_224']
    optimizer='Adam' # ['Adam', 'AdamW', 'AdamP', 'Ranger']
    scheduler='CosineAnnealingWarmRestarts' # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts']
    criterion='LabelSmoothingLoss' # ['CrossEntropyLoss', 'SymmetricCrossEntropyLoss', 'LabelSmoothingLoss', 'FocalLoss', FocalCosineLoss', 'TaylorCrossEntropyLoss']
    epochs=20
    factor=0.2 # ReduceLROnPlateau
    patience=4 # ReduceLROnPlateau
    eps=1e-6 # ReduceLROnPlateau
    T_max=10 # CosineAnnealingLR
    T_0=10 # CosineAnnealingWarmRestarts (# number of iterations before the first restart)
    smoothing=0.5
    lr=1e-4
    min_lr=1e-6
    batch_size=32
    weight_decay=1e-6
    max_grad_norm=1000
    seed=42
    n_fold=5
    trn_fold=[0, 1, 2, 3, 4]
    train=False
    inference=True
    enet_weights=[0, 1, 2, 3, 4]
    enet_name='tf_efficientnet_b3_ns'

# Installation

In [None]:
# !pip install adamp
# !pip install pytorch-ranger

# Library

In [None]:
# ====================================================
# Library
# ====================================================

import math
import time
import random
from contextlib import contextmanager

from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold

from tqdm.auto import tqdm

import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam, AdamW
# from adamp import AdamP
# from pytorch_ranger import Ranger

from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

from albumentations import (
    Compose, Normalize, Resize, RandomResizedCrop, HorizontalFlip, VerticalFlip, ShiftScaleRotate, Transpose
    )
from albumentations.pytorch import ToTensorV2

import sys
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')
import timm

# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Utils

In [None]:
# ====================================================
# Utils
# ====================================================
def get_score(y_true, y_pred):
    return accuracy_score(y_true, y_pred)


@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')


def init_logger(log_file=OUTPUT_DIR+'train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()


def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG.seed)

# Stratified K Folds

In [None]:
folds = train.copy()
Fold = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds['label'])):
    folds.loc[val_index, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)
display(folds.head())

In [None]:
print(folds.groupby(['fold', 'label']).size())

# Dataset

In [None]:
# ====================================================
# Dataset
# ====================================================
# Dataset inherit from torch.utils.data.Dataset which leverage parallel processing and pre-fetching in order reduce data loading time as much as possible
# Dataset object loads data into memory and a DataLoader fetch data from a Dataset and serves the data up in batches
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.labels = df['label'].values
        # f = some_func() is analogus to def f.. 
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TRAIN_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).long()
        return image, label
    

class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        # validation set
#         file_path = f'{TRAIN_PATH}/{file_name}'
        # test submission
        file_path = f'{TEST_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

In [None]:
train_dataset = TrainDataset(train, transform=None)

for i in range(1):
    image, label = train_dataset[i]
    plt.imshow(image)
    plt.title(f'label: {label}')
    plt.show() 

# Data Augmentation and Normalization

In [None]:
# ====================================================
# Transforms
# ====================================================
def get_transforms(*, data):
    """
    Define an data standardization and augmentation pipeline for training and validation set
    """
    if data == 'train':
        return Compose([
            #Resize(CFG.size, CFG.size),
            RandomResizedCrop(CFG.size, CFG.size),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            # PyTorch Tensor is similar to numpy array but can run on GPUs
            ToTensorV2(),
        ])

    elif data == 'valid':
        return Compose([
            Resize(CFG.size, CFG.size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [None]:
train_dataset = TrainDataset(train, transform=get_transforms(data='train'))

for i in range(1):
    image, label = train_dataset[i]
    plt.imshow(image[0])
    plt.title(f'label: {label}')
    plt.show() 

# MODEL

In [None]:
# ====================================================
# MODEL
# ====================================================
# nn.Module serves as base class for all neural network modules
class CustomEfficientNet(nn.Module):
    def __init__(self, model=CFG.model_name, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        # apply a linear transformation to the incoming data from the last convolutional layer to get output of target_size
        self.model.classifier = nn.Linear(n_features, 5)

    # Defines the computation performed at every call (should be overriden by all subclasses)
    def forward(self, x):
        x = self.model(x)
        return x
    
    
class CustomResNext(nn.Module):
    def __init__(self, model=CFG.model_name, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model, pretrained=pretrained)
        n_features = self.model.fc.in_features
        self.model.fc = nn.Linear(n_features, 5)

    def forward(self, x):
        x = self.model(x)
        return x

    
class CustomDeit(nn.Module):
    def __init__(self, model_name=CFG.model_name, pretrained=False):
        super().__init__()
        self.model = torch.hub.load('facebookresearch/deit:main', model_name, pretrained=False)
        self.n_features = self.model.head.in_features
        self.model.head = nn.Linear(self.n_features, 5)

    def forward(self, x):
        x = self.model(x)
        return x

class CustomVit(nn.Module):
    def __init__(self, model_name=CFG.model_name, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        n_features = self.model.head.in_features
        self.model.fc = nn.Linear(n_features, 5)

    def forward(self, x):
        x = self.model(x)
        return x

# Loss Function

In [None]:
class SymmetricCrossEntropy(nn.Module):

    def __init__(self, classes=5, alpha=0.1, beta=1.0):
        super(SymmetricCrossEntropy, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.classes = classes

    def forward(self, logits, targets, reduction='mean'):
        onehot_targets = torch.eye(self.classes)[targets].cuda()
        ce_loss = F.cross_entropy(logits, targets, reduction=reduction)
        rce_loss = (-onehot_targets*logits.softmax(1).clamp(1e-7, 1.0).log()).sum(1)
        if reduction == 'mean':
            rce_loss = rce_loss.mean()
        elif reduction == 'sum':
            rce_loss = rce_loss.sum()
        return self.alpha * ce_loss + self.beta * rce_loss

In [None]:
class LabelSmoothingLoss(nn.Module): 
    def __init__(self, classes=5, smoothing=0.0, dim=-1): 
        super(LabelSmoothingLoss, self).__init__() 
        self.confidence = 1.0 - smoothing 
        self.smoothing = smoothing 
        self.classes = classes 
        self.dim = dim 
    def forward(self, pred, target): 
        pred = pred.log_softmax(dim=self.dim) 
        with torch.no_grad():
            true_dist = torch.zeros_like(pred) 
            true_dist.fill_(self.smoothing / (self.classes - 1)) 
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=5, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self, inputs, targets):
        BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

In [None]:
class FocalCosineLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, xent=.1):
        super(FocalCosineLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.xent = xent
        self.y = torch.Tensor([1]).cuda()
    def forward(self, input, target, reduction="mean"):
        cosine_loss = F.cosine_embedding_loss(input, F.one_hot(target, num_classes=input.size(-1)), self.y, reduction=reduction)
        cent_loss = F.cross_entropy(F.normalize(input), target, reduce=False)
        pt = torch.exp(-cent_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * cent_loss
        if reduction == "mean":
            focal_loss = torch.mean(focal_loss)
        return cosine_loss + self.xent * focal_loss

In [None]:
class TaylorSoftmax(nn.Module):
    '''
    This is the autograd version
    '''
    def __init__(self, dim=1, n=2):
        super(TaylorSoftmax, self).__init__()
        assert n % 2 == 0
        self.dim = dim
        self.n = n

    def forward(self, x):
        '''
        usage similar to nn.Softmax:
            >>> mod = TaylorSoftmax(dim=1, n=4)
            >>> inten = torch.randn(1, 32, 64, 64)
            >>> out = mod(inten)
        '''
        fn = torch.ones_like(x)
        denor = 1.
        for i in range(1, self.n+1):
            denor *= i
            fn = fn + x.pow(i) / denor
        out = fn / fn.sum(dim=self.dim, keepdims=True)
        return out


In [None]:
class TaylorCrossEntropyLoss(nn.Module):
    def __init__(self, n=2, ignore_index=-1, reduction='mean', smoothing=0.0):
        super(TaylorCrossEntropyLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.lab_smooth = LabelSmoothingLoss(5, smoothing=smoothing)

    def forward(self, logits, labels):
        log_probs = self.taylor_softmax(logits).log()
        #loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
        #        ignore_index=self.ignore_index)
        loss = self.lab_smooth(log_probs, labels)
        return loss

# Helper functions

In [None]:
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    """
    Return average loss for training set in an epoch
    """ 
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    
    for step, (images, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        
        # move data to GPU
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        
        # forward pass: compute predicted y by passing x to the model. 
        # Module object overrides the __call__ operator so you can call them like functions
        # When doing so you pass a Tensor of input data to the Module and it produces a Tensor of output data
        y_preds = model(images)
        
        # expect as input raw, unnormalized score for each class
        loss = criterion(y_preds, labels)
        
        # loss.item() return the value of this tensor as a python float
        losses.update(loss.item(), batch_size)
        
        # backward pass: compute gradient of the loss wrt all learnable parameters
        loss.backward()
            
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        
        # call the step function makes an update to its parameters   
        optimizer.step()
        # zero the gradients before running backward pass on a new epoch
        # pytorch accumulate the gradient on subsequent backward passes (convinent for recurrent neural networks)
        optimizer.zero_grad()
        global_step += 1
        
        if CFG.scheduler == 'ReduceLROnPlateau':
            curr_lr = optimizer.param_groups[0]['lr']
        else:
            curr_lr = scheduler.get_last_lr()[0]
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Batch (curr/avg): ({batch_time.val:.3f}s/{batch_time.avg:.3f}s) '
                  'Elapsed {remain:s} '
                  'Loss (curr/avg): ({loss.val:.4f}/{loss.avg:.4f}) '
                  'Grad: {grad_norm:.4f}  '
                  'LR: {lr:.6f}  '
                  .format(
                      epoch+1, step, len(train_loader), 
                      batch_time=batch_time,
                      remain=timeSince(start, float(step+1)/len(train_loader)),
                      loss=losses,
                      grad_norm=grad_norm,
                      lr=curr_lr
                   ))
    return losses.avg


def valid_fn(valid_loader, model, criterion, device):
    """
    Return the average loss and the list of prediction probabilities for validation set in an epoch
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    
    # switch to evaluation mode
    model.eval()
    preds = []
    start = end = time.time()
    
    for step, (images, labels) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)

        # don't compute gradient for validation set (saves memory)
        with torch.no_grad():
            y_preds = model(images)
        loss = criterion(y_preds, labels)
        losses.update(loss.item(), batch_size)
        
        preds.append(y_preds.softmax(1).detach().to('cpu').numpy())
            
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Batch (curr/avg): ({batch_time.val:.3f}s/{batch_time.avg:.3f}s) '
                  'Elapsed {remain:s} '
                  'Loss (curr/avg): ({loss.val:.4f}/{loss.avg:.4f}) '
                  .format(
                      step, len(valid_loader),
                      batch_time=batch_time,
                      remain=timeSince(start, float(step+1)/len(valid_loader)),
                      loss=losses
                   ))
    predictions = np.concatenate(preds)
    return losses.avg, predictions


def get_loaders(folds, fold):
    """
    Return training dataset and data loader, validation dataset and data loader 
    """ 
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)

    train_dataset = TrainDataset(train_folds, 
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds, 
                                 transform=get_transforms(data='valid'))

    # setting pin_memory to True enables fast data transfer to CUDA-enabled GPUs
    # set drop_last to True to disgard last incomplete batch
    train_loader = DataLoader(train_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=True, 
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=False, 
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    return train_folds, train_loader, valid_folds, valid_loader
    

def get_model(pretrained=True):
    """
    Get pretrained model
    """
    if CFG.model_name=='resnext50_32x4d':
        model = CustomResNext(pretrained=pretrained)
    elif  CFG.model_name=='tf_efficientnet_b3_ns':
        model = CustomEfficientNet(pretrained=pretrained)
    elif CFG.model_name=='deit_base_patch16_224':
        model = CustomDeit(pretrained=pretrained)
    elif CFG.model_name=='vit_base_patch16_224' or CFG.model_name=='vit_small_patch16_224':
        model = CustomDeit(pretrained=pretrained)
    return model


def get_optimizer(model_params):
    """
    Get optimizer
    """
    if CFG.optimizer=='Adam':
        optimizer =  Adam(model_params, lr=CFG.lr, weight_decay=CFG.weight_decay)
    elif CFG.optimizer=='AdamW':
        optimizer =  Adam(model_params, lr=CFG.lr, weight_decay=CFG.weight_decay)
    elif CFG.optimizer=='AdamP':
        optimizer =  Adam(model_params, lr=CFG.lr, weight_decay=CFG.weight_decay)
    elif CFG.optimizer=='Ranger':
        optimizer = Ranger(model_params, lr=CFG.lr, weight_decay=CFG.weight_decay)
    return optimizer

    
def get_scheduler(optimizer):
    """
    Get learning rate scheduler
    """
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
    return scheduler


def get_criterion():
    """
    Get loss function
    """
    if CFG.criterion=='CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss()
    elif CFG.criterion=='SymmetricCrossEntropyLoss':
        criterion = SymmetricCrossEntropy()
    elif CFG.criterion=='LabelSmoothingLoss':
        criterion = LabelSmoothingLoss(smoothing=CFG.smoothing)
    elif CFG.criterion=='FocalCosineLoss':
        criterion = FocalCosineLoss()
    elif CFG.criterion=='TaylorCrossEntropyLoss':
        criterion = TaylorCrossEntropyLoss(smoothing=CFG.smoothing)
    return criterion
    
    
def inference(model, states, test_loader, device):
    model.to(device)
    tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images) in tk0:
        images = images.to(device)
        avg_preds = []
        for state in states:
            model.load_state_dict(state['model'])
            model.eval()
            with torch.no_grad():
                y_preds = model(images)
            avg_preds.append(y_preds.softmax(1).to('cpu').numpy())
        avg_preds = np.mean(avg_preds, axis=0)
        probs.append(avg_preds)
    probs = np.concatenate(probs)
    return probs


def inference_tta(model, states, test_loader, device):
    model.to(device)
    tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images) in tk0:
        x = images.to(device)
        x = torch.stack([x, x.flip(-1), x.flip(-2), x.flip(-1,-2),
                         x.transpose(-1,-2), x.transpose(-1,-2).flip(-1),
                         x.transpose(-1,-2).flip(-2),x.transpose(-1,-2).flip(-1,-2)], 0)
        # reshape tension
        x = x.view(-1, 3, CFG.size, CFG.size)
        avg_preds = []
        for state in states:
            model.load_state_dict(state['model'])
            model.eval()
            with torch.no_grad():
                y_preds = model(x)
            y_preds = y_preds.view(1, 8, 5).mean(axis=1)
            avg_preds.append(y_preds.softmax(1).to('cpu').numpy())
        avg_preds = np.mean(avg_preds, axis=0)
        probs.append(avg_preds)
    probs = np.concatenate(probs)
    return probs

# Train loop

In [None]:
# ====================================================
# Train loop
# ====================================================
def train_loop(folds, fold):

    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    train_folds, train_loader, valid_folds, valid_loader = get_loaders(folds, fold)

    # ====================================================
    # model & optimizer
    # ====================================================
    model = get_model(pretrained=True)
    # move model to GPU
    model.to(device)

    optimizer = get_optimizer(model.parameters())
    scheduler = get_scheduler(optimizer)
    criterion = get_criterion()
    
    # ====================================================
    # loop
    # ====================================================

    best_score = 0.
    best_loss = np.inf
    
    for epoch in range(CFG.epochs):
        
        start_time = time.time()
        
        # train
        avg_train_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device)
        
        # eval
        avg_val_loss, val_preds = valid_fn(valid_loader, model, criterion, device)
        valid_labels = valid_folds['label'].values
        
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()

        # scoring
        val_score = get_score(valid_labels, val_preds.argmax(1))

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_train_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - val_accuracy: {val_score:.4f}')

        if val_score > best_score:
            best_score = val_score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(), 
                        'preds': val_preds},
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')
    
    check_point = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')
    valid_folds[[str(c) for c in range(5)]] = check_point['preds']
    valid_folds['preds'] = check_point['preds'].argmax(1)

    return valid_folds

In [None]:
# ====================================================
# main
# ====================================================
def main():

    """
    Prepare: 1.train  2.test  3.submission  4.folds
    """

    def get_result(result_df):
        preds = result_df['preds'].values
        labels = result_df['label'].values
        score = get_score(labels, preds)
        LOGGER.info(f'Score: {score:<.5f}')
    
    if CFG.train:
        # train 
        oof_df = pd.DataFrame()
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                _oof_df = train_loop(folds, fold)
                oof_df = pd.concat([oof_df, _oof_df])
                LOGGER.info(f"========== fold: {fold} result ==========")
                get_result(_oof_df)
        # CV result
        LOGGER.info(f"========== CV ==========")
        get_result(oof_df)
        # save result
        oof_df.to_csv(OUTPUT_DIR+'oof_df.csv', index=False)
    
    if CFG.inference:
        # validation set
#         test_dataset = TestDataset(val, transform=get_transforms(data='valid'))
        
        # test submission
        test_dataset = TestDataset(test, transform=get_transforms(data='valid'))

        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, 
                                 num_workers=CFG.num_workers, pin_memory=True)
 
        # ResNext
        model = CustomResNext(pretrained=False)
        states = [torch.load(OUTPUT_PATH+f'{CFG.model_name}_fold{fold}_best.pth') for fold in CFG.trn_fold]
        predictions = inference(model, states, test_loader, device)
        
        # EfficientNet with TTA
        enet = CustomEfficientNet(model=CFG.enet_name, pretrained=False)
        enet_states = [torch.load(OUTPUT_PATH+f'{CFG.enet_name}_fold{fold}_best.pth') for fold in CFG.enet_weights]
#         predictions_no_tta = inference(enet, enet_states, test_loader, device)
        predictions2 = inference_tta(enet, enet_states, test_loader, device)

        pred = 0.5 * predictions + 0.5 * predictions2
#         valid_labels = val['label'].values
#         ent_score = get_score(valid_labels, predictions_no_tta.argmax(1))
#         print('EfficientNet accuracy: {:.4f}'.format(ent_score))
#         ent_score_tta = get_score(valid_labels, predictions2.argmax(1))
#         print('EfficientNet with TTA accuracy: {:.4f}'.format(ent_score_tta))
#         resnext_score = get_score(valid_labels, predictions.argmax(1))
#         print('ResNext accuracy: {:.4f}'.format(resnext_score))
#         ensemble_score = get_score(valid_labels, pred.argmax(1))
#         print('Ensemble accuracy: {:.4f}'.format(ensemble_score))
        
        # submission
        test['label'] = pred.argmax(1)
        print(test[['image_id', 'label']])
        test[['image_id', 'label']].to_csv(OUTPUT_DIR+'submission.csv', index=False)

In [None]:
if __name__ == '__main__':
    main()