<h1> PyTorch Classification with MLP </h1>
Lets see how we can train our first neural network using the Pytorch funcunalities we have previously seen! In this notebook we will be training a Multilayer Perceptron (MLP) with the MNIST dataset. We will see how to use Pytorch inbuilt datasets, how to construct a MLP using the Pytorch nn.module class and how to construct a training and testing loop to perform stochastic gradient descent (SGD).

<img src="../data/MNIST.gif" width="700" align="center">

Animation of MNIST digits and a MLP's activations changing via learning

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data.dataloader as dataloader
import torch.optim as optim

from torchvision import transforms
from torchvision.datasets import MNIST
import torchvision

from tqdm.notebook import trange, tqdm
import numpy as np
import matplotlib.pyplot as plt
import time

<h2> Classification </h2>
In this Notebook we are performing "classification", that is we want our model to predict what group a given input belongs to after being trained with a large number of examples. In our case we are using MNIST, a data-set of small black and white hand written digits. We want our model to predict what digit (0-9) is in the image!

[Classification Explained](https://www.datacamp.com/blog/classification-machine-learning)

<h2> Download the MNIST Train and Test set </h2>
The MNIST dataset is a large database of handwritten digits that is commonly used for training and testing in the field of machine learning, it consists of 60,000 training images and 10,000 testing images as well as the corresponding digit class (0-9) (it has moved out of fashion these days because it is "too easy" to learn though it is still used at times as a "proof of concept").  <br>
Pytorch has constructed a number of "dataset" classes that will automatically download various datasets making it very easy for us to train our models. We will look more closely at using Pytorch datasets in a later lab.

[Pytorch Datasets](https://pytorch.org/vision/stable/datasets.html)

In [None]:
# Size of our mini-batches
batch_size = 256
data_set_root = "../../datasets"

# Create a train and test dataset using the Pytorch MNIST dataloader class
# NOTE: IF YOU DO NOT HAVE THE LATEST VERSION OF torchvision YOU WILL NEED TO DOWNLOAD THE MNIST DATASET
# FIRST AS THE LINK THE OLD PYTORCH MNIST DATASET HAS DOES NOT WORK! 
# SEE BELOW BLOCK OF CODE!

train = MNIST(data_set_root, train=True,  download=True, transform=transforms.ToTensor())
test  = MNIST(data_set_root, train=False, download=True, transform=transforms.ToTensor())

# Using the Pytorch dataloader class and the Pytorch datasets we with create itterable dataloader objects
train_loader = dataloader.DataLoader(train, shuffle=True, batch_size=batch_size, num_workers=0, pin_memory=False) 
test_loader = dataloader.DataLoader(test, shuffle=False, batch_size=batch_size, num_workers=0, pin_memory=False)

# NOTE:num_workers is the number of extra threads the dataloader will spawn to load the data from file, 
# you will rarely need more than 4 
# NOTE!!! ON WINDOWS THERE CAN BE ISSUES WITH HAVING MORE THAN 0 WORKERS!! IF YOUR TRAINING LOOP STALLS AND DOES
# NOTHING SET num_workers TO 0!

# NOTE:pin_memory is only useful if you are training with a GPU!!!! If it is True then the GPU will pre-allocate
# memory for the NEXT batch so the CPU-GPU transfer can be handled by the DMA controller freeing up the CPU

In [None]:
## IF YOU ARE USING AN OLD VERSION OF PYTORCH
# from torchvision.datasets.utils import download_and_extract_archive
# import os

# url_base = 'https://ossci-datasets.s3.amazonaws.com/mnist/'

# resources = [
#     ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
#     ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
#     ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
#     ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")]


# def download():
#     os.makedirs("./data/MNIST/raw", exist_ok=True)

#     # download files
#     for filename, md5 in resources:
#         url = "{}{}".format(url_base, filename)
#         print("Downloading {}".format(url))
        
#         download_and_extract_archive(
#             url, download_root="./data/MNIST/raw",
#             filename=filename,
#             md5=md5
#         )

        
# download()

In [None]:
# Set device to GPU_indx if GPU is avaliable
GPU_indx = 0
device = torch.device(GPU_indx if torch.cuda.is_available() else 'cpu')

<h3> Visualise a few training samples </h3>
Lets visualise that mini-batches that the dataloader gives us

In [None]:
# We can create an itterater using the dataloaders and take a random sample 
images, labels = next(iter(train_loader))
print("The input data shape is :\n", images.shape)
print("The target output data shape is :\n", labels.shape)

We can see that (as specified) our mini-batch is 256. The dataloader has passed us a 4D Tensor as input data, the first dimension (d0) is known as the "batch dimension" (B) the other three are the image dimensions (CxHxW). We can this of this 4D Tensor as a stack of 256, 1 channel, 28x28 images.<br>
The image labels are a 1D Tensor, 1 single scalar value per image (per mini-batch "instance").

In [None]:
# Lets visualise an entire batch of images!
plt.figure(figsize = (20,10))
out = torchvision.utils.make_grid(images, 32)
plt.imshow(out.numpy().transpose((1, 2, 0)))

## Define our Neural Network Model 
We define our model using the torch.nn.Module class

In [None]:
# Lets create a simple MLP network similar to the sine wave approximator
class Simple_MLP(nn.Module):
    def __init__(self, num_classes):
        super(Simple_MLP, self).__init__()
        # We will use 4 linear layers
        # The input to the model is 784 (28x28 - the image size)
        # and the should be num_classes outputs
        
        # TO DO-  hidded layer size 512
        # TO DO-  hidded layer size 256
        # TO DO-  hidded layer size 128
        # TO DO-  output layer size num_classes

    def forward(self, x):
        
        # The data we pass the model is a batch of single channel images
        # with shape BSx1x28x28 we need to flatten it to BSx784
        # To use it in a linear layer
        x = x.view(x.shape[0], -1)
        
        # We will use a relu activation function for this network! (F.relu)
        # NOTE F.relu is the "functional" version of the activation function!
        # nn.ReLU is a class constructor of a "ReLU" object
        
        # These two things are the same for MOST purposes!
        # TO DO, layer and then activation function
        # TO DO, layer and then activation function
        # TO DO, layer and then activation function
        # TO DO, add layer
        return # TO DO

<h3> Create the model and define the Loss and Optimizer</h3>
Since this is a classification task, we will use Cross Entropy Loss. We define our criterion using Cross Entropy Loss 

[torch.nn.CrossEntropyLoss](https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss)

[Killer Combo: Softmax and Cross Entropy by Paolo Perrotta
](https://levelup.gitconnected.com/killer-combo-softmax-and-cross-entropy-5907442f60ba)

Just like in the sine wave approximation, experiment with different optimizers and hyperparameters 

In [None]:
# Create our model
model = # TO DO create a model with 10 classes
#Create our loss function
criterion = # TO DO add the Cross Entropy loss Function
#Define our loss funcition and optimizer
lr = # TO DO
optimizer = # TO DO
# Number of Epochs
n_epochs = # TO DO

# We can print out our model structure
print(model)
# Note: this is only the order in which the layers were defined NOT the path of the forward pass!

<h3> Create a function that will train the network for one epoch </h3>

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, loss_logger):
    for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc="Training", leave=False)):   
        
        # Forward pass of model
        outputs = # TO DO
        
        # Calculate loss
        loss = # TO DO
        
        # Zero gradients
        # TO DO
        
        # Backprop loss
        # TO DO
        
        # Optimization Step
        # TO DO
        
        loss_logger.append(loss.item())

    return model, optimizer, loss_logger

