# Custom Datasets

So far, we have only used the MNIST dataset, which is easily accessible through torchvision. What do we do when we have our own data which we want to use with PyTorch?

In this notebook, we will covert our own raw data into PyTorch datasets that can be processed by our PyTorch models

## What should our dataset be able to do?
### ` __getitem__ `
Our dataset should be a set of many examples. We should be able to index it like `my_dataset[3]` to get the example at position 3. The `__getitem__` function defines how the dataset is indexed, it is a function which should return an example datapoint given the example index as an argument. 

`mydataset[2]` is equivalent to `my_dataset.__getitem__(2)`

### `__len__`
The `__len__` function must return the length of the dataset we are loading in.

`len(mydataset)` is equivalent to `my_dataset.__len__()`

### It should also inherit from `torch.utils.data.Dataset`
This just makes sure that we implement everything that we need to so that our dataset will be compatible with other utilities from torch such as the `DataLoader`.

## Dataset 1: The Auto MPG Dataset

This dataset contains 398 examples of cars with 7 numerical features and their corresponding miles per gallon (MPG) as a label.

In [1]:
import torch
from torch.utils.data import Dataset
import pandas as pd

In [1]:
from utils import get_auto_mpg_data

auto_mpg_data = get_auto_mpg_data()
auto_mpg_data.head()

398


Unnamed: 0,mpg,cylinders,displacement,horsepower,weight,acceleration,model year,origin
0,18.0,8.0,307.0,130.0,3504.0,12.0,70.0,1.0
1,15.0,8.0,350.0,165.0,3693.0,11.5,70.0,1.0
2,18.0,8.0,318.0,150.0,3436.0,11.0,70.0,1.0
3,16.0,8.0,304.0,150.0,3433.0,12.0,70.0,1.0
4,17.0,8.0,302.0,140.0,3449.0,10.5,70.0,1.0


In [None]:
class AutoMPGDataset(Dataset):
    def __init__(self):

In [30]:


class S40dataset(Dataset): # create dataset class

    def __init__(self, img_dir='S40-data/images', annotation_dir='S40-data/annotations', transform=None):
        self.img_dir = img_dir # what directory are the images in
        self.annotation_dir = annotation_dir # what directory are the annotations in
        self.transform = transform # what transforms were passed to the initialiser

        self.img_names = os.listdir(img_dir) # list all files in the img folder
        self.img_names.sort() # order the images alphabetically
        self.img_names = [os.path.join(img_dir, img_name) for img_name in self.img_names] # join folder and file names

        self.annotation_names = os.listdir(annotation_dir) # list all annotation files
        self.annotation_names.sort() # order annotation files alphabetically
        self.annotation_names = [os.path.join(annotation_dir, ann_name) for ann_name in self.annotation_names] # join folder and file names

#         print(self.img_names)
#         print(self.annotation_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx] # get the path of the image at that index
        img = Image.open(img_name) # open the image using the path

        annotation_name = self.annotation_names[idx] # get the path to the label file
        annotation_tree = ET.parse(annotation_name) # use xml parser to load the file
        bndbox_xml = annotation_tree.find("object").find("bndbox") # get the tag which contains our labels
        
        # get the x and y values for the corners of the rectangle
        xmax = int(bndbox_xml.find('xmax').text) 
        ymax = int(bndbox_xml.find('ymax').text)
        xmin = int(bndbox_xml.find('xmin').text)
        ymin = int(bndbox_xml.find('ymin').text)
        #print(xmax, ymax, xmin, ymin)

        # Convert from corner co-ordinates format into center co-ordinate, width and height format
        w = xmax - xmin #
        h = ymax - ymin
        x = int(xmin + w / 2)
        y = int(ymin + h / 2)

        # Normlise the labels so the values are expressed as a proportion of the whole image
        x /= img.size[0]
        w /= img.size[0]
        y /= img.size[1]
        h /= img.size[1]

        bndbox = (x, y, w, h) # create tuple of bounding box dimensions
        
        if self.transform: # if any transforms were given to initialiser
            img = self.transform(img) # apply any transforms

        bndbox = torch.tensor(bndbox) # convert bounding box tuple to tensor

        return img, bndbox

    def __len__(self):
        return len(self.img_names)

# Convert from  center co-ordinate, width and height format into corner co-ordinates format
def unpack_bndbox(bndbox, img):
#     bndbox = list(bndbox[0])
    x, y, w, h = tuple(bndbox)
    x *= img.size[0] 
    w *= img.size[0]
    y *= img.size[1]
    h *= img.size[1]
    xmin = x - w / 2
    xmax = x + w / 2
    ymin = y - h / 2
    ymax = y + h / 2
    bndbox = [xmin, ymin, xmax, ymax]
    return bndbox

def show(batch, pred_bndbox=None):
    img, bndbox = batch

#     img = img[0]
    print(img.shape)
    img = transforms.ToPILImage()(img)
    img = transforms.Resize((512, 512))(img)
    draw = ImageDraw.Draw(img)

    bndbox = unpack_bndbox(bndbox, img)
    print(bndbox)
    draw.rectangle(bndbox)
    if pred_bndbox is not None:
        pred_bndbox = unpack_bndbox(pred_bndbox, img)
        draw.rectangle(pred_bndbox, outline=1000)
    img.show()

In [31]:
from torchvision import transforms

myS40 = S40dataset(transform=transforms.ToTensor()) # initialise our dataset and transform each example with a ToTensor transform
print('len dataset:', len(myS40)) # use our __len__ method to show the length of the dataset
example = myS40[0] # use our __getitem__ method to get the first example
show(example)

TypeError: object of type 'S40dataset' has no len()

A common way that we might then use this dataset would be to create a torch `DataLoader` from it.

In [None]:
from torch.utils.data import DataLoader
my_dataloader = # use dataset to create dataloader

## Notebook complete

You should now understand how to create your own dataset classes by inheriting from torch's `Dataset` class and overwriting the `__getitem__` and `__len__` methods.

__Next Steps__

- [CNN Detection](https://github.com/AI-Core/Convolutional-Neural-Networks/blob/master/CNN%20Detection.ipynb) - use this dataset to train a CNN to detect single instances in images