# Classifier

## Google COLAB Settings
In this section are certain processes that should be run when running the code through Google COLAB so to have access to a GPU. If such is the case, uncomment the sections and run them sequentially. otherwise, feel free to skip directly to [Imports](#Imports).

### Installs

In [None]:
# %%capture
# from os.path import exists
# from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
# platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
# cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
# accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
# !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision
# !pip install livelossplot

### Google Drive
This portion is exclusively for development on _my_ end. I use Google Drive to access the training/testing data without having to redownload it each time the Google COLAB runtime is reset. 

Of course anyone who does not have access to my Google credentials will not be able to access my Drive. As such, these users should skip directly to [Imports](#Imports). The result will be that torchvision will personally download the CIFAR data from the web each time the COLAB runtime is reset.

#### Mounting Drive
This mounts Google Drive to the local runtime. If Drive is already mounted, then of course, it will not try to mount it again. It will of course ask for authentication.

In [None]:
# from google.colab import drive
# drive.mount('/content/gdrive')

#### Importing data from Drive
Here I import the CIFAR data which I have previously downloaded and stored in my Google Drive. To do so I copy the corresponding directory from my Google Drive into the COLAB Runtime to avoid having to redownload it each time my COLAB runtime is reset. The ```-n``` flag is set to avoid overwriting.

In [None]:
# !cp -r -n /content/gdrive/My\ Drive/Education/Undergraduate/Year_3/Computer_Science/SSA/Machine_Learning/Coursework/ML_Classifier-Pegasus-Generator/data /content/

## Imports

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision  # provides specific datasets
import matplotlib.pyplot as plt  # provides plotting capabilities
from livelossplot import PlotLosses  # provides live plotting capabilities

## PyTorch settings

In [None]:
# sets the device for the user.
device = torch.device(
    'cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device being used:", device)

## Functions

In [None]:
def cycle(iterable):
    """Helper function to make getting another batch of data easier"""
    while True:
        for x in iterable:
            yield x

            
def plot_image(i, predictions_array, true_label, img):
    """Plots predicted images"""
    predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(img, cmap=plt.cm.binary)

    predicted_label = np.argmax(predictions_array)
    color = '#335599' if predicted_label == true_label else '#ee4433'

    plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                  100*np.max(predictions_array),
                                  class_names[true_label]),
                                  color=color)

    
def plot_value_array(i, predictions_array, true_label):
    """Plots the value arrays associated with particular predictions"""
    predictions_array, true_label = predictions_array[i], true_label[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    thisplot = plt.bar(range(len(class_names)), predictions_array, color="#777777")
    plt.ylim([0, 1])
    predicted_label = np.argmax(predictions_array)

    thisplot[predicted_label].set_color('#ee4433')
    thisplot[true_label].set_color('#335599')

## Classes

In [None]:
class MyNetwork(nn.Module):
    '''A simple classifier'''
    
    # define initialization
    def __init__(self):
        super(MyNetwork, self).__init__()
        # initialize network layers
        layers = nn.ModuleList()
        # 3 (colors) by 32 (width) by 32 (height) tensor 
        layers.append(nn.Linear(in_features=3 * 32 * 32, out_features=512))
        # rectifier layer
        layers.append(nn.ReLU())
        # 100 class outputs
        layers.append(nn.Linear(in_features=512, out_features=100))
        self.layers = layers

    # define network structure
    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten input as we're using linear layers
        for m in self.layers:
            x = m(x)
        return x

class EarlyStop:
    """Object that checks for early stopping. A class is used to maintain a state"""
    
    def __init__(self, patience):
        """initializes instance variables"""
        self.patience = patience
        self.annoyance = 0
        self.stop = False
        self.min_loss = np.Inf
        self.epoch = 0
        
    def __call__(self, model, epoch, curr_loss):
        """ When the Model is called on a current loss measure for a given model,
            check whether the current loss is greater than at the last checkpoint, 
            if so increase the annoyance until patience runs out at which point we 
            flag this. Record checkpoints as loss improves.
        """
        # if the minimum loss has yet to be measured (base case)
        if self.min_loss == np.Inf:
            # set it to the one currently being checked for
            self.min_loss = curr_loss
            
        # if the miniumum loss is smaller than the current loss
        elif self.min_loss < curr_loss:
            # increase the level
            self.annoyance += 1
            print(f'Current loss is greater than previously measured')
            print(f'Increasing annoyance to {self.annoyance} out of {self.patience}')
            # once patience threshold is met
            if self.annoyance >= patience:
                print(f'Patience limit has been reached -- Early stopping')
                # set a flag that enough is enough
                self.stop = True
        # if the minimum loss is greater than the current loss  
        else:
            print(f'Current loss is smaller than previously measured')
            # reset the minimum loss to the current loss
            self.min_loss = curr_loss
            # record a checkpoint
            self.checkpoint(model, epoch)
            # reset the annoyance level to zero
            print(f'Resetting annoyance to 0')
            self.annoyance = 0
    
    def checkpoint(self, model, epoch):
        """Saves the model for it to be loaded later on. Takes note of the checkpoint epoch"""
        print(f'Setting checkpoint')
        torch.save(model.state_dict(), 'Training/Checkpoints/checkpoint.pt')
        self.epoch = epoch
        

## Dataset Setup

In [None]:
# define list containing class names
class_names = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle',
               'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm', ]

### Transforms

In [None]:
# define train transform (can add augmentation)
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# define test transfrom (no augmentation should be added)
test_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

### Getting the Datasets
The training and validation datasets will both originate from the CIFAR training set. This is because later on they will be split according to some ratio such that for example training is 4/5 of the training set and validation is 1/5 of the training set. This is to perform regularization (such as early stopping) in an isolated manner from the testing dataset, which is in fact loaded from the test set.

The careful reader will notice that despite being loaded from the same set, training and validation datasets still differ as each will have their own transforms, allowing the former for example to be augmented. 

In [None]:
# get the training dataset from training
train_dataset = torchvision.datasets.CIFAR100(root='data', train = True, download = True, transform = train_transform)

# get the testing dataset from testing
test_dataset = torchvision.datasets.CIFAR100(root='data', train = False, download = True, transform = test_transform)

### Loading the Datasets for usage
We've downloaded the data. Now we have to load it for our model to use it.


#### Instantiate data objects

In [None]:
BATCH_SIZE = 16

# instantiate trainingdata loader
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)


