# PyTorch datasets and dataloaders

PyTorch provides a bunch of datasets that you can access through the library. Some of them are contained in the **torchvision** library - PyTorch's computer vision library.



In [None]:
import torch
import torchvision import datasets, transforms

# GET THE TRAINING DATASET
train_data = datasets.MNIST(root='MNIST-data',                        # where is the data (going to be) stored
                            transform=transforms.ToTensor(),          # transform the data from a PIL image to a tensor
                            train=True,                               # is this training data?
                            download=True                             # should i download it if it's not already here?
                           )

# GET THE TEST DATASET
test_data = datasets.MNIST(root='MNIST-data',
                           transform=transforms.ToTensor(),
                           train=False,
                          )

import numpy as np
import matplotlib.pyplot as plt
# PRINT THEIR LENGTHS AND VISUALISE AN EXAMPLE
x = train_data[np.random.randint(0, 300)][0]    # get a random example image
plt.imshow(x[0].numpy(),cmap='gray')
plt.show()

## Splitting our data in PyTorch

Like SKLearn, PyTorch provides a utility for splitting datasets which we can use to split the data into training, testing and validation sets.

In [None]:
# FURTHER SPLIT THE TRAINING INTO TRAINING AND VALIDATION
train_data, val_data = torch.utils.data.random_split(train_data, [50000, 10000])    # split into 50K training & 10K validation

## Data Loaders

PyTorch provides a ```DataLoader``` class which prepares our data to be passed through a PyTorch model by doing a few useful things. These include:
- batching our data into minibatches
- shuffling the data
- applying transforms to the data such as converting it from an image (which PyTorch models cannot process) to a torch tensor (which PyTorch models can process)

In [None]:
batch_size = 256

# MAKE TRAINING DATALOADER
train_loader = torch.utils.data.DataLoader( # create a data loader
    train_data, # what dataset should it sample from?
    shuffle=True, # should it shuffle the examples?
    batch_size=batch_size # how large should the batches that it samples be?
)

# MAKE VALIDATION DATALOADER
val_loader = torch.utils.data.DataLoader(
    val_data,
    shuffle=True,
    batch_size=batch_size
)

# MAKE TEST DATALOADER
test_loader = torch.utils.data.DataLoader(
    test_data,
    shuffle=True,
    batch_size=batch_size
)