In [None]:
# import os
# from google.colab import drive

# os.mkdir("/content/input")
# os.chdir("/content/input")

# drive.mount("/content/drive")

# !cp "/content/drive/MyDrive/Colab Notebooks/datasets/cassava-leaf-disease-classification.zip" .
# !unzip cassava-leaf-disease-classification.zip

In [None]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import json
import numpy as np
import PIL.Image as Image
import time
import seaborn as sns

from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split, StratifiedKFold

### Steps:
- Build dataset loader
- Visualize some transformed images
- Define model + choose pretrained backbone
- Define loss + optimization function
- Train
- Test

### Notes:
- Just need 15 epochs for an experiment
- In an experiment, we can:
+ Change initialization methods
+ Change backbone
+ Change augmentation
+ Apply variet methods to prevent overfitting

In [None]:
df = pd.read_csv("../input/cassava-leaf-disease-classification/train.csv")
df.head(5)

In [None]:
ignored_images = []

df = df.loc[~df["image_id"].isin(ignored_images)]

In [None]:
labels = open("../input/cassava-leaf-disease-classification/label_num_to_disease_map.json", "r").read()
labels = json.loads(labels)
labels_name = list(labels.values())
labels

In [None]:
shortcut_labels = [label.split(" ")[-1] for label in labels.values()]
labels_value = df["label"].value_counts().sort_index().values

In [None]:
indices = np.arange(len(shortcut_labels))
plt.bar(indices, labels_value)
plt.xticks(indices, shortcut_labels)
plt.show()

In [None]:
class CassavaDataset(Dataset):
    def __init__(self, df: pd.DataFrame, path, transform=None):
        self.df = df
        self.path = path
        self.transform = transform
    
    def __len__(self):
        return len(self.df.index)
    
    def __getitem__(self, i):
        row = self.df.iloc[i]
        image_id = row["image_id"]
        label = row["label"] if "label" in row else 0
        
        img = cv2.imread(f"{self.path}/{image_id}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            img = self.transform(image=img)["image"]
            
        return img, label, image_id

In [None]:
PARAMS = {
    "TRAINING_BATCH_SIZE": 8,
    "VAL_BATCH_SIZE": 2,
    "IMG_SIZE": 512,
    "NUM_WORKERS": 4,
    "PIN_MEMORY": True,
    "EPOCHS": 15,
    "T_0": 10,
    "LEARNING_RATE": 1e-4,
    "MIN_LEARNING_RATE": 1e-6,
    "MODEL": "tf_efficientnet_b4_ns",
    "WEIGHT_DECAY": 0.0001,
    "LS_EPSILON": 0.2,
    "EFF_SKIPPED_LAYERS": 0,
    "RANDOM_SEED": 719,
    "TOTAL_FOLDS": 5
}

In [None]:
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
from albumentations.pytorch import ToTensorV2

# define config here
PIN_MEMORY = True
dataset_path = "../input/cassava-leaf-disease-classification/train_images"

# TODO: modify a little more
training_transform = Compose([
#     Resize(PARAMS["IMG_SIZE"], PARAMS["IMG_SIZE"]),
    RandomResizedCrop(PARAMS["IMG_SIZE"], PARAMS["IMG_SIZE"]),
    Transpose(p=0.5),
    HorizontalFlip(p=0.5),
    VerticalFlip(p=0.5),
#     ShiftScaleRotate(p=0.5),
#     HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
#     RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
    CoarseDropout(p=0.5),
#     Cutout(p=0.5),
    ToTensorV2(p=1.0),
], p=1.)

val_transform = Compose([
    CenterCrop(PARAMS["IMG_SIZE"], PARAMS["IMG_SIZE"], p=1.),
    Resize(PARAMS["IMG_SIZE"], PARAMS["IMG_SIZE"]),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.)

In [None]:
training_set, val_set = train_test_split(df, test_size=0.2, random_state=PARAMS["RANDOM_SEED"])

training_set = CassavaDataset(training_set, dataset_path, training_transform)
training_loader = DataLoader(
    dataset=training_set,
    batch_size=PARAMS["TRAINING_BATCH_SIZE"],
    num_workers=PARAMS["NUM_WORKERS"],
    pin_memory=PARAMS["PIN_MEMORY"],
    shuffle=True,
    drop_last=True
)

val_set = CassavaDataset(val_set, dataset_path, val_transform)
val_loader = DataLoader(
    dataset=val_set,
    batch_size=PARAMS["VAL_BATCH_SIZE"],
    num_workers=PARAMS["NUM_WORKERS"],
    pin_memory=PARAMS["PIN_MEMORY"],
    shuffle=False,
    drop_last=True
)

In [None]:
nrows = 2
ncols = 5
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(22, 8))

for X, Y, _ in training_loader:
    for i, x in enumerate(X[:(nrows * ncols)]):
        label = labels[str(Y[i].item())]

        row = int(i / ncols)
        col = i % ncols
        ax[row][col].imshow(x.permute(1, 2, 0)) # permute and denormalize
        ax[row][col].set_title(label)
        
    break

In [None]:
device = torch.device("cuda:0" if torch.cuda.device_count() > 0 else "cpu")
device

In [None]:
# freeze layers
# def freeze(model):
#     # To freeze the residual layers
#     for param in model.parameters():
#         param.requires_grad = False

#     for param in model.fc.parameters():
#         param.requires_grad = True

#     return model

# def unfreeze(model):
#     # Unfreeze all layers
#     for param in model.parameters():
#         param.requires_grad = True

#     return model

# resnext50
# finetuned_net = torchvision.models.resnext50_32x4d(pretrained=True)
# finetuned_net.fc = nn.Linear(finetuned_net.fc.in_features, len(labels_name))

# for param in finetuned_net.layer1.parameters():
#     param.requires_grad = False
# for param in finetuned_net.layer2.parameters():
#     param.requires_grad = False
# for param in finetuned_net.layer3.parameters():
#     param.requires_grad = False
# for param in finetuned_net.layer4.parameters():
#     param.requires_grad = True

!pip install timm
# efficientnet b4
import timm

def load_model():
    finetuned_net = timm.create_model(PARAMS["MODEL"], pretrained=True)
    finetuned_net.classifier = nn.Linear(finetuned_net.classifier.in_features, len(labels_name))

    for block in finetuned_net.blocks[:PARAMS["EFF_SKIPPED_LAYERS"]]:
        for param in block.parameters():
            param.requires_grad = False

    return finetuned_net
# TODO: Try different types of initilization

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        
    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
import torch.nn.functional as F

# idea: https://amaarora.github.io/2020/07/18/label-smoothing.html
# from fastai
def linear_combination(x, y, epsilon):
    return epsilon * x + (1 - epsilon) * y


def reduce_loss(loss, reduction='mean'):
    return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss


class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, epsilon: float = 0.1, reduction='mean'):
        super().__init__()
        self.epsilon = epsilon
        self.reduction = reduction

    def forward(self, preds, target):
        n = preds.size()[-1]
        log_preds = F.log_softmax(preds, dim=-1)
        loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
        nll = F.nll_loss(log_preds, target, reduction=self.reduction)
        
        return linear_combination(loss / n, nll, self.epsilon)
    
    