# instantiate test data loader with no shuffling
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=BATCH_SIZE, drop_last=True)

# create iterators for later use
train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')

### Viewing Some of Test Dataset

In [None]:
# create pyplot figure
plt.figure(figsize=(10, 10))
# make a 5 by 5 grid of images from test dataset
for i in range(25):
    plt.subplot(5, 5, i + 1)
    # get rid of tickmarks
    plt.xticks([])
    plt.yticks([])
    # get rid of grid
    plt.grid(False)
    # show the image
    plt.imshow(test_loader.dataset[i][0].permute(
        0, 2, 1).contiguous().permute(2, 1, 0), cmap=plt.cm.binary)
    # label
    plt.xlabel(class_names[test_loader.dataset[i][1]])

## Training/Testing

### Set up

In [None]:
# initialize model object
N = MyNetwork().to(device)

print(f'> Number of network parameters {len(torch.nn.utils.parameters_to_vector(N.parameters()))}')

# initialise the optimiser (stochastic gradient descent)
optimiser = torch.optim.SGD(N.parameters(), lr=0.001)
# set start epoch
epoch = 0
# initialize live loss plot object
liveplot = PlotLosses()



# if we want to visualize rather than record
VIS_BOOL = True

# if we are recording the results instead:
if not VIS_BOOL:
    # initialize train record lists
    train_acc = []
    train_loss = []
    # initialize test record lists
    test_acc = []
    test_loss = []
    
# initialize early stopping object
early_stopper = EarlyStop(15)

### Training

