## Reference

Custom Dataset classes in pytorch
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html


Calculating mean and std of custom Dataset

https://discuss.pytorch.org/t/computing-the-mean-and-std-of-dataset/34949/3

https://discuss.pytorch.org/t/about-normalization-using-pre-trained-vgg16-networks/23560

https://forums.fast.ai/t/image-normalization-in-pytorch/7534/7


## Library imports

In [None]:
# common imports
import os
import random
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
#import math
#import time
#from skimage import io, transform
#from typing import Dict
#from pathlib import Path

# interactive plot libraries
import matplotlib.pyplot as plt
import seaborn as sns
#from plotly.offline import init_notebook_mode, iplot # download_plotlyjs, plot
#import plotly.graph_objs as go
#from plotly.subplots import make_subplots
#init_notebook_mode(connected=True)

# torch imports
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from torchvision.models.resnet import resnet50, resnet18, resnet34, resnet101
import torch.nn.functional as F


# sklearn related imports
# import skorch #sklearn + pytorch functionalitites
from sklearn.model_selection import StratifiedKFold #KFold, 
#from sklearn.model_selection import cross_val_score

#import skorch
#from skorch.callbacks import Checkpoint
#from skorch.callbacks import Freezer
#from skorch.helper import predefined_split
#from skorch import NeuralNetClassifier

## Config files

In [None]:
path_cfg = {'train_img_path': "cassava-leaf-disease-classification/train_images/",
            'train_csv_path': 'cassava-leaf-disease-classification/train.csv',
            'train' : True, 'lr_find' : False, 'validate' : True, 'test' : False}

model_cfg = {'model_architecture': 'resnet18', 'model_name': 'R18_imagenet',
             'init_lr': 1e-4, 'weight_path': '', 'train_epochs':5}

train_cfg = {'batch_size': 256, 'shuffle': False, 'num_workers': 4, 'checkpt_every' : 1 }
valid_cfg = {'batch_size': 16, 'shuffle': False, 'num_workers': 4, 'validate_every' : 1 }
test_cfg  = {'batch_size': 16, 'shuffle': False, 'num_workers': 4}

In [None]:
index_label_map = {
                0: "Cassava Bacterial Blight (CBB)", 
                1: "Cassava Brown Streak Disease (CBSD)",
                2: "Cassava Green Mottle (CGM)", 
                3: "Cassava Mosaic Disease (CMD)", 
                4: "Healthy"
                }

## EDA

In [None]:
train_csv = pd.read_csv('cassava-leaf-disease-classification/train.csv')
train_csv['disease'] = train_csv['label'].map(index_label_map);
print(train_csv.shape)
train_csv.head()

In [None]:
_, axes = plt.subplots(1, 1, figsize=(6, 6))
sns.countplot(y='disease', data=train_csv, ax=axes);

In [None]:
print(train_csv['disease'].value_counts(normalize=True))

## Helper functions

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
RANDOM_STATE = 42
set_seed(RANDOM_STATE)

## Dataset class

In [None]:
class CassavaDataset(Dataset):
    """Cassave leaf disease detection dataset."""

    def __init__(self, csv_file, root_dir, transform=None, idx_list=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            idx_list (list of ints): select only certain rows from csv 
        """
        self.cassava_leaf_disease = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        if idx_list != None:
            self.cassava_leaf_disease = self.cassava_leaf_disease.iloc[idx_list, :]


    def __len__(self):
        return len(self.cassava_leaf_disease)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.cassava_leaf_disease.iloc[idx, 0])
        image = Image.open(img_name)
        if self.transform != None:
            image = self.transform(image)
        
        label = np.array(self.cassava_leaf_disease.iloc[idx, 1])
        return (image, label)

## Device 

In [None]:
# Use GPU if it's available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
print(device)

## Transforms and Dataloader

In [None]:
transforms = transforms.Compose([
    transforms.ToTensor()
])

cassava_dataset = CassavaDataset(csv_file=path_cfg['train_csv_path'], root_dir=path_cfg['train_img_path'], 
                                 transform=transforms)

print(f'Length of total Dataset is ', {len(cassava_dataset)})

cassava_dataloader = DataLoader(cassava_dataset, batch_size=train_cfg['batch_size'],shuffle=train_cfg['shuffle'])

In [None]:
pop_mean = []
pop_std0 = []
pop_std1 = []

In [None]:
for idx, (data, _) in enumerate(cassava_dataloader):
    # shape (batch_size, 3, height, width)
    numpy_image = data.numpy()
    #print(numpy_image.shape)
    
    # shape (3,)
    batch_mean = np.mean(numpy_image, axis=(0,2,3))
    batch_std0 = np.std(numpy_image, axis=(0,2,3))
    batch_std1 = np.std(numpy_image, axis=(0,2,3), ddof=1)
    
    if idx % 5 == 0 :
        print(idx)
        
    pop_mean.append(batch_mean)
    pop_std0.append(batch_std0)
    pop_std1.append(batch_std1)

pop_mean = np.array(pop_mean).mean(axis=0)
pop_std0 = np.array(pop_std0).mean(axis=0)
pop_std1 = np.array(pop_std1).mean(axis=0)

print(pop_mean, pop_std0, pop_std1)