def split_df_into_folds(df, n_splits=5):
    splits = []
    splitter = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=PARAMS["RANDOM_SEED"])
    
    for training_index, val_index in splitter.split(df["image_id"], df["label"]):
        training_set = pd.DataFrame({"image_id": df["image_id"][training_index], "label": df["label"][training_index]})
        training_set = CassavaDataset(training_set, dataset_path, training_transform)
        training_loader = DataLoader(
            dataset=training_set,
            batch_size=PARAMS["TRAINING_BATCH_SIZE"],
            num_workers=PARAMS["NUM_WORKERS"],
            pin_memory=PARAMS["PIN_MEMORY"],
            shuffle=True,
            drop_last=True
        )

        val_set = pd.DataFrame({"image_id": df["image_id"][val_index], "label": df["label"][val_index]})
        val_set = CassavaDataset(val_set, dataset_path, val_transform)
        val_loader = DataLoader(
            dataset=val_set,
            batch_size=PARAMS["VAL_BATCH_SIZE"],
            num_workers=PARAMS["NUM_WORKERS"],
            pin_memory=PARAMS["PIN_MEMORY"],
            shuffle=False,
            drop_last=True
        )

        splits.append([training_loader, val_loader])

    return splits

In [None]:
def calculate_accurate_percent(outputs, labels):
    return torch.sum(torch.argmax(outputs, dim=1) == labels.data), len(outputs)