<h3> Create a function that will evaluate our network's performance on the test set </h3>

In [None]:
def test_model(model, test_loader, criterion, loss_logger):
    with torch.no_grad():
        correct_predictions = 0
        total_predictions = 0
        
        for batch_idx, (data, target) in enumerate(tqdm(test_loader, desc="Testing", leave=False)):   
            
            # Forward pass of model
            outputs = # TO DO           
            
            # Calculate the accuracy of the model
            # You'll need to accumulate the accuracy over multiple steps
            
            # TO DO
            # Number of correctly predicted outputs
            
            correct_predictions += # TO DO
            # Total number of predictions made
            total_predictions += # TO DO
            
            # Calculate the loss
            loss = # TO DO
            loss_logger.append(loss.item())
            
        acc = (correct_predictions/total_predictions) * 100.0
        return loss_logger, acc

## Train the model for N epochs
We call our training and testing functions in a loop, while keeping track of the losses and accuracy. 

In [None]:
# Create empty lists for the train/test losses and the test accuracy
train_loss = []
test_loss  = []
test_acc   = []

In [None]:
for i in trange(n_epochs, desc="Epoch", leave=False):
    # Call the trainging function to perform an epoch of training
    model, optimizer, train_loss = # TO DO
    
    # Call the testing function to work out the test loss and accuracy!
    test_loss, acc = # TO DO
    test_acc.append(acc)

print("Final Accuracy: %.2f%%" % acc)

## Visualize Training Data

In [None]:
# Plot the Training and Test losses
# TO DO

In [None]:
# Plot the Test Accuracy
# TO DO