# Custom Dataset and Data Loader

This document is basically copied from [pytorch-101-building-neural-networks](https://blog.paperspace.com/pytorch-101-building-neural-networks/)

In [3]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
import numpy as np
import pickle
import os
from PIL import Image
import random
import time
import torchvision

Normally, there are two classes PyTorch provides you in relation to build input pipelines to load data.

1. `torch.data.utils.dataset`, which we will just refer as the dataset class now.
2. `torch.data.utils.dataloader` , which we will just refer as the dataloader class now.

`dataset` is a class that loads the data and returns a generator so that you iterate over it. It also lets you incorporate data augmentation techniques into the input Pipeline.

If you want to create a `dataset` object for your data, you need to overload three functions.

1. `__init__` function. Here, you define things related to your dataset here. Most importantly, the location of your data. You can also define various data augmentations you want to apply.
2. `__len__` function. Here, you just return the length of the dataset.
3. `__getitem__` function. The function takes as an argument an index `i` and returns a data example. This function would be called every iteration during our training loop with a different `i` by the `dataset` object.

Here is a implementation of our `dataset` object for the CIFAR dataset.

In [2]:
class Cifar10Dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, data_size, transforms=None):
        files = os.listdir(data_dir)
        files = [os.path.join(data_dir, x) for x in files]
        
        if data_size < 0 or data_size > len(files):
            raise Exception("Data size should be between 0 and {0}".format(len(files)))
            
        if data_size == 0:
            data_size = len(files)
            
        self.data_size = data_size
        self.files = random.sample(files, self.data_size)
        self.transforms = transforms
        
    def __len__(self):
        return self.data_size
    
    def __getitem__(self, index):
        image_address = self.files[index]
        image = Image.open(image_address)
        image = preprocess(image)
        
        