# Custom Dataset and Data Loader

This document is basically copied from [pytorch-101-building-neural-networks](https://blog.paperspace.com/pytorch-101-building-neural-networks/)

Using following two commands to download cifar-10 dataset.

```python
wget http://pjreddie.com/media/files/cifar.tgz
tar xzf cifar.tgz
```

In [6]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
import numpy as np
import pickle
import os
from PIL import Image
import random
import time
import torchvision

In [10]:
train_data_dir = "../data/cifar/train"
test_data_dir = "../data/cifar/test"
label_data_dir = "../data/cifar/labels.txt"

We now read the labels of the classes present in the CIFAR dataset.

In [11]:
with open(label_data_dir) as label_file:
    labels = label_file.read().split()
    label_mapping = dict(zip(labels, list(range(len(labels)))))

print(label_mapping)

{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


## Input Format

The type of data we are dealing with will dictate what input we use. Generally, in PyTorch, batch is always the first dimension. Since we are dealing with Images here, I will describe the input format required by images.  

The input format for images is `[B C H W]`. Where `B` is the batch size, `C` are the channels, `H` is the height and `W` is the width.

We will be reading images using `PIL` library. Before we write the functionality to load our data, we write a preprocessing function that does the following things.

1. Randomly flip the image upside-down with a probability of 0.5
2. Normalise the image with mean and standard deviation of CIFAR dataset
3. Reshape it from `W  H  C` to `C  H  W`.

In [19]:
def preprocess(image):
    image = np.array(image)
    
    if random.random() > 0.5:
        image = image[::-1,:,:]
    
    cifar_mean = np.array([0.4914, 0.4822, 0.4465]).reshape(1,1,-1)
    cifar_std  = np.array([0.2023, 0.1994, 0.2010]).reshape(1,1,-1)
    image = (image - cifar_mean) / cifar_std
    
    image = image.transpose(2,1,0)
    return image

Normally, there are two classes PyTorch provides you in relation to build input pipelines to load data.

1. `torch.data.utils.dataset`, which we will just refer as the dataset class now.
2. `torch.data.utils.dataloader` , which we will just refer as the dataloader class now.

## torch.utils.data.dataset

`dataset` is a class that loads the data and returns a generator so that you iterate over it. It also lets you incorporate data augmentation techniques into the input Pipeline.

If you want to create a `dataset` object for your data, you need to overload three functions.

1. `__init__` function. Here, you define things related to your dataset here. Most importantly, the location of your data. You can also define various data augmentations you want to apply.
2. `__len__` function. Here, you just return the length of the dataset.
3. `__getitem__` function. The function takes as an argument an index `i` and returns a data example. This function would be called every iteration during our training loop with a different `i` by the `dataset` object.

Here is a implementation of our `dataset` object for the CIFAR dataset.

In [25]:
class Cifar10Dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, data_size=0, transforms=None):
        files = os.listdir(data_dir)
        files = [os.path.join(data_dir, x) for x in files]
        
        if data_size < 0 or data_size > len(files):
            raise Exception("Data size should be between 0 and {0}".format(len(files)))
            
        if data_size == 0:
            data_size = len(files)
            
        self.data_size = data_size
        self.files = random.sample(files, self.data_size)
        self.transforms = transforms
        
    def __len__(self):
        return self.data_size
    
    def __getitem__(self, index):
        image_address = self.files[index]
        image = Image.open(image_address)
        image = preprocess(image)
        label_name = image_address[:-4].split("_")[-1]
        label = label_mapping[label_name]
        
        image = image.astype(np.float32)
        
        if self.transforms:
            image = self.transforms(image)
            
        return image, label

We also use the `__getitem__` function to extract the label for an image encoded in its file name.  

`Dataset` class allows us to incorporate the `lazy data loading` principle. This means instead of loading all data at once into the memory (which could be done by loading all the images in memory in the `__init__` function rather than just image addresses), it only loads a data example whenever it is needed (when `__getitem__ `is called).

When you create an object of the `Dataset` class, you basically can iterate over the object as you would over any python iterable. Each iteration, `__getitem__` with the incremented index `i` as its input argument.

I've passed a `transforms`  argument in the `__init__` function as well. This can be any python function that does data augmentation. While you can do the data augmentation right inside your preprocess code, doing it inside the `__getitem__` is just a matter of taste. There are a plethora of data augmentation libraries that can be used to augment data. For our case, we use `torchvision` library, which provides a lot of pre-built transforms along with the ability to compose them into one bigger transform. 

## torch.utils.data.Dataloader

The `Dataloader` class facilitates
1. Batching of Data
2. Shuffling of Data
3. Loading multiple data at a single time using threads
4. Prefetching, that is, while GPU crunches the current batch, `Dataloader` can load the next batch into memory in meantime. This means GPU doesn't have to wait for the next batch and it speeds up training.

We instantiate a `Dataloader` object with a `Dataset` object. Then, we can iterate over a `Dataloader` object instance just like we did with a `dataset` instance. In addition, we can specify various options (such that `batch size`, `shuffle` and `num_workers`) for controling the looping process.

In [29]:
trainset = Cifar10Dataset(data_dir = train_data_dir, data_size=500, transforms=None)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)


testset = Cifar10Dataset(data_dir = test_data_dir, data_size=500, transforms=None)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True, num_workers=2)

Both the `trainset` and `trainloader` objects are python generator objects which can be iterated over in the following fashion.

```python
for data in trainloader:   # or trainset
	img, label = data
```

However, the `Dataloader` class makes things much more convenient than `Dataset` class.  While on each iteration the `Dataset` class would only return us the output of the `__getitem__` function, `Dataloader` does much more than that:

1. Notice that the `__getitem__` method of `Cifar10Dataset` class returns a numpy array of shape `3 x 32 x 32`. `Dataloader` batches the images into Tensor of shape `B x 3 x 32 x 32`, where `B` is the batch size specfied by `batch_size`.
2. Also notice that while `__getitem__` method outputs a numpy array, `Dataloader` class automatically converts it into a Tensor
3. Even if the `__getitem__` method returns a object which is of non-numerical type, the `Dataloader` class turns it into a list/tuple of size `B`.  Suppose that  `__getitem__` return a string, namely the label string. If we set batch = 128 while instantiating the dataloader, each iteration, Dataloader will give us a tuple of 128 strings.

Add prefetching, multiple threaded loading to above benefits, using `Dataloader` class is preferred almost every time.

In [31]:
for idx, data in enumerate(trainloader):   
    img, label = data
    print("the image sample with index:{0}, image shape:{1}, and label shape:{2}".format(idx, img.shape, label.shape))

the image sample with index:0, image shape:torch.Size([128, 3, 32, 32]), and label shape:torch.Size([128])
the image sample with index:1, image shape:torch.Size([128, 3, 32, 32]), and label shape:torch.Size([128])
the image sample with index:2, image shape:torch.Size([128, 3, 32, 32]), and label shape:torch.Size([128])
the image sample with index:3, image shape:torch.Size([116, 3, 32, 32]), and label shape:torch.Size([116])