In [None]:
# start training process
while (epoch < 100): # training for 100 epochs
    # arrays for epoch metrics
    train_loss_arr = np.zeros(0)
    train_acc_arr = np.zeros(0)
    # same for testing
    test_loss_arr = np.zeros(0)
    test_acc_arr = np.zeros(0)

    # set the model to training mode
    N.train()
    # iterate over the training dateset
    for (x, t) in train_loader:
        # set the data and target tensors to the GPU
        x, t = x.to(device), t.to(device)
        # initialize the gradient to zero
        optimiser.zero_grad()
        # calculate prediction by running input through Neural Network
        p = N(x)
        pred = p.argmax(dim=1, keepdim=True)
        # calculate loss between prediction and target
        loss = torch.nn.functional.cross_entropy(p, t)
        # backpropagate loss
        loss.backward()
        # performs a parameter update (train)
        optimiser.step()  
        # if we want to visualize our results
        if VIS_BOOL:
            # record the loss for this image
            train_loss_arr = np.append(train_loss_arr, loss.cpu().data)
            # record the accuracy for this image
            train_acc_arr = np.append(train_acc_arr, pred.data.eq(
                t.view_as(pred)).float().mean().item())
    
    # set the model to evaluation mode
    N.eval()
    # iterate over the test dataset
    for x, t in test_loader:
        # get input and respective targets from testing dataset
        x, t = x.to(device), t.to(device)
        # calculate prediction by running input through Neural Network
        p = N(x)
        # calculate loss
        loss = torch.nn.functional.cross_entropy(p, t)
        pred = p.argmax(dim=1, keepdim=True)

        # record the loss for this epoch
        test_loss_arr = np.append(test_loss_arr, loss.cpu().data)
        # record the accuracy for this epoch
        test_acc_arr = np.append(test_acc_arr, pred.data.eq(
            t.view_as(pred)).float().mean().item())

    # calculate the overall loss and and accuracy for this epoch
    #training
    epoch_train_acc = train_acc_arr.mean()
    epoch_train_loss = train_loss_arr.mean()
    #testing
    epoch_test_acc = test_acc_arr.mean()
    epoch_test_loss = test_loss_arr.mean()
    
    
        
    # if we want to visualize live
    if VIS_BOOL:
        # draw the training results live
        # NOTE: live plot library has naming forcing our 'test' to be called 'validation'
        liveplot.update({
            'accuracy': epoch_train_acc,
            'val_accuracy': epoch_test_acc,
            'loss': epoch_train_loss,
            'val_loss': epoch_test_loss
        })
        liveplot.draw()
    # if we prefer to record the results for later usage
    else:
        # append to the record arrays
        train_acc.append(epoch_train_acc)
        train_loss.append(epoch_train_loss)
        # same for test
        test_acc.append(epoch_train_acc)
        test_loss.append(epoch_train_loss)
        
        # print to keep track of progress
        print(epoch)
    
    # move on to the next epoch
    epoch += 1

print ("Training completed")

In [None]:
# # load the checkpointed model
# N.load_state_dict(torch.load('Training/Checkpoints/checkpoint.pt'))
# epoch = early_stopper.epoch

## Results

### Inference

In [None]:
test_images, test_labels = next(test_iterator)
test_images, test_labels = test_images.to(device), test_labels.to(device)
# perform inference on the test dataset - use softmax to format this as a sum of probabilities normalized to 1
test_preds = torch.softmax(N(test_images).view(test_images.size(0), len(class_names)), dim=1).data.squeeze().cpu().numpy()


### Plotting Inference

In [None]:
num_rows = 4
num_cols = 4
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
    plt.subplot(num_rows, 2*num_cols, 2*i+1)
    plot_image(i, test_preds, test_labels.cpu(), test_images.cpu().squeeze().permute(1,3,2,0).contiguous().permute(3,2,1,0))
    plt.subplot(num_rows, 2*num_cols, 2*i+2)
    plot_value_array(i, test_preds, test_labels)

### Plotting Test Loss and Accuracy

In [None]:
# if we had previously decided to record the test accuracy and loss throughout
if not VIS_BOOL:
    # get 
    epochs_arr = list(range(1, epoch+1))
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    
    ax1.plot(epochs_arr, test_acc)
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Accuracy")
    ax1.set_title("Test Accuracy over Training Epochs")
    
    ax2.plot(epochs_arr, test_loss)
    ax2.set_xlabel("Epochs")
    ax2.set_ylabel("Accuracy")
    ax2.set_title("Test Loss over Training Epochs")
    
    fig.tight_layout()
else:
    print("You have not previously recorded the test accuracy and loss over time")

### Print Best Accuracy

In [None]:
if not VIS_BOOL:
    print("The value for the test accuracy is {} %".format(test_acc[-1]*100))
else:
    print("You have not previously recorded the test accuracy and loss over time")