## Reference

Custom Dataset classes in pytorch
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

k-Fold validation pytorch
--
1. https://stackoverflow.com/questions/58996242/cross-validation-for-mnist-dataset-with-pytorch-and-sklearn
2. https://discuss.pytorch.org/t/i-need-help-in-this-k-fold-cross-validation-implementation/90705/5
3. https://github.com/buomsoo-kim/PyTorch-learners-tutorial/blob/master/PyTorch%20Basics/pytorch-datasets-2.ipynb


kFold split sklearn
--
1. sklearn.model_selection.KFold -  normal ordered splits without any shuffle by default. 
2. sklearn.model_selection.StratifiedKFold - tries to preserve the distribution of each class in each set
3. GroupKFold - ensures the group of data is not repeated in any fold; little complex concept
4. RepeatedKFold - repeat kfold n times with different random state each instance

In [1]:
#!pip install -U skorch

## Library imports

In [2]:
# common imports
import os
import random
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
#import math
#import time
#from skimage import io, transform
#from typing import Dict
#from pathlib import Path

# interactive plot libraries
import matplotlib.pyplot as plt
import seaborn as sns
#from plotly.offline import init_notebook_mode, iplot # download_plotlyjs, plot
#import plotly.graph_objs as go
#from plotly.subplots import make_subplots
#init_notebook_mode(connected=True)

# torch imports
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from torchvision.models.resnet import resnet50, resnet18, resnet34, resnet101
import torch.nn.functional as F


# sklearn related imports
# import skorch #sklearn + pytorch functionalitites
from sklearn.model_selection import StratifiedKFold #KFold, 
#from sklearn.model_selection import cross_val_score

#import skorch
#from skorch.callbacks import Checkpoint
#from skorch.callbacks import Freezer
#from skorch.helper import predefined_split
#from skorch import NeuralNetClassifier

## Config files

In [3]:
path_cfg = {'train_img_path': "cassava-leaf-disease-classification/train_images/",
            'train_csv_path': 'cassava-leaf-disease-classification/train.csv',
            'train' : True, 'lr_find' : False, 'validate' : True, 'test' : False}

model_cfg = {'model_architecture': 'resnet18', 'model_name': 'R18_imagenet',
             'init_lr': 1e-4, 'weight_path': '', 'train_epochs':5}

train_cfg = {'batch_size': 16, 'shuffle': False, 'num_workers': 4, 'checkpt_every' : 1 }
valid_cfg = {'batch_size': 16, 'shuffle': False, 'num_workers': 4, 'validate_every' : 1 }
test_cfg  = {'batch_size': 16, 'shuffle': False, 'num_workers': 4}

In [4]:
index_label_map = {
                0: "Cassava Bacterial Blight (CBB)", 
                1: "Cassava Brown Streak Disease (CBSD)",
                2: "Cassava Green Mottle (CGM)", 
                3: "Cassava Mosaic Disease (CMD)", 
                4: "Healthy"
                }

## TODO

- load images into dataset (Dataset class of pytorch maybe)
- split into 5 fold data - scikit learn
- simple network -r18, r50 with last layers changed to 5 lables
- adam optimizer, lr_finder, cross entropy loss
- cv score

## Helper functions

In [5]:
def find_no_of_trainable_params(model):
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    #print(total_trainable_params)
    return total_trainable_params

In [6]:
def get_cv_splits(csv_path, cv_splits=3):
    df = pd.read_csv(csv_path)
    y = df['label'].values
    X = np.zeros(y.shape)
    
    cv_split_fn = StratifiedKFold(n_splits=cv_splits, shuffle=True, random_state=RANDOM_STATE)
    
    cv_split_idx = {}
    for idx, (train_idx, test_idx) in enumerate(cv_split_fn.split(X,y)):
        cv_split_idx['split' + str(idx+1) + '_train'] = train_idx
        cv_split_idx['split' + str(idx+1) + '_test']  = test_idx
    return cv_split_idx

In [7]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
RANDOM_STATE = 42
set_seed(RANDOM_STATE)

## Dataset class

