# Pytorch k-fold cross validation DataLoader.

The purpose of this kernel is to provide an easy to use system for stratified k-fold cross validation with pytorch models.
***

Features:
* Stratified K-fold cross validation support.
* Data augmentation support.
* Test time augmentation support (TTA).

**_If you have any comments or suggestions please don't hesitate to let me know, any feedback is appreciated._**

In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

import torch
from torch.utils.data import Sampler
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import BatchSampler
from torchvision.transforms import transforms
from sklearn.model_selection import StratifiedKFold

In [None]:
class CassavaLeafDataset(Dataset):
    def __init__(self, 
                 root_dir='../input/cassava-leaf-disease-classification/', 
                 image_dir_name='train_images',
                 label_csv_name='train.csv',
                 transform=transforms.ToTensor(),
                 device='cpu'):
        super().__init__()
        self.root_dir = root_dir
        self.image_dir = os.path.join(self.root_dir, image_dir_name)
        self.df_labels = pd.read_csv(os.path.join(self.root_dir, label_csv_name))
        self.transform = transform
        self.device=device
        
    def __getitem__(self, image_id):
        path = os.path.join(self.image_dir, image_id)
        label = np.array(self.df_labels[self.df_labels['image_id'] == image_id]['label'].values[0])
        label = torch.from_numpy(label).to(device)
        image = Image.open(path)
        image = self.transform(image).to(device)
        return (image, label)
    
    def __len__(self):
        return len(self.df_labels)

In [None]:
class ImageSampler(Sampler):
    def __init__(self, 
                 sample_idx,
                 data_source='../input/cassava-leaf-disease-classification/train.csv'):
        super().__init__(data_source)
        self.sample_idx = sample_idx
        self.df_images = pd.read_csv(data_source)
        
    def __iter__(self):
        image_ids = self.df_images['image_id'].loc[self.sample_idx]
        return iter(image_ids)
    
    def __len__(self):
        return len(self.sample_idx)

In [None]:
class ImageBatchSampler(BatchSampler):
    def __init__(self, 
                 sampler,
                 aug_count=5,
                 batch_size=30,
                 drop_last=True):
        super().__init__(sampler, batch_size, drop_last)
        self.aug_count = aug_count
        assert self.batch_size % self.aug_count == 0, 'Batch size must be an integer multiple of the aug_count.'
        
    def __iter__(self):
        batch = []
        
        for image_id in self.sampler:
            for i in range(self.aug_count):
                batch.append(image_id)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch
    
    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

The create_split_loaders function is used to create the two DataLoaders required for each split.

The 'aug_count' parameter will be explained later, ignore it for now.

Each split has (k-1) training folds and 1 validation fold. Two DataLoaders are provided for each split. 

The _train_loader_ is used to iterate through all batches in the training folds of a given split. The _valid_loader_ is used to iterate through all samples in the validation fold of a given split.

The function returns the two DataLoaders for the split in a tuple.

In [None]:
def create_split_loaders(dataset, split, aug_count, batch_size):
    train_folds_idx = split[0]
    valid_folds_idx = split[1]
    train_sampler = ImageSampler(train_folds_idx)
    valid_sampler = ImageSampler(valid_folds_idx)
    train_batch_sampler = ImageBatchSampler(train_sampler, 
                                            aug_count, 
                                            batch_size)
    valid_batch_sampler = ImageBatchSampler(valid_sampler, 
                                            aug_count=1, 
                                            batch_size=batch_size,
                                            drop_last=False)
    train_loader = DataLoader(dataset, batch_sampler=train_batch_sampler)
    valid_loader = DataLoader(dataset, batch_sampler=valid_batch_sampler)
    return (train_loader, valid_loader)    

The get_all_split_loaders function is used to create the DataLoaders for each split.

The 'aug_count' parameter specifies how many variations of each image in the dataset must be provided. 

Random transformations are applied to each image when it is sampled from the dataset. If you would like your model to be trained seeing 5 different variations of each image in the training set, use 'aug_count=5'. 

_Note: Regardless of the value of aug_count all images in the dataset will be seen by the model. The dataset size is N_samples = N_images * aug_count._

The create_split_loaders function is called for each split, this creates a two DataLoaders for each split. One DataLoader for the samples in the training folds and one DataLoader for the samples in the validation fold.

The DataLoaders for each split are stored in a tuple. All tuples are stored in a list which is returned by the function.

In [None]:
def get_all_split_loaders(dataset, cv_splits, aug_count=5, batch_size=30):
    """Create DataLoaders for each split.

    Keyword arguments:
    dataset -- Dataset to sample from.
    cv_splits -- Array containing indices of samples to 
                 be used in each fold for each split.
    aug_count -- Number of variations for each sample in dataset.
    batch_size -- batch size.
    
    """
    split_samplers = []
    
    for i in range(len(cv_splits)):
        split_samplers.append(
            create_split_loaders(dataset,
                                 cv_splits[i], 
                                 aug_count, 
                                 batch_size)
        )
    return split_samplers

### The system can be used as follows:

We begin by loading in the image_ids and their corresponding class labels. This information is stored in the train.csv file.

In [None]:
df_train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')

Use sklearn to split the dataset into stratified folds.

In [None]:
splitter = StratifiedKFold(n_splits=4, shuffle=True, random_state=0)

splits = []
for train_idx, test_idx in splitter.split(df_train['image_id'], df_train['label']):
    splits.append((train_idx, test_idx))

Specify the image transformations to be used for data augmentation. 

Create an instance of the dataset.

In [None]:
transform = transforms.Compose([
    transforms.RandomAffine(degrees=45, 
                           translate=(0.05,0.05),
                           scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.CenterCrop((400, 500)),
    transforms.ToTensor()
])

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dataset = CassavaLeafDataset(transform=transform, device=device)

Create a list containing the DataLoaders for each split:

[(train_loader_split_1, val_loader_split_1), 
 (train_loader_split_2, val_loader_split_2), ... ]

In [None]:
dataloaders = get_all_split_loaders(dataset, splits, aug_count=5, batch_size=10)

This is the training loop for the model.

In [None]:
# train_batch_loader -- batches all samples in training folds.
# valid_batch_loader -- batches all samples in validation fold.
for train_batch_loader, valid_batch_loader in dataloaders:
    # Loop through all batches in training folds for a given split.
    for batch in train_batch_loader:
        # Train model on the training folds in the split.
        break
    
    # Loop through all batches in validation fold for a given split.
    for batch in valid_batch_loader:
        # Test model on the validation fold in the split.   
        break
    break

Example training batch:

batch_size = 10

aug_count = 5

In [None]:
for train_batch_loader, valid_batch_loader in dataloaders:
    for batch in train_batch_loader:
        train_batch = batch
        break
    break

fig, ax = plt.subplots(2,5, figsize=(20, 10))
images = train_batch[0]
for i,image in enumerate(images):
    image = image.transpose(0,2)
    if i < 5:
        ax[0, i].imshow(image)
    else:
        ax[1, i%5].imshow(image)
plt.show()