## Reference

Custom Dataset classes in pytorch
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

k-Fold validation pytorch
--
1. https://stackoverflow.com/questions/58996242/cross-validation-for-mnist-dataset-with-pytorch-and-sklearn
2. https://discuss.pytorch.org/t/i-need-help-in-this-k-fold-cross-validation-implementation/90705/5
3. https://github.com/buomsoo-kim/PyTorch-learners-tutorial/blob/master/PyTorch%20Basics/pytorch-datasets-2.ipynb


kFold split sklearn
--
1. sklearn.model_selection.KFold -  normal ordered splits without any shuffle by default. 
2. sklearn.model_selection.StratifiedKFold - tries to preserve the distribution of each class in each set
3. GroupKFold - ensures the group of data is not repeated in any fold; little complex concept
4. RepeatedKFold - repeat kfold n times with different random state each instance

In [None]:
#!pip install -U skorch

## Library imports

In [1]:
# common imports
import os
import random
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
#import math
#import time
#from skimage import io, transform
#from typing import Dict
#from pathlib import Path

# interactive plot libraries
import matplotlib.pyplot as plt
import seaborn as sns
#from plotly.offline import init_notebook_mode, iplot # download_plotlyjs, plot
#import plotly.graph_objs as go
#from plotly.subplots import make_subplots
#init_notebook_mode(connected=True)

# torch imports
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models.resnet import resnet50, resnet18, resnet34, resnet101
import torch.nn.functional as F


# sklearn related imports
# import skorch #sklearn + pytorch functionalitites
from sklearn.model_selection import KFold, StratifiedKFold
from skorch import NeuralNetClassifier
from sklearn.model_selection import cross_val_score

#import skorch
#from skorch.callbacks import Checkpoint
#from skorch.callbacks import Freezer
#from skorch.helper import predefined_split

## Config files

In [2]:
cfg = {
    'train_img_path': "cassava-leaf-disease-classification/train_images/",
    'train_csv_path': 'cassava-leaf-disease-classification/train.csv',
    
    'model_params': {
        'model_architecture': 'resnet18', 'model_name': "R18_pretrain_imagenet",
        'lr': 1e-4, 'weight_path': "", 
        'lr_find' : 0, 'train': 1, 'validate': 0,'test': 0 },

    'train_data_loader': { 'batch_size': 16, 'shuffle': False, 'num_workers': 4 },
    
    'val_data_loader': {'batch_size': 16, 'shuffle': False, 'num_workers': 4 },

    'test_data_loader': {'batch_size': 32, 'shuffle': False, 'num_workers': 4 },

    'train_params': {'train_start_batch_index' : 117001, 'max_num_steps': 11, 'checkpoint_every_n_steps': 5 } }

In [3]:
index_label_map = {
                0: "Cassava Bacterial Blight (CBB)", 
                1: "Cassava Brown Streak Disease (CBSD)",
                2: "Cassava Green Mottle (CGM)", 
                3: "Cassava Mosaic Disease (CMD)", 
                4: "Healthy"
                }

## Helper functions

In [4]:
def find_no_of_trainable_params(model):
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    #print(total_trainable_params)
    return total_trainable_params

In [5]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
RANDOM_STATE = 42
set_seed(RANDOM_STATE)

In [6]:
df = pd.read_csv(cfg['train_csv_path'])

In [7]:
#print(df.columns)
y = df['label'].values
X = np.zeros(y.shape)
skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=RANDOM_STATE)

In [8]:
split_data = {}
for idx, (train_idx, test_idx) in enumerate(skf.split(X,y)):
    print (f'{idx} split , Train idx len = {len(train_idx)}, Test idx len = {len(test_idx)}')
    split_data['split' + str(idx+1) + '_train'] = np.bincount(y[train_idx])
    split_data['split' + str(idx+1) + '_test'] = np.bincount(y[test_idx])

0 split , Train idx len = 14264, Test idx len = 7133
1 split , Train idx len = 14265, Test idx len = 7132
2 split , Train idx len = 14265, Test idx len = 7132


In [9]:
test_df = pd.DataFrame.from_dict(split_data)
test_df.index = test_df.index.map(index_label_map)
test_df

Unnamed: 0,split1_train,split1_test,split2_train,split2_test,split3_train,split3_test
Cassava Bacterial Blight (CBB),724,363,725,362,725,362
Cassava Brown Streak Disease (CBSD),1460,729,1459,730,1459,730
Cassava Green Mottle (CGM),1590,796,1591,795,1591,795
Cassava Mosaic Disease (CMD),8772,4386,8772,4386,8772,4386
Healthy,1718,859,1718,859,1718,859