In [8]:
class CassavaDataset(Dataset):
    """Cassave leaf disease detection dataset."""

    def __init__(self, csv_file, root_dir, transform=None, idx_list=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            idx_list (list of ints): select only certain rows from csv 
        """
        self.cassava_leaf_disease = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        if idx_list != None:
            self.cassava_leaf_disease = self.cassava_leaf_disease.iloc[idx_list, :]


    def __len__(self):
        return len(self.cassava_leaf_disease)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.cassava_leaf_disease.iloc[idx, 0])
        image = Image.open(img_name)
        if self.transform != None:
            image = self.transform(image)
        
        label = np.array(self.cassava_leaf_disease.iloc[idx, 1])
        return (image, label)

## Transforms and Dataloader

In [9]:
transforms = transforms.Compose([
    transforms.RandomResizedCrop(300),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

In [10]:
cassava_dataset = CassavaDataset(csv_file=path_cfg['train_csv_path'], root_dir=path_cfg['train_img_path'], 
                                 transform=transforms)

# split to train and validation sets
cv_splits = get_cv_splits(path_cfg['train_csv_path'], cv_splits=3)

# Datasets
train_data = Subset(cassava_dataset, cv_splits['split1_train'])
test_data  = Subset(cassava_dataset, cv_splits['split1_test'])

# Dataloaders
trainloader = DataLoader(train_data, batch_size=train_cfg['batch_size'],shuffle=train_cfg['shuffle'])
validateloader = DataLoader(test_data, batch_size=valid_cfg['batch_size'],shuffle=valid_cfg['shuffle'])

## Pretrained model

In [11]:
output_features = 5
model = models.resnet18(pretrained=True)

# Freeze parameters so we don't backprop through them
for param in model.parameters():
    param.requires_grad = False
    
model.fc = nn.Sequential(nn.Linear(model.fc.in_features, 128), nn.ReLU(), 
                                 nn.Linear(128, output_features), nn.LogSoftmax(dim=1)
                                )
print('Trainable Parameters :', find_no_of_trainable_params(model))
#print(model.model.classifier)

Trainable Parameters : 66309


## Device, loss fn, optimizer

In [12]:
# Use GPU if it's available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device);

# loss function
criterion = nn.NLLLoss()

# Only train the classifier parameters, feature parameters are frozen
optimizer = optim.Adam(model.fc.parameters(), lr=model_cfg['init_lr'])

## Training & validation loops

In [13]:
# load previous weight file
if model_cfg['weight_path'] != '':
    state_dict = torch.load(model_cfg['weight_path'])
    model.load_state_dict(state_dict)

In [14]:
def validate(validate_dataloader):
    valid_it = iter(validate_dataloader)
    progress_bar = tqdm(range(len(validate_dataloader)))
    
    test_loss = 0
    test_accuracy = 0
    model.eval()

    with torch.no_grad():
        for batch_idx in progress_bar: 
            try:
                inputs, labels = next(valid_it)
            except StopIteration:
                valid_it = iter(validate_dataloader)
                inputs, labels = next(valid_it)

            inputs, labels = inputs.to(device), labels.to(device)
            logps = model.forward(inputs)
            batch_loss = criterion(logps, labels)
            test_loss += batch_loss.item()

            # Calculate accuracy
            ps = torch.exp(logps)
            top_p, top_class = ps.topk(1, dim=1)
            equals = top_class == labels.view(*top_class.shape)
            test_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
    
    test_loss = test_loss/len(validate_dataloader)
    test_accuracy = test_accuracy/len(validate_dataloader)
    return test_accuracy, test_loss

In [15]:
if path_cfg['train'] == True:
    results = {}
    results['train_losses'] = []
    #results['train_accuracy'] = []
    results['validate_losses'] = []
    results['validate_accuracy'] = []

    for epoch in range(model_cfg['train_epochs']):
        tr_it = iter(trainloader)
        progress_bar = tqdm(range(len(trainloader)))            
        running_loss = 0.0
        model.train()
        
        for batch_idx in progress_bar:
            try:
                inputs, labels = next(tr_it)
            except StopIteration:
                tr_it = iter(trainloader)
                inputs, labels = next(tr_it)

            # Move input and label tensors to the default device
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            log_ps = model(inputs)
            loss = criterion(log_ps, labels)

            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # store losses
            running_loss += loss.item()

            # print to console
            progress_bar.set_description(f"loss: {loss.item()} loss(avg): {running_loss/(batch_idx+1)}")
        
        results['train_losses'].append(running_loss/len(trainloader))
        
        # save weights periodically
        if (epoch % train_cfg['checkpt_every'] == 0):
            torch.save(model.state_dict(), model_cfg['model_name'] + str(epoch+1) + '_epochs.pth')
        
        # validate periodically
        if (epoch % valid_cfg['validate_every'] == 0):
            val_loss, val_accuracy = validate(validateloader)
            results['validate_losses'].append(val_loss)
            results['validate_accuracy'].append(val_accuracy)

loss: 1.043710708618164 loss(avg): 0.9644037976259608: 100%|██████████| 892/892 [08:07<00:00,  1.83it/s]  
100%|██████████| 446/446 [03:53<00:00,  1.91it/s]
loss: 0.959989070892334 loss(avg): 0.8108083269654902: 100%|██████████| 892/892 [08:06<00:00,  1.83it/s]  
100%|██████████| 446/446 [03:53<00:00,  1.91it/s]
loss: 1.3579087257385254 loss(avg): 0.7732734635405476: 100%|██████████| 892/892 [08:04<00:00,  1.84it/s] 
100%|██████████| 446/446 [03:55<00:00,  1.89it/s]
loss: 1.0391736030578613 loss(avg): 0.7529018251350642: 100%|██████████| 892/892 [08:03<00:00,  1.84it/s] 
100%|██████████| 446/446 [03:52<00:00,  1.92it/s]
loss: 0.8170357942581177 loss(avg): 0.7438279592302616: 100%|██████████| 892/892 [08:03<00:00,  1.85it/s] 
100%|██████████| 446/446 [03:53<00:00,  1.91it/s]


In [17]:
results['validate_accuracy']

[0.8143060549492259,
 0.7539211985709421,
 0.7269521639619707,
 0.7044892492850265,
 0.6967150224163928,
 0.7009488511245882]

In [18]:
results['validate_losses']

[0.6990018110104206,
 0.7246464298445013,
 0.7395653673886183,
 0.7410421697548152,
 0.7464319593168695,
 0.7442329251980033]

In [None]:
results