def get_learning_rate(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
    
    
def train_one_epoch(training_loader, net, criterion, optimizer, device):
    training_phase_loss = []
    training_phase_correct = torch.tensor(0, device=device, dtype=torch.int)
    training_phase_trained = torch.tensor(0, device=device, dtype=torch.int)

    start_time = time.time()
    net.train()
    for (inputs, labels, _) in training_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        training_phase_loss.append(loss.item())
        # print(f"Loss 1: {loss.item()}, loss 2: {loss2.item()}")

        correct, trained = calculate_accurate_percent(outputs, labels)
        training_phase_correct += correct
        training_phase_trained += trained

    training_end_time = time.time() - start_time
    training_loss_avg = sum(training_phase_loss) / len(training_loader)
    training_accurate_percent = training_phase_correct / training_phase_trained

    return training_end_time, training_loss_avg, training_accurate_percent


def val_one_epoch(val_loader, net, criterion, device):
    val_phase_loss = []
    val_phase_correct = torch.tensor(0, device=device, dtype=torch.int)
    val_phase_trained = torch.tensor(0, device=device, dtype=torch.int)
    val_labels = []
    val_preds = []
    
    start_time = time.time()
    net.eval()
    for (inputs, labels, _) in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = net(inputs)

        loss = criterion(outputs, labels)
        val_phase_loss.append(loss.item())

        correct, trained = calculate_accurate_percent(outputs, labels)
        val_phase_correct += correct
        val_phase_trained += trained
        
        val_labels.extend(labels)
        val_preds.extend(torch.argmax(outputs, dim=1))

    val_end_time = time.time() - start_time
    val_loss_avg = sum(val_phase_loss) / len(val_loader)
    val_accurate_percent = val_phase_correct / val_phase_trained
    
    val_labels = [i.cpu() for i in val_labels]
    val_preds = [i.cpu() for i in val_preds]
    
    return val_end_time, val_loss_avg, val_accurate_percent, val_labels, val_preds


def show_heatmap(labels, preds):
    fig, ax = plt.subplots(figsize=(8, 8)) 
    cm = confusion_matrix(labels, preds)
    sns.heatmap(cm, annot=True, cmap="Blues", fmt="g")

    ax.set_xticklabels(shortcut_labels)
    ax.set_yticklabels(shortcut_labels)
    plt.show()
    
    
def train_finetuned_model_no_cv(training_loader, val_loader, epochs, device, use_pretrained=True, is_classifier=True):
    net = load_model()
    net = net.to(device)
    
    if use_pretrained and not is_classifier:
        params_1x = [param for name, param in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
        optimizer = torch.optim.Adam([
            {'params': params_1x},
            {'params': net.fc.parameters(), 'lr': PARAMS["LEARNING_RATE"] * 10}
        ], lr=PARAMS["LEARNING_RATE"], weight_decay=PARAMS["WEIGHT_DECAY"])
    else:
        optimizer = torch.optim.Adam(net.parameters(), lr=PARAMS["LEARNING_RATE"], weight_decay=PARAMS["WEIGHT_DECAY"])

    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=PARAMS["T_0"], T_mult=1, eta_min=PARAMS['MIN_LEARNING_RATE'], last_epoch=-1)
#     early_stopping = EarlyStopping(patience=3, verbose=True)
    criterion = LabelSmoothingCrossEntropy(PARAMS["LS_EPSILON"], reduction='mean')

    statistic = {
        "lr": [],
        "training_loss": [],
        "training_correct_percent": [],
        "val_loss": [],
        "val_correct_percent": [],
    }

    for i in range(1, PARAMS["EPOCHS"] + 1):
        training_end_time, training_loss_avg, training_accurate_percent = train_one_epoch(training_loader, net, criterion, optimizer, device)
        statistic["training_loss"].append(training_loss_avg)
        statistic["training_correct_percent"].append(training_accurate_percent)
        
        val_end_time, val_loss_avg, val_accurate_percent, val_labels, val_preds = val_one_epoch(val_loader, net, criterion, device)
        statistic["val_loss"].append(val_loss_avg)
        statistic["val_correct_percent"].append(val_accurate_percent)

        statistic["lr"].append(optimizer.param_groups[0]["lr"])
        scheduler.step(val_loss_avg)
#         early_stopping(val_loss_avg, net)
        
#         if early_stopping.early_stop:
#             print("Early stopping")
#             break

        # calculate accurate percentage
        print(f"\nEpoch: {i}/{epochs}")
        print("Learning rate: %.8f" % get_learning_rate(optimizer))
        print("Training loss: %.4f" % training_loss_avg)
        print("Training accurate percent: %.4f" % training_accurate_percent)
        print("Training time: %.4f" % training_end_time)
        print("-------")
        print("Validation loss: %.4f" % val_loss_avg)
        print("Validation accurate percent: %.4f" % val_accurate_percent)
        print("Validation time: %.4f" % val_end_time)
        print("=" * 50)
        
    print(f"Saving model")

    torch.save(net.state_dict(), f"model.pth")
    del net, criterion, optimizer, training_loader, val_loader, scheduler
    torch.cuda.empty_cache()
    
    show_heatmap(val_labels, val_preds)
        
    return statistic


def train_finetuned_model_cv(df, device, use_pretrained=True, is_classifier=True):
    # split dataframe into fols
    total_folds = PARAMS["TOTAL_FOLDS"]
    folds = split_df_into_folds(df, total_folds)
    
    # train model by each fold
    for fold_index, (training_loader, val_loader) in enumerate(folds):
        if fold_index != 0:
            continue
            
        # init model
        net = load_model()
        net = net.to(device)
        
        if use_pretrained and not is_classifier:
            params_1x = [param for name, param in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
            optimizer = torch.optim.Adam([
                {'params': params_1x},
                {'params': net.fc.parameters(), 'lr': PARAMS["LEARNING_RATE"] * 10}
            ], lr=PARAMS["LEARNING_RATE"], weight_decay=PARAMS["WEIGHT_DECAY"])
        else:
            optimizer = torch.optim.Adam(net.parameters(), lr=PARAMS["LEARNING_RATE"], weight_decay=PARAMS["WEIGHT_DECAY"])

#         init scheduler + criterion + early_stopping + optimizer
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=PARAMS["T_0"], T_mult=1, eta_min=PARAMS["MIN_LEARNING_RATE"], last_epoch=-1)
#         early_stopping = EarlyStopping(patience=5, verbose=True)
        criterion = LabelSmoothingCrossEntropy(PARAMS["LS_EPSILON"], reduction="mean")
        min_val_loss = 999
        
#         begin training
        for i in range(1, PARAMS["EPOCHS"] + 1):
            training_end_time, training_loss_avg, training_accurate_percent = train_one_epoch(training_loader, net, criterion, optimizer, device)
            
            with torch.no_grad():
                val_end_time, val_loss_avg, val_accurate_percent, val_labels, val_preds = val_one_epoch(val_loader, net, criterion, device)
            
            scheduler.step(val_loss_avg)
#             early_stopping(val_loss_avg, net)

            if min_val_loss > val_loss_avg:
                torch.save(net.state_dict(), f"best_model_{PARAMS['MODEL']}_fold_{fold_index + 1}.pth")
                min_val_loss = val_loss_avg
                
#             if early_stopping.early_stop:
#                 print("Early stopping")
#                 break

            # calculate accurate percentage
            print(f"\nEpoch: {i}/{PARAMS['EPOCHS']}")
            print(f"Fold: {fold_index + 1}/{total_folds}")
            print("Learning rate: %.8f" % get_learning_rate(optimizer))
            print("Training loss: %.4f" % training_loss_avg)
            print("Training accurate percent: %.4f" % training_accurate_percent)
            print("Training time: %.4f" % training_end_time)
            print("-------")
            print("Validation loss: %.4f" % val_loss_avg)
            print("Validation accurate percent: %.4f" % val_accurate_percent)
            print("Validation time: %.4f" % val_end_time)
            print("=" * 50)
            
            torch.save(net.state_dict(), f"model_{PARAMS['MODEL']}_fold_{fold_index + 1}_epoch_{i}.pth")
        # show confusion matrix
            
#         print(f"Saving model fold {fold_index}")
#         torch.save(net.state_dict(), f"model_{PARAMS['MODEL']}_fold_{fold_index + 1}.pth")
        del net, criterion, optimizer, training_loader, val_loader, scheduler
        torch.cuda.empty_cache()
        
        show_heatmap(val_labels, val_preds)

        
# statistic = train_finetuned_model_no_cv(
#     training_loader,
#     val_loader,
#     PARAMS["EPOCHS"],
#     device)

train_finetuned_model_cv(df, device)

In [None]:
# show confusion matrix
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

### V5 changes:
- Replace torch transform by albumentations
- Add weight decay (trivial: 1e-6)
- Add early stopping
- Use similar learning rates for every layers