# Assignment overview <ignore>
The overarching goal of this assignment is to produce a research report in which you implement, analyse, and discuss various Neural Network techniques. You will be guided through the process of producing this report, which will provide you with experience in report writing that will be useful in any research project you might be involved in later in life.

All of your report, including code and Markdown/text, ***must*** be written up in ***this*** notebook. This is not typical for research, but is solely for the purpose of this assignment. Please make sure you change the title of this file so that XXXXXX is replaced by your candidate number. You can use code cells to write code to implement, train, test, and analyse your NNs, as well as to generate figures to plot data and the results of your experiments. You can use Markdown/text cells to describe and discuss the modelling choices you make, the methods you use, and the experiments you conduct. So that we can mark your reports with greater consistency, please ***do not***:

* rearrange the sequence of cells in this notebook.
* delete any cells, including the ones explaining what you need to do.

If you want to add more code cells, for example to help organise the figures you want to show, then please add them directly after the code cells that have already been provided. 

Please provide verbose comments throughout your code so that it is easy for us to interpret what you are attempting to achieve with your code. Long comments are useful at the beginning of a block of code. Short comments, e.g. to explain the purpose of a new variable, or one of several steps in some analyses, are useful on every few lines of code, if not on every line. Please do not use the code cells for writing extensive sentences/paragraphs that should instead be in the Markdown/text cells.

# Abstract/Introduction (instructions) - 15 MARKS <ignore>
Use the next Markdown/text cell to write a short introduction to your report. This should include:
* a brief description of the topic (image classification) and of the dataset being used (CIFAR10 dataset). (2 MARKS)
* a brief description of how the CIFAR10 dataset has aided the development of neural network techniques, with examples. (3 MARKS)
* a descriptive overview of what the goal of your report is, including what you investigated. (5 MARKS)
* a summary of your major findings. (3 MARKS)
* two or more relevant references. (2 MARKS)

The labelled CIFAR-10 dataset has been used for benchmarking and testing in many exploratory and ground breaking papers related to computer vision and image classification, not least the development of Alexnet [8], Resnet [4] and most recently transformer-for vision architectures [3]. It is fitting, then, to use it to explore some of the fundamental properties of aritificial neural networks (NNs) in this assignment.

The first experiment presented here examines the effect that altering the learning rate (LR) has on training and performance. As well as experimenting with different learning rates, a LR 'scheduler' was designed and its effect on performance analysed through comparison to models with static learning rates.

The second experiment demonstrates the impact of introducing a dropout layer into the arhchitecture of the network. Different dropout rates were trialled and their effects compared to baseline performance during both training and evaluation. The effect of dropout was also tested in a transfer learning context.

The third experiment focuses on analysing gradient flow during back propagation in different architectures. Gradients flow was measured for all layers at the start and end of training for a baseline, dropout and batch-normalised model, with results compared. The overall performance of the model trained with batch normalisation was also compared to that of the others.

# Methodology (instructions) - 55 MARKS <ignore>
Use the next cells in this Methodology section to describe and demonstrate the details of what you did, in practice, for your research. Cite at least two academic papers that support your model choices. The overarching prinicple of writing the Methodology is to ***provide sufficient details for someone to replicate your model and to reproduce your results, without having to resort to your code***. You must include at least these components in the Methodology:
* Data - Decribe the dataset, including how it is divided into training, validation, and test sets. Describe any pre-processing you perform on the data, and explain any advantages or disadvantages to your choice of pre-processing. 
* Architecture - Describe the architecture of your model, including all relevant hyperparameters. The architecture must include 3 convolutional layers followed by two fully connected layers. Include a figure with labels to illustrate the architecture.
* Loss function - Describe the loss function(s) you are using, and explain any advantages or disadvantages there are with respect to the classification task.
* Optimiser - Describe the optimiser(s) you are using, including its hyperparameters, and explain any advantages or disadvantages there are to using that optimser.
* Experiments - Describe how you conducted each experiment, including any changes made to the baseline model that has already been described in the other Methodology sections. Explain the methods used for training the model and for assessing its performance on validation/test data.


## Data (7 MARKS) <ignore>

The CIFAR-10 used consists of 60,000 low resolution (32x32) colour images split into 50,000 training examples and 10,000 testing examples. Each image belongs to one of 10 mutually exclusive classes and is labelled accordingly.

It is conveniently accessable, along with many other benchmarking datasets, via the Pytorch `datasets` method which enables the user to load both training and test data into separate `torch.Dataset` instances extremely easily, and this was the method used here. 
 
The data was standardised during loading so that the pixel values in the 3 input channels had a mean of 0 and and a standard deviation of 1. This ensures the model learns only the informative variation in the data.  

The training instances were split to create a validation set of 5000 samples (with a random seed set for consistency across experiments). The class distribution for each dataset was found to be well balanced (see Fig 1) meaning simple accuracy will be a reliable measure of overall performance.

<figure><center><img src="./classdisttraining.png" width=200><img src="./classdistval.png" width=200><img src="./class dist test.png" width=200><figcaption style="max-width: 600px"> Figure 1. Class distributions across the training, validation, and testing datasets</figcaption></center></figure>

Data Batching for stochastic gradient descent was handled by the `DataLoader` class, which yields samples without replacement from the shuffled dataset.

It was decided that a single train and validation split would be appropriate for this assignment. Cross-validation was discounted as it's key benefit of a more accurate estimation of test performance was not an important consideration here. 

## Architecture (17 MARKS) <ignore>

<figure><center><img src="./baseline_model_diagram.png" width=800><figcaption style="max-width: 600px"> Fig 2. BaselineNet Convolutional Neural Network architecture. </figcaption></center></figure>

<figure><center><img src="./TABLE.PNG" width=600><figcaption style="max-width: 600px"> Table 1: Convolutional Neural Network Architecture</figcaption></center></figure>

The choices for the initial architecture were based on a combination of the assignment brief, initial experimentation, and common practices in the field.

Fig 2. shows the overall arhcitecture of the model, whilst table 1 shows the detail of the convolutional layers.

Filter dimensions of 3x3 were chosen as they have been shown to be effective in capturing local spatial patterns while keeping the number of parameters relatively low. VGG net demonstrated the power of stacked 3x3 filter-based convolutional layers [14], and although they were used in a much deeper network there, that network was also classifying much higher resolution images.

The increasing number of filters in the convolutional layers allows the network to learn progressively more complex and abstract features as the depth increases, and was another property shown to be effective in the VGG network [14]. 

Setting the stride and padding to 1 in the convolutional layers ensured that the spatial resolution was preserved, while preventing information loss at the image edges.

The max pooling layers all have a pool sise of 2x2 and stride of 2. This reduces the spatial dimensions of the network, reducing the number of parameters, but also provides a form of translation invariance because the exact position of a feature within the pooling window becomes less important.

The choice of sise for the fully connected layer was a balance between the capacity requirements of the model and the number of paramaters that could realistically be trained over numerous runs during the experiments. As `fc1` takes as its input the <lt>$1024$</lt> activations from the flattened convolutional layer before, the weights of this layer are <lt>$1024*d$</lt> where $d$ is the dimensionality of `fc1`  The final value of <lt>$64$</lt> was a good compromise.

ReLU was chosen for the non-linear activation throughout for the same reasons it is often chosen, namely its ability to avoid vanishing gradients owing to the fact it does not saturate as other activation do, and so it avoids values close to 0 on differentiation during back-propogation.

## Loss function (3 MARKS) <ignore>

The loss function used for each experiment was cross-entropy loss, implimented using the `nn.CrossEntropyLoss` class from Pytorch [9].

It is widely used in classification problems such as this where the target variable is binomial or miultinomial. 

It works by first transforming the raw logits of the output layer into what is a effectively a probability distribution via the softmax activation function. Where <lt>$C$</lt> is the number of classes, it outputs is  $C$-dimensional vector of real numbers in the range (0, 1) that sum to 1.

To calculate the loss this distribution is compared to a one-hot encoded version of the true class label. This acts as a target probability distribution for the comparison and the cross entropy loss calculation essentially quantifies the difference.

By minimizing the average cross-entropy loss over all training samples, the model learns to assign high probabilities to the correct class and low probabilities to the incorrect ones.

## Optimiser (4 MARKS) <ignore>

The optmiser used to handle parameter updates and impliment gradient descent was stochastic gradient descent (SGD), implimented using the `optim.SGD` class from Pytorch. 

SGD estimates the true gradient of the loss with respect to the paramaters of the model by calulating the gradient of a small subset of the training data (a mini-batch) and updating the parameters of the model with this approximate gradient, weighted by a LR which - in this approach - is fixed, and is a user defined hyperparamater that can be tuned. 

This process is repeated for multiple mini-batch samples taken from the training data without replacement (until the entire data set has been seen - representing an 'epoch' of training) and then repeated until a stopping criterion is met - in this case a set number of epochs.

Mathematically, the estimated gradient for a mini-batch of sise $B$ sampled from the training data is computed as:
<lt>$$\nabla_\theta L(\theta_t) \approx \frac{1}{B} \sum_{i=1}^{B} \nabla_\theta L(\theta_t; x_i, y_i)$$</lt>
where <lt>$(x_i, y_i)$</lt> represents the <lt>$i$</lt>-th example in the mini-batch.

Although a number of more sophiticated optimisers are available SGD was chosen to make analysing the impact of LR on performance straightforward and transparent. With SGD paramaters are directly updated based only on the gradient and the learning rate. By keeping to this very direct forumlation it easier to understand and interpret the impact of the LR on the model's performance, which is the focus of the experiment. 

## Experiments <ignore>
### Experiment 1 (8 MARKS)

1.1
A range of learning rates (LRs) were explored to determine the upper and lower bounds for stable learning. The extremes were found to be 0.15 and 0.001. Based on these findings, five LRs were selected for comparison: 0.1, 0.075, 0.05, 0.025, and 0.01. 

For each LR, five trials were conducted using set random seeds for consistency. Models were trained using mini-batch stochastic gradient descent, and the loss and accuracy were recorded for each batch and averaged across epochs. 

After each training epoch, the model was evaluated against the validation data, and at the end of training the model was tested against the test dataset. The test scores for each LR were obtained by averaging the scores of the five models instantiated for that LR.

All results were plotted and saved for analysis below.

1.2
In the second part of the experiment, various approaches to LR scheduling were explored in an attempt to find a decay that accorded with findings from 1.1. 

The most suitable function and decay rate was found to be 'inverse time decay' with a decay rate of 0.25. This function modifies the LR over the epochs according to the equation <lt>$\alpha_t = \frac{\alpha_0}{1 + kt}$</lt>, where <lt>$\alpha_t$</lt> is the LR at time step <lt>$t$</lt>, <lt>$\alpha_0$</lt> is the initial learning rate, <lt>$k$</lt> is the decay rate, and <lt>$t$</lt> is the current time step or iteration. 

A model was then trained using this LR decay function, which was applied every epoch, and the results were gathered and plotted.

### Experiment 2 (8 MARKS) <ignore>

2.1
The original training data was re-split into two halves to create new training and validation datasets of 25,000 each. 

A new model was defined, incorporating dropout in the fully connected layers. Dropout was applied only to the activations of the first fully connected layer ('fc1'), as applying dropout to the output layer or CNN activations is generally not recommended. 

The set of dropout rates for experimentation was defined as 0, 0.2, 0.4, 0.6, and 0.8. 

The same approach as in experiment 1.1 was taken for training and validation, with 5 trials carried out for each dropout rate using consistent seeding for model initialisation and all results plotted and stored.

2.2
The performance of the best performing model from experiment 1 was compared with two others: 
i) a model pretrained on the original data without dropout, then retrained on the new data, and 
ii) a model pretrained on the original data with dropout, then retrained on the new data. In both cases, the retraining was partial and involved transfer learning, where pretrained models had some weights 'frozen' while others were reinitialised and made trainable on the new data.

Transfer learning was implemented by initializing two models, one with dropout and one without, and training them as in previous experiments, iterating over 5 random seeds and gathering performance data. 

The final instance of each model was saved to store the trained model weights. 
The validation and training datasets were then swapped, the models were loaded, and all layers except the fully connected layers were frozen. 

The fully connected layers were manually re-initialised and subjected to training. These models were then trained on the new, swapped data as in previous experiments. By the end of this process, the models had effectively been exposed to two slightly differently distributed datasets in their different layers. Their performance during retraining on training, validation, and testing data was recorded and visualised for comparison.

### Experiment 3 (8 MARKS) <ignore>

3.1, 3.2, 3.3
These experiments investigate the flow of gradient flow in different networks. 

A new model definition implimenting batch normalisation was defined using the inbuilt Python `nn.BatchNorm` method applied to all except for the last layer. 

The process for exploring gradient flow in of the baseline, dropout and batch normalised models was the same and was as follows. 

The gradient for each layer in each model was gathered and averaged across the first 5 episodes and the last 5 episodes during training. PyTorch conveniently makes these values accessible as a property of the model, and all that was required was to collect and calculate the averages for each layer across the correct episodes.

The original data split was re-instigated, and a fixed LR of 0.05 was selected and training for this experiment was carried out over 30 epochs and proceeded as in previous experiment but rather than gathering performance data, the gradient data was collected as described above. 

3.4
Finally, a batch normalised model was trained on the original data for 50 epochs 5 times as in previous experiments, with performance on the original training, validation and tes datasets recorded and plotted. It was compared and analysed in relation to other models performance.


In [None]:
############################################
### Code for building the baseline model ###
############################################


# relevant imports

import torch
import torch.nn as nn
import torch.nn.functional as F # as per convention

class BaselineNet(nn.Module):
    def __init__(self):
        super().__init__()
        # max pool layers - not strictly needed to be seperate instances but helps with reference to the diagram
        self.pool1 = nn.MaxPool2d(kernel_sise=2, stride=2)
        self.pool2 = nn.MaxPool2d(kernel_sise=2, stride=2)
        self.pool3 = nn.MaxPool2d(kernel_sise=2, stride=2)

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_sise=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_sise=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_sise=3, stride=1, padding=1)

        self.fc1 = nn.Linear(in_features=64 * 4 * 4, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Results (instructions) - 55 MARKS <ignore>
Use the Results section to summarise your findings from the experiments. For each experiment, use the Markdown/text cell to describe and explain your results, and use the code cell (and additional code cells if necessary) to conduct the experiment and produce figures to show your results.

### Experiment 1 (17 MARKS) <ignore>

1.1 
<figure><center><img src="./lrchaos.png" width=700><img src="./tranval_no_learn.png" width=700><figcaption style="max-width: 600px"> Fig 1. Showing behavioural extremes for different learning rates: unstable learning at a LR of 0.2, and minimal learning at a LR of 0.001 </figcaption></center></figure>
Initial experimentation established reasonable limits for selecting learning rates (LRs). Rates of 0.15 and above led to erratic behavior, while extremely low rates prevented learning (Fig. 1).
<figure><center><img src="./lr1.png" width=700><img src="./lr2.png" width=700><img src="./lr3.png" width=700><img src="./lr4.png" width=700><img src="./lr5.png" width=700><figcaption style="max-width: 600px"> Fig 2. Performance plots showing individual and averaged training and validation losses and accuracies for models trained with descending LRs across 50 epochs of training </figcaption></center></figure>
<figure><center><img src="./smoothed loss accuracy.png" width=700><figcaption style="max-width: 600px"> Fig 3. Smoothed averaged results for accuracies and losses across 50 epochs on validation data for models trained with different learning rates</figcaption></center></figure>
Fig. 2 and Fig. 3 show the performance of different LRs. As LRs get smaller, the generalisation gap between training and validation loss and accuracy develops more slowly and is less extreme. Higher LRs fit the training data more quickly but also overfit more quickly. The impact of the LR is also seen in the volatility of the training loss, which is markedly lower at lower learning rates.
<figure><center><img src="./leraning rates test performance.PNG" width=300><figcaption style="max-width: 600px"> Fig 4. Test set performance of models trained with different LRs highlighting the best result for each metric in green</figcaption></center></figure>
For unseen data, lower LRs led to reduced validation loss at the end of 50 epochs due to slower fitting and less overfitting, but also lower accuracy in test and validation (Figs. 3 & 4). This contrasted with the quicker rise to high accuracy for higher learning rates, followed by a plateauing and gradual decline.

1.2
<figure><center><img src="./lr_scheculer experiments.png" width=350><figcaption style="max-width: 600px"> Fig 5. Different LR decay schedules affect on the active LR across 50 epochs </figcaption></center></figure>
Different LR decay approaches and their effects on the LR can be seen in Fig. 5. The smooth inverse time function with a decay rate of 0.25 had an ideal smooth reduction in LR and performed well relative to the others.
<figure><center><img src="./LR SCHEDULER final results.png" width=700><figcaption style="max-width: 600px"> Fig 6. Performance over 50 epochs of training for model trained with LR scheduler </figcaption></center></figure>
<figure><center><img src="./results accuracy camparison lr and scheduler.png" width=350><figcaption style="max-width: 600px"> Fig 7. Comparison of performance across training of model trained with a LR scheduler, and the best performing model without a scheduler (LR of 0.05)</figcaption></center></figure>
Fig. 7 compares a model using this scheduler with one using a static LR. There is a slight improvement in overall performance with a scheduler, with the most substantial difference being the stability of validation loss and accuracies despite overfitting to the training data. The LR scheduled model's validation accuracy (Fig 6) stabilises in a way that did not occur with other models that previously saturated at close to 100% training accuracy, likely because of the tiny LRs in later epochs.
<figure><center><img src="./lr decay comparison.PNG" width=300><figcaption style="max-width: 600px"> Fig 8. Comparison of test results between a model trained with a LR scheduler, and the best performing model trained without a scheduler (LR of 0.05) highlighting the best result for each metric in green</figcaption></center></figure>
The model achieving 100% on the training set is also notable, as none of the previous models did. This is likely due to high static LRs being too coarse to hone in on a particular point in parameter space and low LRs being unable to traverse the loss landscape effectively, possibly getting trapped in suboptimal local minima.

In [None]:
#############################
### code for Experiment 1 ###
#############################

# Utility functions that are used here and in all other experiments are included at the bottom of this cell. 
# This choice was made so the experiment code came first to help with readability
# it does mean some function calls show as undefined, but they are defined below


# imports 
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import math

# use GPU where available
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")


# EXPERIMENT 1.1 ------------- LRs -------------

# DATA LOADING AND SPLITTING

# set seed for data split
torch.manual_seed(0)

# create transform object so conversion to Tensor and normalising carried out on data download (functionality as part of torchvision.datasets method)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalise((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# get the data - 'train' boolean specifies whether to get training or test data
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform)

# set value for validation split (10% validation)
num_validation_samples = 5000
num_train_samples = len(train_data) - num_validation_samples

# split training data
train_data, val_data = random_split(train_data, [num_train_samples, num_validation_samples])

# confirm split number
print(len(train_data)) # 50000 training egs  
print(len(val_data)) # 10000 test egs
print(len(test_data)) # 10000 test egs

# set batch side for initialising dataloaders intialise for different datasets
batch_sise = 64
train_dataloader = DataLoader(train_data, batch_sise=batch_sise, shuffle=True)
val_dataloader = DataLoader(val_data, batch_sise=batch_sise, shuffle=True)
test_dataloader = DataLoader(test_data, batch_sise=batch_sise, shuffle=False)


# RUNNING TRAINING AND VALIDATION

num_epochs = 50
random_seeds = list(range(1, 6))

# learning rates to experiment with
learning_rates_for_experiment = [0.1, 0.075, 0.05, 0.025, 0.01]

# initialise dictionary for storing data for saving to JSON
averaged_results = {lr:{} for lr in learning_rates_for_experiment}
path_to_save = f'./run_data/learning_rates/FINAL.json'
path_to_load = f'./run_data/learning_rates/FINAL.json'
save_experiment = True

# iterate over LRs to be tested
for learning_rate in learning_rates_for_experiment:
    
    # initialise empty lists for collecting data for each LRs (over the 5 runs)
    epoch_train_losses_by_run = []
    epoch_val_losses_by_run = []
    epoch_train_accuracies_by_run = []
    epoch_val_accuracies_by_run = []
    test_losses = []
    test_accuracies = []
    reports = []
    
    # 5 random seeds = 5 different runs for each learning rate
    for random_seed in random_seeds:
        # set seed prior to initialising model (as used for initial weights as well as any dropout layers)
        torch.manual_seed(random_seed)
        # initialise model, criterion and optimiser
        model = BaselineNet().to(device)
        criterion = nn.CrossEntropyLoss()
        optimiser = optim.SGD(model.parameters(), lr=learning_rate)
        
        # key function call which actually trains the model and returns all of the run data (see definition below)
        model, train_epoch_losses, train_epoch_accuracy, val_epoch_losses, val_epoch_accuracy, _,_ = run_training_and_validation(model, device, learning_rate, num_epochs, criterion, optimiser, train_dataloader, val_dataloader, manual_lr_schedule=False, plot=True)
        
        # store data
        epoch_train_losses_by_run.append(train_epoch_losses)
        epoch_val_losses_by_run.append(val_epoch_losses)
        epoch_train_accuracies_by_run.append(train_epoch_accuracy)
        epoch_val_accuracies_by_run.append(val_epoch_accuracy)
        
        # run testing
        test_loss, test_accuracy, report = run_testing(model, device, criterion, test_dataloader)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        reports.append(report)
    
    # calulate average stats for learning rate
    average_train_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_train_losses_by_run)]
    average_val_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_val_losses_by_run)]
    average_train_accuracies = [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_train_accuracies_by_run)]
    average_val_accuracies =  [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_val_accuracies_by_run)]
    
    average_test_loss = sum(test_losses)/len(test_losses)
    average_test_accuracy = sum(test_accuracies)/len(test_accuracies)
    
    # accumulate data into dictionary for saving to JSON
    averaged_results[learning_rate] = {'seeds':random_seeds,
                                       'av_train_losses': average_train_losses,
                                       'av_val_losses': average_val_losses,
                                       'av_train_acc': average_train_accuracies,
                                       'av_val_acc': average_val_accuracies,
                                       'all_train_losses':epoch_train_losses_by_run,
                                       'all_val_losses': epoch_val_losses_by_run,
                                       'all_train_accuracies': epoch_train_accuracies_by_run,
                                       'all_val_accuracies': epoch_val_accuracies_by_run,
                                       'all_test_losses':test_losses, 
                                       'all_test_accuracies':test_accuracies,
                                       'av_test_loss': average_test_loss,
                                       'av_test_accuracy':average_test_accuracy}
    plot_single_train_val_smoothed(average_train_losses,average_val_losses,average_train_accuracies,average_val_accuracies, num_epochs, smoothing_window=5, title=f'lr: {learning_rate}')

# save for future plotting/analysis
if save_experiment:
    with open(path_to_save, 'w') as file:
        json.dump(averaged_results, file, indent=4)  # 'indent' makes the output formatted and easier to read


# PLOTTING

plot_all_models_performance_from_disk(path_to_load, enforce_axis=True)
plot_performance_comparison_from_file(path_to_load, enforce_axis=True)
display_accuracy_heatmap(path_to_load)


# EXPERIMENT 1.2 ------------- LR SCHEDULER -------------

# INVESTIGATE LR DECAY

# exploring different learning_rate decay approaches and plotting them to see how the LRwill actually behave across 50 epochs
# helper function deining different basic learning rate decay functions
def adjust_learning_rate(epoch, initial_lr, decay_type, decay_rate=0.1, decay_interval=10):
    if decay_type == 'inverse_time':
        new_lr = initial_lr / (1 + decay_rate * epoch)
    elif decay_type == 'exponential':
        new_lr = initial_lr * (math.e ** (-1 * decay_rate * epoch))
    elif decay_type == 'step':
        num_decays = epoch / decay_interval
        new_lr = initial_lr * (decay_rate ** num_decays)
    return new_lr

# plotting for visualising LR decay
def plot_learning_rate_decay(num_epochs, initial_lr, decay_functions):
    fig, axs = plt.subplots(len(decay_functions), figsise=(8, 4 * len(decay_functions)))
    if len(decay_functions) == 1:
        axs = [axs]
    
    for i, (decay_type, decay_rate, decay_interval) in enumerate(decay_functions):
        lr_values = [adjust_learning_rate(epoch, initial_lr, decay_type, decay_rate, decay_interval) for epoch in range(num_epochs)]
        
        if decay_type == 'step':
            title = f'Decay Function: {decay_type}, Decay Rate: {decay_rate}, Decay Interval: {decay_interval}'
        else:
            title = f'Decay Function: {decay_type}, Decay Rate: {decay_rate}'
        
        axs[i].plot(range(num_epochs), lr_values)
        axs[i].set_title(title)
        axs[i].set_xlabel('Epoch')
        axs[i].set_ylabel('Learning Rate')
    
    plt.tight_layout()
    plt.show()

num_epochs = 50
initial_lr = 0.1

decay_functions = [
    ('inverse_time', 0.1, 0),
    ('inverse_time', 0.05, 0),
    ('step', 0.5, 10),
    ('step', 0.1, 5),
    ('exponential', 0.25, 0),
    ('exponential', 0.1, 0)
]

# plot the different decays and how LR looks with different values 
plot_learning_rate_decay(num_epochs, initial_lr, decay_functions)


# RUN TRAINING AND VALIDATION WITH LRDECAY

# implimenting the LR decay shceduler that best fit what I wanted to happen
# creating function that will be passed in to the training function to be applied after evey epoch
def adjust_initial_learning_rate(optimiser, epoch, initial_lr=0.1, decay_rate=0.25):    
    new_lr = initial_lr / (1 + decay_rate *epoch)
    for param_group in optimiser.param_groups:
        param_group['lr'] = new_lr
    print('LR:',new_lr)
    return optimiser


num_epochs = 50

initial_learning_rate = 0.1
decay_rate = 0.25

random_seeds = list(range(1, 6))

averaged_results = {decay_rate:{}}
path_to_save = f'./run_data/lr_decay/final_decaying_lr_initial_lr_{initial_learning_rate}_decay_{decay_rate}.json'
path_to_load = f'./run_data/lr_decay/final_decaying_lr_initial_lr_{initial_learning_rate}_decay_{decay_rate}.json'

save_experiment = True

epoch_train_losses_by_run = []
epoch_val_losses_by_run = []
epoch_train_accuracies_by_run = []
epoch_val_accuracies_by_run = []
test_losses = []
test_accuracies = []
reports = []
    
for random_seed in random_seeds:
    print('DECAY: ', decay_rate)
    print('seed:', random_seed)
    torch.manual_seed(random_seed)

    model = BaselineNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimiser = optim.SGD(model.parameters(), lr=initial_learning_rate)

    # note manual_lr_schedule is set to True and the decay function is also passed
    model,train_epoch_losses, train_epoch_accuracy, val_epoch_losses, val_epoch_accuracy, train_report,val_report = run_training_and_validation(model, device, initial_learning_rate, num_epochs, criterion, optimiser, train_dataloader, val_dataloader, manual_lr_schedule=True, scheduler_func=adjust_initial_learning_rate, plot=True)
    epoch_train_losses_by_run.append(train_epoch_losses)
    epoch_val_losses_by_run.append(val_epoch_losses)
    epoch_train_accuracies_by_run.append(train_epoch_accuracy)
    epoch_val_accuracies_by_run.append(val_epoch_accuracy)
    
    test_loss, test_accuracy, report = run_testing(model, device, criterion, test_dataloader)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
    reports.append(report)

    
    average_train_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_train_losses_by_run)]
    average_val_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_val_losses_by_run)]
    average_train_accuracies = [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_train_accuracies_by_run)]
    average_val_accuracies =  [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_val_accuracies_by_run)]
    average_test_loss = sum(test_losses)/len(test_losses)
    average_test_accuracy = sum(test_accuracies)/len(test_accuracies)
    
    averaged_results[decay_rate] = {'seeds':random_seeds,
                                       'av_train_losses': average_train_losses,
                                       'av_val_losses': average_val_losses,
                                       'av_train_acc': average_train_accuracies,
                                       'av_val_acc': average_val_accuracies,
                                       'all_train_losses':epoch_train_losses_by_run,
                                       'all_val_losses': epoch_val_losses_by_run,
                                       'all_train_accuracies': epoch_train_accuracies_by_run,
                                       'all_val_accuracies': epoch_val_accuracies_by_run,
                                       'all_test_losses':test_losses, 
                                       'all_test_accuracies':test_accuracies,
                                       'av_test_loss': average_test_loss,
                                       'av_test_accuracy':average_test_accuracy}
    
    plot_single_train_val_smoothed(average_train_losses,average_val_losses,average_train_accuracies,average_val_accuracies, num_epochs, smoothing_window=3, title=f'LR: {initial_learning_rate}, DECAY: {decay_rate}')
    
if save_experiment:
    with open(path_to_save, 'w') as file:
        json.dump(averaged_results, file, indent=4)  # 'indent' makes the output formatted and easier to read

# PLOTTING
lr_decay_data = path_to_load
plot_all_models_performance_from_disk(lr_decay_data, enforce_axis=True)
plot_performance_comparison_from_file(lr_decay_data, enforce_axis=True)
display_accuracy_heatmap(lr_decay_data)


# ---------UTILITY FUNCTIONS USED ACROSS ALL EXPERIMENTS---------------------------------------------

# These functions comprised a utils.py file during development

# MODEL RELATED:

def run_training_and_validation(model, device, initial_lr, num_epochs, criterion, optimiser, train_dataloader, val_dataloader, metrics = False, manual_lr_schedule = False, scheduler_func=None, plot = False):

    # key function which performs training and validation of a model for params and data. 
    
    # returns all of the data gathered from the training and validation run organised by epoch. Optional params added during development to accomodate different experiments (eg lr_scheduling)
    
    # optional metrics and plot paramaters allow for plotting as well as generation of classification report used for analysis of results
    
    # when plotting, includes a call to plot_single_train_val_smoothed() util function defined below
    # when training includes a call to the get_accuracy() function below

    train_epoch_losses = []
    train_epoch_accuracy = []
    val_epoch_losses = []
    val_epoch_accuracy = []
    
    for epoch in range(num_epochs):
        train_running_batch_losses = []
        train_running_batch_accuracy = []
        
        if epoch == num_epochs-1:
            train_all_preds = []
            train_all_labels = []
            val_all_preds = []
            val_all_labels = []
        
        if manual_lr_schedule:
            optimiser = scheduler_func(optimiser, epoch, initial_lr)

        model.train()
        for i, (images, labels) in enumerate(train_dataloader):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            
            accuracy = get_accuracy(outputs, labels)
            
            loss.backward()
            optimiser.step()
            optimiser.zero_grad()

            train_running_batch_losses.append(loss.item())
            train_running_batch_accuracy.append(accuracy)
            # if i % 50 == 0:
            #   training_progress_bar.set_description(f'Training Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_dataloader)}], Loss: {loss.item():.4f}, Acc: {accuracy:.4f}')
            
            if epoch == num_epochs-1:
                _, preds = torch.max(outputs, 1)
                train_all_preds.extend(preds.cpu().numpy())  # Move predictions to CPU and convert to numpy for sklearn
                train_all_labels.extend(labels.cpu().numpy())  # Move labels to CPU and convert to numpy

        train_epoch_losses.append(sum(train_running_batch_losses)/len(train_running_batch_losses))
        train_epoch_accuracy.append(sum(train_running_batch_accuracy)/len(train_running_batch_accuracy))
        model.eval()
        with torch.no_grad():
            val_running_batch_losses = []
            val_running_batch_accuracy = []

            for i, (images, labels) in enumerate(val_dataloader):
                images = images.to(device)
                labels = labels.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                accuracy = get_accuracy(outputs, labels)

                val_running_batch_losses.append(loss.item())
                val_running_batch_accuracy.append(accuracy)
                # if i % 20 == 0:
                #   val_progress_bar.set_description(f'Validation Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(val_dataloader)}], Loss: {loss.item():.4f}, Acc: {accuracy:.4f}')
                
                if epoch == num_epochs-1:
                    _, preds = torch.max(outputs, 1)
                    val_all_preds.extend(preds.cpu().numpy())  
                    val_all_labels.extend(labels.cpu().numpy())

            val_epoch_losses.append(sum(val_running_batch_losses)/len(val_running_batch_losses))
            val_epoch_accuracy.append(sum(val_running_batch_accuracy)/len(val_running_batch_accuracy))
            print(f'Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_epoch_losses[epoch]:.4f}, Acc: {train_epoch_accuracy[epoch]:.4f} | Val Loss: {val_epoch_losses[epoch]:.4f}, Acc: {val_epoch_accuracy[epoch]:.4f}')
            class_names = ['plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck']
            
    if plot:
        plot_single_train_val_smoothed(train_epoch_losses, val_epoch_losses, train_epoch_accuracy, val_epoch_accuracy, num_epochs, smoothing_window=10, title=f'single run lr={initial_lr}, decay={manual_lr_schedule}')
    
    if metrics:
        train_report = classification_report(train_all_labels, train_all_preds, target_names=(class_names))
        val_report = classification_report(val_all_labels, val_all_preds, target_names=(class_names))
        # print('FINAL EPOCH TRAINING SUMMARY:')
        # print(train_report)
        # print('FINAL EPOCH VALIDATION SUMMARY:')
        # print(val_report)
        
        return (model,train_epoch_losses, train_epoch_accuracy, val_epoch_losses, val_epoch_accuracy, train_report,val_report)
    else:
        return (model, train_epoch_losses, train_epoch_accuracy, val_epoch_losses, val_epoch_accuracy, 0,0)

def get_accuracy(logits, targets):
    
        # key function used in all training and valdation and testing runs to calculate the accuracy of predictions made by a model using.
        
        # takes in logits (raw output scores from the model) and targets (actual class labels) and returns a float representing the accuracy of the predictions.

        # get the indices of the maximum value of all elements in the input tensor (which are the predicted class labels)
        _, predicted_labels = torch.max(logits, 1)
        
        # calculate the number of correctly predicted labels.
        correct_predictions = (predicted_labels == targets).sum().item()
        
        # calculate the accuracy.
        accuracy = correct_predictions / targets.sise(0)
        
        return accuracy

def run_testing(model, device, criterion, test_dataloader):
    # this function was used to test trained models on the test dataset
    # its returns loss accuracy and the classification report for analysis
    model.eval()
    with torch.no_grad():
        test_running_batch_losses = []
        test_running_batch_accuracy = []
        test_all_preds = []
        test_all_labels = []

        for i, (images, labels) in enumerate(test_dataloader):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            accuracy = get_accuracy(outputs, labels)

            test_running_batch_losses.append(loss.item())
            test_running_batch_accuracy.append(accuracy)
            # test_progress_bar.set_description(f'testidation Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(test_dataloader)}], Loss: {loss.item():.4f}, Acc: {accuracy:.4f}')
            _, preds = torch.max(outputs, 1)
            test_all_preds.extend(preds.cpu().numpy())  # Move predictions to CPU and convert to numpy for sklearn
            test_all_labels.extend(labels.cpu().numpy())  # Move labels to CPU and convert to numpy

    test_loss = sum(test_running_batch_losses)/len(test_running_batch_losses)
    test_accuracy = sum(test_running_batch_accuracy)/len(test_running_batch_accuracy)

    report = classification_report(test_all_labels, test_all_preds, target_names=(['plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck']))
    print(report)
    return test_loss, test_accuracy, report

# PLOTTING/VISUALISING RELATED :

def plot_single_train_val_smoothed(train_epoch_losses, val_epoch_losses, train_epoch_accuracy, val_epoch_accuracy, num_epochs, smoothing_window=5, title=None):
    
    # function used in many contexts to plot training and validation losses and accuracies of a single run
    # takes in the values returned from a single run of training and validation and plots them 
    # smoothing param allows for clearer picture of the progress during validation especially as it can be volatile 
    
    # convert lists to pandas Series
    train_epoch_losses_series = pd.Series(train_epoch_losses)
    val_epoch_losses_series = pd.Series(val_epoch_losses)
    train_epoch_accuracy_series = pd.Series(train_epoch_accuracy)
    val_epoch_accuracy_series = pd.Series(val_epoch_accuracy)

    # calculate moving averages using the provided smoothing window
    smooth_train_epoch_losses = train_epoch_losses_series.rolling(window=smoothing_window).mean()
    smooth_val_epoch_losses = val_epoch_losses_series.rolling(window=smoothing_window).mean()
    smooth_train_epoch_accuracy = train_epoch_accuracy_series.rolling(window=smoothing_window).mean()
    smooth_val_epoch_accuracy = val_epoch_accuracy_series.rolling(window=smoothing_window).mean()

    fig, ax = plt.subplots(1, 2, figsise=(14, 5))

    # Plot training and validation loss with moving averages
    ax[0].plot(train_epoch_losses, label='Training Loss', alpha=0.3)
    ax[0].plot(val_epoch_losses, label='Validation Loss', alpha=0.3)
    ax[0].plot(smooth_train_epoch_losses, label='Smoothed Training Loss', color='blue')
    ax[0].plot(smooth_val_epoch_losses, label='Smoothed Validation Loss', color='orange')
    ax[0].set_xlabel('Epochs')
    ax[0].set_ylabel('Loss')
    ax[0].set_title('Training and Validation Loss')
    ax[0].legend()

    # Set x-axis to show each epoch as a tick
    ax[1].set_xticks(range(0, num_epochs + 1, 10))

    # Plot training and validation accuracy with moving averages
    ax[1].plot(train_epoch_accuracy, label='Training Accuracy', alpha=0.3)
    ax[1].plot(val_epoch_accuracy, label='Validation Accuracy', alpha=0.3)
    ax[1].plot(smooth_train_epoch_accuracy, label='Smoothed Training Accuracy', color='blue')
    ax[1].plot(smooth_val_epoch_accuracy, label='Smoothed Validation Accuracy', color='orange')
    ax[1].set_xlabel('Epochs')
    ax[1].set_ylabel('Accuracy')
    ax[1].set_title('Training and Validation Accuracy')
    ax[1].legend()

    # Set x-axis to show each epoch as a tick
    ax[1].set_xticks(range(0, num_epochs + 1, 10))

    # Set y-axis for accuracy to range from 0 to 1 with ticks at intervals of 0.1
    ax[1].set_ylim(0, 1)
    ax[1].set_yticks([i * 0.1 for i in range(11)])
    if title:
        fig.suptitle(title, fontsise=16)

    plt.tight_layout()
    plt.show()

def display_accuracy_heatmap(path_to_load):
    # helper function for displaying best performing models in a convenient way
    with open(path_to_load, 'r') as file:
        results = json.load(file)
    
    rates = []
    av_test_losses = []
    av_test_accuracy = []
    for rate, value_dict in results.items():
        rates.append(rate)
        av_test_losses.append(value_dict['av_test_loss'])
        av_test_accuracy.append(value_dict['av_test_accuracy'])
    
    # Creating the DataFrame
    df = pd.DataFrame({
        'Average Test Loss': av_test_losses,
        'Average Test Accuracy': av_test_accuracy
    }, index=rates)
    
    # Applying conditional formatting to highlight the best value in each column
    def highlight_best(column):
        if column.name == 'Average Test Loss':
            is_best = column == column.min()
        else:
            is_best = column == column.max()
        return ['background: green' if v else '' for v in is_best]
    
    styled_df = df.style.apply(highlight_best, axis=0)
    
    return styled_df

def plot_single_model_performance(single_var_multi_run_data, title=None, enforce_axis=False):
    
    # function used for plotting the performance of single variable being investigated of n multiple runs 
    # for example during experiments 1.1 and 2.1
    # plots individual runs in background and a clearer average run in the foreground
    
    epochs = range(1, len(single_var_multi_run_data['av_train_losses']) + 1)
    n_runs = len(single_var_multi_run_data['all_train_losses'])
    fig, (ax1, ax2) = plt.subplots(1, 2, figsise=(12, 4))
    
    if title:
        title += f' across {n_runs} runs'
        fig.suptitle(title, fontsise=12)

    # Plot losses
    for train_loss, val_loss in zip(single_var_multi_run_data['all_train_losses'], single_var_multi_run_data['all_val_losses']):
        ax1.plot(epochs, train_loss, color='blue', alpha=0.3, linewidth=0.5, label='Individual Run Training Losses')
        ax1.plot(epochs, val_loss, color='orange', alpha=0.3, linewidth=0.5, label='Individual Run Validation Losses')
    ax1.plot(epochs, single_var_multi_run_data['av_train_losses'], color='blue', linewidth=1.2, label='Average Training Loss')
    ax1.plot(epochs, single_var_multi_run_data['av_val_losses'], color='orange', linewidth=1.2, label='Average Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Losses')
    
    # Remove duplicate labels in the legend
    handles, labels = ax1.get_legend_handles_labels()
    unique_labels = ["Average Training Loss", "Average Validation Loss", "Individual Run Training Losses", "Individual Run Validation Losses"]
    unique_handles = [handles[labels.index(label)] for label in unique_labels]
    ax1.legend(unique_handles, unique_labels)

    # Plot accuracies
    for train_acc, val_acc in zip(single_var_multi_run_data['all_train_accuracies'], single_var_multi_run_data['all_val_accuracies']):
        ax2.plot(epochs, train_acc, color='blue', alpha=0.3, linewidth=0.5, label='Individual Run Training Accuracies')
        ax2.plot(epochs, val_acc, color='orange', alpha=0.3, linewidth=0.5, label='Individual Run Validation Accuracies')
    ax2.plot(epochs, single_var_multi_run_data['av_train_acc'], color='blue', linewidth=1.2, label='Average Training Accuracy')
    ax2.plot(epochs, single_var_multi_run_data['av_val_acc'], color='orange', linewidth=1.2, label='Average Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Accuracies')
    
    # Remove duplicate labels in the legend
    handles, labels = ax2.get_legend_handles_labels()
    unique_labels = ["Average Training Accuracy", "Average Validation Accuracy", "Individual Run Training Accuracies", "Individual Run Validation Accuracies"]
    unique_handles = [handles[labels.index(label)] for label in unique_labels]
    ax2.legend(unique_handles, unique_labels)
    
    if enforce_axis:
        ax1.set_ylim(0, 5)
        ax2.set_ylim(0, 1)

    plt.tight_layout()
    plt.show()    
    

def plot_all_models_performance_from_disk(path_to_load, variable_name=None, enforce_axis=False):
    # used for loading in saved data and plotting with customisable titles (used to generate plots for the assignment)
    with open(path_to_load, 'r') as file:
        averaged_results = json.load(file)
        
    for variable_val, data in averaged_results.items():
        plot_single_model_performance(data, title=f'Training/Validation Losses and Accuracy for {variable_name} = {variable_val} across', enforce_axis=enforce_axis)

def plot_performance_comparison_from_file(path_to_load, enforce_axis=False, smooth_window=5):
    with open(path_to_load, 'r') as file:
        results = json.load(file)
    learning_rates = list(results.keys())
    num_epochs = len(results[learning_rates[0]]['av_train_losses'])

    fig_sise = (12, 16)
    fig, ((ax_train_loss, ax_train_acc), (ax_val_loss, ax_val_acc),
          (ax_train_loss_smoothed, ax_train_acc_smoothed),
          (ax_val_loss_smoothed, ax_val_acc_smoothed)) = plt.subplots(4, 2, figsise=fig_sise)

    plot_metrics(ax_train_loss, results, learning_rates, num_epochs, 'av_train_losses', 'Average Training Loss')
    plot_metrics(ax_train_acc, results, learning_rates, num_epochs, 'av_train_acc', 'Average Training Accuracy')
    plot_metrics(ax_val_loss, results, learning_rates, num_epochs, 'av_val_losses', 'Average Validation Loss')
    plot_metrics(ax_val_acc, results, learning_rates, num_epochs, 'av_val_acc', 'Average Validation Accuracy')
    plot_metrics(ax_train_loss_smoothed, results, learning_rates, num_epochs, 'av_train_losses', 'Smoothed Training Loss', smoothed=True, smooth_window=smooth_window)
    plot_metrics(ax_train_acc_smoothed, results, learning_rates, num_epochs, 'av_train_acc', 'Smoothed Training Accuracy', smoothed=True, smooth_window=smooth_window)
    plot_metrics(ax_val_loss_smoothed, results, learning_rates, num_epochs, 'av_val_losses', 'Smoothed Validation Loss', smoothed=True, smooth_window=smooth_window)
    plot_metrics(ax_val_acc_smoothed, results, learning_rates, num_epochs, 'av_val_acc', 'Smoothed Validation Accuracy', smoothed=True, smooth_window=smooth_window)

    if enforce_axis:
        for ax in [ax_val_acc, ax_val_loss, ax_train_acc, ax_train_loss,
                   ax_val_acc_smoothed, ax_val_loss_smoothed, ax_train_acc_smoothed, ax_train_loss_smoothed]:
            ax.set_ylim(0, 5) if 'Loss' in ax.get_ylabel() else ax.set_ylim(0, 1)

    plt.tight_layout()
    plt.show()

    if len(learning_rates) > 2:
        plot_comparative_metrics(results, learning_rates, num_epochs, 'Comparative Accuracies', 'av_train_acc', 'av_val_acc', enforce_axis)
        plot_comparative_metrics(results, learning_rates, num_epochs, 'Comparative Accuracies (Smoothed)', 'av_train_acc', 'av_val_acc', enforce_axis, smoothed=True, smooth_window=smooth_window)
    elif len(learning_rates) == 2:
        fig_acc_two, ax_acc_two = plt.subplots(figsise=(6, 4))
        fig_acc_two.suptitle('Comparative Accuracies', fontsise=12)

        for lr in learning_rates:
            ax_acc_two.plot(range(1, num_epochs + 1), results[lr]['av_val_acc'], label=f"Validation ({lr})", linestyle='-')
            ax_acc_two.plot(range(1, num_epochs + 1), results[lr]['av_train_acc'], label=f"Training ({lr})", linestyle='--')

        ax_acc_two.set_xlabel('Epoch')
        ax_acc_two.set_ylabel('Accuracy')
        ax_acc_two.set_title('Accuracy Comparison')
        ax_acc_two.legend(loc='upper right')

        if enforce_axis:
            ax_acc_two.set_ylim(0, 1)

        plt.tight_layout()
        plt.show()

        plot_comparative_metrics(results, learning_rates, num_epochs, 'Comparative Accuracies (Smoothed)', 'av_train_acc', 'av_val_acc', enforce_axis, smoothed=True, smooth_window=smooth_window)

def plot_metrics(ax, results, learning_rates, num_epochs, metric_key, title, smoothed=False, smooth_window=5):
    for lr in learning_rates:
        if smoothed:
            metric = np.convolve(results[lr][metric_key], np.ones(smooth_window) / smooth_window, mode='valid')
            ax.plot(range(smooth_window / 2, num_epochs - smooth_window / 2 + 1), metric, label=str(lr))
        else:
            ax.plot(range(1, num_epochs + 1), results[lr][metric_key], label=str(lr))
    ax.set_xlabel('Epoch')
    ax.set_ylabel(title)
    ax.set_title(title)
    ax.legend(title='Learning Rates', loc='lower right')

def plot_comparative_metrics(results, learning_rates, num_epochs, fig_title, train_key, val_key, enforce_axis=False, smoothed=False, smooth_window=5):
    fig, (ax_train, ax_val) = plt.subplots(1, 2, figsise=(12, 4))
    fig.suptitle(fig_title, fontsise=12)

    plot_metrics(ax_train, results, learning_rates, num_epochs, train_key, f'Training {fig_title}', smoothed, smooth_window)
    plot_metrics(ax_val, results, learning_rates, num_epochs, val_key, f'Validation {fig_title}', smoothed, smooth_window)

    if enforce_axis:
        ax_train.set_ylim(0, 1)
        ax_val.set_ylim(0, 1)

    plt.tight_layout()
    plt.show()


### Experiment 2 (19 MARKS) <ignore>

2.1
<figure><center><img src="./dr0.png" width=800><img src="./dr02.png" width=800><img src="./dr04.png" width=800><img src="./dr046.png" width=800><img src="./dr08.png" width=800><figcaption style="max-width: 600px"> Fig 9. Performance plots showing individual and averaged training and validation losses and accuracies for models trained with increasing dropout rates across 50 epochs of training </figcaption></center></figure>
Fig. 9 shows the effect of increasing dropout rates from 0 to 0.8. As the dropout rate increases, there is a reduction in the speed and extent of model fitting to the training data, reflected in the final training accuracy and the speed of reaching it. There is also a significant decrease in the generalisation gap between training and validation performance, with the highest dropout leading to the smallest gap.
<figure><center><img src="./overall dropout comparisons.png" width=800><figcaption style="max-width: 600px"> Fig 10. Smoothed averaged results for accuracies and losses across 50 epochs on validation data for models trained with different dropout rates </figcaption></center></figure>
<figure><center><img src="./dropout rates test results.PNG" width=300><figcaption style="max-width: 600px"> Fig 11. Test set performance of models trained with different dropout rates highlighting the best result for each metric in green</figcaption></center></figure>
Fig. 10 shows the comparative performance of models trained with different dropout rates. Validation losses indicates an earlier and more severe onset of the generalisation gap for lower dropout rates. Despite this, accuracy is relatively well preserved, with lower dropout rates still attaining reasonable performance on both test and validation datasets. However, the best test loss and performance belong to models with higher dropout rates, albeit by a small margin (fig 11).

2.2
<figure><center><img src="./pretrained_0.png" width=800><figcaption style="max-width: 600px"> Fig 12. Performance plots showing individual and averaged training and validation losses and accuracies for Baseline (non-dropout) model (model 0)during trained on the original data over 5figcaptionstyle=figcaption></center></figure>
<figure><center><img src="./pretrained_1.png" width=800><figcaption style="max-width: 600px"> Fig 13. Performance plots showing individual and averaged training and validation losses and accuracies for model with dropout implimented (model 1) during trained on the original data over 50 epochs. </figcaption></center></figure>
The findings above are reiterated in the freshly trained models shown in Figs. 12 & 13, where the model trained without dropout (model 0) demonstrates poor generalisability and marked overfitting, while the model trained with dropout (model 1) fits less closely to the training data.
<figure><center><img src="./retrained_baseline.png" width=800><figcaption width="max-width: 600px"> Fig 16. Performance plots showing individual and averaged training and validation losses and accuracies for Baseline (non-dropout) model (model 0) during retraining on swapped data over 50 epochs</figcaption></center></figure>
<figure><center><img src="./retrained_dropout.png" width=800><figcaption style="max-width: 600px">Fig 17. Performance plots showing individual and averaged training and validation losses and accuracies for model with dropout (model 0) during retraining on swapped data over 50 epochs</figcaption></center></figure>
Both models then had their fully connected layers retrained on a reversed version of the original dataset while their other parameters remained frozen. Figs. 16 and 17 show their performance during this second phase of training.

The dropout-free model fits and overfits to the new data extremely quickly, with only a brief period of learning generalisable information. The model pre-trained and retrained with dropout still overfits but the process is smoother and more gradual, with a less severe transition from fitting with generalisation to overfitting.
<figure><center><img src="./retrained_comparison.png" width=400><figcaption style="max-width: 600px">Fig 18. Direct comparison of performance of averaged and smoothed performance of non-dropout (model 0) and dropout (model 1) models over 50 epochs of training and validation on the swapped data </figcaption></center></figure>

<figure><center><img src="./retrained comparison test results.PNG" width=300><figcaption style="max-width: 600px"> Fig 19. Test set performance of models retrained on swapped dataset having been previously trained on original dataset without (0) and with (1) dropout implimented highlighting the best result for each metric in green.</figcaption></center></figure>

In terms of overall test set performance (Fig. 19), the model with dropout performs better on both metrics and outperforms all other models except the one using the LR scheduler.

In summary, the regularisation effect of a single dropout layer was able to improve performance almost to the same level as basic LR scheduling.

In [2]:
#############################
### Code for Experiment 2 ###
#############################

# --- EXPERIMENT 2.1 - Dropout Rates ---

# DATA LOADING AND NEW SPLIT
torch.manual_seed(0)

batch_sise = 64

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalise((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform)

# half and half split
num_validation_samples = 25000
num_train_samples = len(train_data) - num_validation_samples
train_data, val_data = random_split(train_data, [num_train_samples, num_validation_samples])

print(len(train_data)) # 50000 training egs  
print(len(val_data)) # 25000 test egs
print(len(test_data)) # 10000 test egs

train_dataloader = DataLoader(train_data, batch_sise=batch_sise, shuffle=True)
val_dataloader = DataLoader(val_data, batch_sise=batch_sise, shuffle=True)
test_dataloader = DataLoader(test_data, batch_sise=batch_sise, shuffle=False)

# DROPOUT MODEL DEFINITION

class DropoutNet(nn.Module):
    def __init__(self, dropout_rate):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_sise=2, stride=2)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_sise=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_sise=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_sise=3, stride=1, padding=1)
        self.fc1 = nn.Linear(in_features=64 * 4 * 4, out_features=64)
        self.dropout = nn.Dropout(p=dropout_rate)  # Dropout layer after the first FC layer only
        self.fc2 = nn.Linear(in_features=64, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  # Applying dropout after activation
        x = self.fc2(x)
        return x


# TRAINING WITH DIFFERENT DROPOUT RATES
# similar to 1.1

num_epochs = 50
learning_rate = 0.05

random_seeds = list(range(1, 6))
dropout_rates_for_experiment = [0, 0.2, 0.4, 0.6, 0.8]

averaged_results = {dr:{} for dr in dropout_rates_for_experiment}

path_to_save = f'./run_data/dropout/C2_final_dropout_rate_compatison_lr_{learning_rate}_{num_epochs}_epochs.json'
path_to_load = f'./run_data/dropout/C2_final_dropout_rate_compatison_lr_{learning_rate}_{num_epochs}_epochs.json'
save_experiment = True


for dropout_rate in dropout_rates_for_experiment:
    print('DR: ', dropout_rate) 
    epoch_train_losses_by_run = []
    epoch_val_losses_by_run = []
    epoch_train_accuracies_by_run = []
    epoch_val_accuracies_by_run = []
    test_losses = []
    test_accuracies = []
    reports = []
    
    for random_seed in random_seeds:
        print('DR: ', dropout_rate) 
        print('seed:', random_seed)
        torch.manual_seed(random_seed)
        
        model = DropoutNet(dropout_rate).to(device)
        criterion = nn.CrossEntropyLoss()
        optimiser = optim.SGD(model.parameters(), lr=learning_rate)

        model, train_epoch_losses, train_epoch_accuracy, val_epoch_losses, val_epoch_accuracy, _,_ = run_training_and_validation(model, device, learning_rate, num_epochs, criterion, optimiser, train_dataloader, val_dataloader, metrics = False, manual_lr_schedule=False, plot=True)
        epoch_train_losses_by_run.append(train_epoch_losses)
        epoch_val_losses_by_run.append(val_epoch_losses)
        epoch_train_accuracies_by_run.append(train_epoch_accuracy)
        epoch_val_accuracies_by_run.append(val_epoch_accuracy)
        
        test_loss, test_accuracy, report = run_testing(model, device, criterion, test_dataloader)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        reports.append(report)
        
    average_train_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_train_losses_by_run)]
    average_val_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_val_losses_by_run)]
    average_train_accuracies = [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_train_accuracies_by_run)]
    average_val_accuracies =  [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_val_accuracies_by_run)]
    average_test_loss = sum(test_losses)/len(test_losses)
    average_test_accuracy = sum(test_accuracies)/len(test_accuracies)
    
    averaged_results[dropout_rate] = {'seeds':random_seeds,'av_train_losses': average_train_losses,
                                       'av_val_losses': average_val_losses,
                                       'av_train_acc': average_train_accuracies,
                                       'av_val_acc': average_val_accuracies,
                                       'all_train_losses':epoch_train_losses_by_run,
                                       'all_val_losses': epoch_val_losses_by_run,
                                       'all_train_accuracies': epoch_train_accuracies_by_run,
                                       'all_val_accuracies': epoch_val_accuracies_by_run,
                                       'all_test_losses':test_losses, 
                                       'all_test_accuracies':test_accuracies,
                                       'av_test_loss': average_test_loss,
                                       'av_test_accuracy':average_test_accuracy}
    print('average for ')
    print('DR: ', dropout_rate) 
    plot_single_train_val_smoothed(average_train_losses,average_val_losses,average_train_accuracies,average_val_accuracies, num_epochs, smoothing_window=3, title=f'DROPOUT: {dropout_rate}')

if save_experiment:
    with open(path_to_save, 'w') as file:
        json.dump(averaged_results, file, indent=4)  # 'indent' makes the output formatted and easier to read
        
# PLOTTING
dropout_data = path_to_load
plot_all_models_performance_from_disk(dropout_data, enforce_axis=True)
plot_performance_comparison_from_file(dropout_data, enforce_axis=True)
display_accuracy_heatmap(dropout_data)

# --- EXPERIMENT 2.1 - TRANSFER LEARNINNG ---

# SWAP DATASETS WITH NEW DATALOADERS

torch.manual_seed(0)

batch_sise = 64

original_train_dataloader = DataLoader(train_data, batch_sise=batch_sise, shuffle=True)
original_val_dataloader = DataLoader(val_data, batch_sise=batch_sise, shuffle=True)

swapped_train_dataloader = DataLoader(val_data, batch_sise=batch_sise, shuffle=True)
swapped_val_dataloader = DataLoader(train_data, batch_sise=batch_sise, shuffle=True)

test_dataloader = DataLoader(test_data, batch_sise=batch_sise, shuffle=False)

# TRAINING ON ORIGINAL DATA

# train and save models ready transfer learning 
# train two models - one dropout, one not dropout, train them on the ORIGINAL half and half data, then save a copy of the models to disk
best_dropout_rate = 0.6

num_epochs = 50
learning_rate = 0.05

random_seeds = [list(range(1, 6))]


path_to_save = f'./run_data/transfer_learning/transfer_learn_original_dat_{num_epochs}_epochs_lr_{learning_rate}.json'
path_to_load = f'./run_data/transfer_learning/transfer_learn_original_dat_{num_epochs}_epochs_lr_{learning_rate}.json'

models = [0, 1]
averaged_results = {i:{} for i in models}

save_experiment = True

# train them both on the original data
for i, model in enumerate(models):
    epoch_train_losses_by_run = []
    epoch_val_losses_by_run = []
    epoch_train_accuracies_by_run = []
    epoch_val_accuracies_by_run = []
    test_losses = []
    test_accuracies = []
    reports = []
    
    for random_seed in random_seeds:
        print('MODEL: ', i) 
        print('seed:', random_seed)
        torch.manual_seed(random_seed)
        
        model = BaselineNet() if i == 0 else DropoutNet(dropout_rate=best_dropout_rate)
        model.to(device)
        
        criterion = nn.CrossEntropyLoss()
        optimiser = optim.SGD(model.parameters(), lr=learning_rate)
        
        model, train_epoch_losses, train_epoch_accuracy, val_epoch_losses, val_epoch_accuracy, _,_ = run_training_and_validation(model, device, learning_rate, num_epochs, criterion, optimiser, original_train_dataloader, original_val_dataloader, metrics = False, manual_lr_schedule=False, plot=True)
        epoch_train_losses_by_run.append(train_epoch_losses)
        epoch_val_losses_by_run.append(val_epoch_losses)
        epoch_train_accuracies_by_run.append(train_epoch_accuracy)
        epoch_val_accuracies_by_run.append(val_epoch_accuracy)
        
        test_loss, test_accuracy, report = run_testing(model, device, criterion, test_dataloader)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        reports.append(report)
        
    average_train_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_train_losses_by_run)]
    average_val_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_val_losses_by_run)]
    average_train_accuracies = [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_train_accuracies_by_run)]
    average_val_accuracies =  [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_val_accuracies_by_run)]
    average_test_loss = sum(test_losses)/len(test_losses)
    average_test_accuracy = sum(test_accuracies)/len(test_accuracies)
    
    averaged_results[i] = {'seeds':random_seeds,'av_train_losses': average_train_losses,
                                       'av_val_losses': average_val_losses,
                                       'av_train_acc': average_train_accuracies,
                                       'av_val_acc': average_val_accuracies,
                                       'all_train_losses':epoch_train_losses_by_run,
                                       'all_val_losses': epoch_val_losses_by_run,
                                       'all_train_accuracies': epoch_train_accuracies_by_run,
                                       'all_val_accuracies': epoch_val_accuracies_by_run,
                                       'all_test_losses':test_losses, 
                                       'all_test_accuracies':test_accuracies,
                                       'av_test_loss': average_test_loss,
                                       'av_test_accuracy':average_test_accuracy}
    print('average for ')
    print('Model: ', i) 
    plot_single_train_val_smoothed(average_train_losses,average_val_losses,average_train_accuracies,average_val_accuracies, num_epochs, smoothing_window=3, title=f'PRETRAINING MODEL: {i}')
    
    # save last version of model to disk for retraining    
    torch.save(model, f'./models/trained_model_{i}.pth')

    
if save_experiment:
    with open(path_to_save, 'w') as file:
        json.dump(averaged_results, file, indent=4) 
        
# PLOTTING
pre_training_data = path_to_load
plot_all_models_performance_from_disk(pre_training_data, enforce_axis=True)
plot_performance_comparison_from_file(pre_training_data, enforce_axis=True)
display_accuracy_heatmap(pre_training_data)


# PERFORM TRANSFER LEARNING

# load in the two pretrained models and then reinitialise some layers
# retrain on the SWAPPED data

num_epochs = 50
learning_rate = 0.05
random_seeds = list(range(1,6))

path_to_save = f'./run_data/transfer_learning/transfer_learning_data_{num_epochs}_epochs_lr_{learning_rate}.json'
path_to_load = f'./run_data/transfer_learning/transfer_learning_data_{num_epochs}_epochs_lr_{learning_rate}.json'

models = [0, 1]
averaged_results = {i:{} for i in models}

save_experiment = True

# train them both on the swapped train and val data - test data same
for i, model in enumerate(models):
    epoch_train_losses_by_run = []
    epoch_val_losses_by_run = []
    epoch_train_accuracies_by_run = []
    epoch_val_accuracies_by_run = []
    test_losses = []
    test_accuracies = []
    reports = []
    
    for random_seed in random_seeds:
        print('MODEL: ', i) 
        print('seed:', random_seed)
        torch.manual_seed(random_seed)
        # here handle the loading of saved model and reinitiailisation of the fully connected layers
        if i == 0:
            pretrained_model_non_dropout = torch.load('./models/trained_model_0.pth')
            pretrained_model_non_dropout.fc1 =  nn.Linear(in_features=64 * 4 * 4, out_features=64)
            pretrained_model_non_dropout.fc2 = nn.Linear(in_features=64, out_features=10)
            model = pretrained_model_non_dropout
        elif i == 1:
            pretrained_model_best_dropout = torch.load('./models/trained_model_1.pth')
            pretrained_model_best_dropout.fc1 =  nn.Linear(in_features=64 * 4 * 4, out_features=64)
            pretrained_model_best_dropout.fc2 = nn.Linear(in_features=64, out_features=10)
            model = pretrained_model_best_dropout
        model.to(device)
        criterion = nn.CrossEntropyLoss()
        optimiser = optim.SGD(model.parameters(), lr=learning_rate)
        model, train_epoch_losses, train_epoch_accuracy, val_epoch_losses, val_epoch_accuracy, _,_ = run_training_and_validation(model, device, learning_rate, num_epochs, criterion, optimiser, swapped_train_dataloader, swapped_val_dataloader, metrics = False, manual_lr_schedule=False, plot=True)
        epoch_train_losses_by_run.append(train_epoch_losses)
        epoch_val_losses_by_run.append(val_epoch_losses)
        epoch_train_accuracies_by_run.append(train_epoch_accuracy)
        epoch_val_accuracies_by_run.append(val_epoch_accuracy)
        
        test_loss, test_accuracy, report = run_testing(model, device, criterion, test_dataloader)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        reports.append(report)
        
    average_train_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_train_losses_by_run)]
    average_val_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_val_losses_by_run)]
    average_train_accuracies = [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_train_accuracies_by_run)]
    average_val_accuracies =  [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_val_accuracies_by_run)]
    average_test_loss = sum(test_losses)/len(test_losses)
    average_test_accuracy = sum(test_accuracies)/len(test_accuracies)
    
    averaged_results[i] = {'seeds':random_seeds,'av_train_losses': average_train_losses,
                                       'av_val_losses': average_val_losses,
                                       'av_train_acc': average_train_accuracies,
                                       'av_val_acc': average_val_accuracies,
                                       'all_train_losses':epoch_train_losses_by_run,
                                       'all_val_losses': epoch_val_losses_by_run,
                                       'all_train_accuracies': epoch_train_accuracies_by_run,
                                       'all_val_accuracies': epoch_val_accuracies_by_run,
                                       'all_test_losses':test_losses, 
                                       'all_test_accuracies':test_accuracies,
                                       'av_test_loss': average_test_loss,
                                       'av_test_accuracy':average_test_accuracy}
    print('average for ')
    print('Model: ', i) 
    plot_single_train_val_smoothed(average_train_losses,average_val_losses,average_train_accuracies,average_val_accuracies, num_epochs, smoothing_window=3, title=f'TRANSFER LEARNING MODEL: {i}')
    


if save_experiment:
    with open(path_to_save, 'w') as file:
        json.dump(averaged_results, file, indent=4)  # 'indent' makes the output formatted and easier to read

# plotting results

transfer_learned_data = path_to_load
plot_all_models_performance_from_disk(transfer_learned_data, enforce_axis=True)
plot_performance_comparison_from_file(transfer_learned_data, enforce_axis=True)
display_accuracy_heatmap(transfer_learned_data)

### Experiment 3 (19 MARKS) <ignore>

Figs. 20, 21, and 22 show gradient flow through the different models, using the absolute value of the gradient for clearer representation of magnitudes at different layers.

3.1 
<figure><center><img src="./gradients baseline model.png" width=800><figcaption style="max-width: 600px"> Fig 20. Mean and standard deviation of the gradients of the loss function with respect to the paramaters at each layer of thebaseline model during training. </figcaption></center></figure>
Fig. 20 shows that for the baseline model, the gradient in the first 5 episodes is small overall, with minimal propagation from later to earlier layers. In the last 5 episodes, gradients are higher with better flow to earlier layers. Variability seems proportional to gradient sise. Results indicate initially very small updates to parameters primarily in later layers, with more gradient passed to earlier layers by the end of training.

3.2
<figure><center><img src="./gradients dropout model.png" width=800><figcaption style="max-width: 600px"> Fig 21. Mean and standard deviation of the gradients of the loss function with respect to the paramaters at each layer of the model with dropout implimented during training. </figcaption></center></figure>
Fig. 21 shows similarities in gradient flow between models with and without dropout, with the main difference being higher gradient magnitudes in both first and last 5 episodes for the dropout model, despite similar variability and propagation patterns.

3.3
<figure><center><img src="./gradients batchnorm model (matching others).png" width=800><figcaption style="max-width: 600px"> Fig 22. Mean and standard deviation of the gradients of the loss function with respect to the paramaters at each layer of the model with dropout implimented during training. Not in this plot the batch normalisation layers and their paramaetyrr gradients are not represented to facilitate comparison with previous models  </figcaption></center></figure>

<figure><center><img src="./gradients batchnorm model (not matching others).png" width=800><figcaption style="max-width: 600px"> Fig 23. Mean and standard deviation of the gradients of the loss function with respect to the paramaters at each layer of the model with dropout implimented during training. batch norm paramater gradients included. </figcaption></center></figure>
Results for the batch normalised model (Figs. 22 and 23) differ significantly. Bias terms for convolutional layers disappear, as the bias parameter is taken over by batch normalisation layer parameters due to the 'absorption of bias' phenomenon [1],[6]. In common layers, gradient values are dramatically higher for all layers in the first 5 episodes, especially significant for earlier layers with virtually no gradient in un-batch-normalised models. The last 5 episodes are broadly similar. Comparative performance can be seen in Fig 24.
<figure><center><img src="./gradint flow relative metrics.png" width=800><figcaption style="max-width: 600px"> Fig 24. Comparison grouped by metric. </figcaption></center></figure>
3.4
<figure><center><img src="./batch norm performance.png" width=800><figcaption style="max-width: 600px"> Fig 25. Performance plots showing individual and averaged training and validation losses and accuracies for a model with batch normalisation applied and trained on orignal data over 50 epochs. </figcaption></center></figure>
<figure><center><img src="./batch norm test results.PNG" width=300><figcaption style="max-width: 600px"> Fig 26. Test performance of model trained with batch normalisation</figcaption></center></figure>
Figs. 25 & 26 show the batch normalised model's performance during training, validation and testing. The model overfits quickly and performs poorly on test data, with substantial instability in validation performance. This is perhaps surprising given the regularisation effect often associated with batch normalisation [1],[6].

In [None]:
#############################
### Code for Experiment 3 ###
#############################

# return to original data splits

batch_sise = 64

torch.manual_seed(0)

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalise((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


train_data = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform)

num_validation_samples = 5000
num_train_samples = len(train_data) - num_validation_samples

train_data, val_data = random_split(train_data, [num_train_samples, num_validation_samples])

print(len(train_data)) # 50000 training egs  
print(len(val_data)) # 10000 test egs
print(len(test_data)) # 10000 test egs

train_dataloader = DataLoader(train_data, batch_sise=batch_sise, shuffle=True)
val_dataloader = DataLoader(val_data, batch_sise=batch_sise, shuffle=True)
test_dataloader = DataLoader(test_data, batch_sise=batch_sise, shuffle=False)



# define functions for accumulating gradients across first 5 episodes for each layer

def collect_gradients_abs_4(model, dataloader, device, criterion, optimiser, num_epochs):
    # initialise dictionary for storing grads for each layer
    first_5_episodes_gradients_abs = {name: [] for name, _ in model.named_parameters()}
    last_5_episodes_gradients_abs = {name: [] for name, _ in model.named_parameters()}
    
    # run training
    for epoch in range(num_epochs):
        model.train().to(device)
        for batch_count, (images, labels) in enumerate(dataloader, 1):
            images = images.to(device)
            labels = labels.to(device)

            optimiser.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()

            episode_gradients_abs = {}
            for name, param in model.named_parameters():
                # handle any errors with 'none' gradients (a bug that I thought was happening at one point)
                if param.grad is None and param.requires_grad:
                    episode_gradients_abs[name] = torch.zeros_like(param.data)
                elif param.grad is not None:
                    # get the abs gradient for each layer for the episode
                    episode_gradients_abs[name] = torch.abs(param.grad.clone().detach())

            # if it is an episode being stored, add it to the dict
            if epoch == 0 and batch_count <= 5:
                for name, grad_abs in episode_gradients_abs.items():
                    first_5_episodes_gradients_abs[name].append(grad_abs)
            elif epoch == num_epochs - 1 and batch_count > len(dataloader) - 5:
                for name, grad_abs in episode_gradients_abs.items():
                    last_5_episodes_gradients_abs[name].append(grad_abs)

            optimiser.step()

    return first_5_episodes_gradients_abs, last_5_episodes_gradients_abs


def compute_gradient_statistics_abs_4(gradients_abs):
    # gets mean and sd of abs gradients
    mean_gradients_abs = {}
    std_gradients_abs = {}
    for layer_name, layer_gradients_abs in gradients_abs.items():
        layer_gradients_abs = torch.stack(layer_gradients_abs)
        mean_gradients_abs[layer_name] = torch.mean(layer_gradients_abs, dim=0)
        std_gradients_abs[layer_name] = torch.std(layer_gradients_abs, dim=0)
    return mean_gradients_abs, std_gradients_abs

def plot_gradient_statistics_abs_4(mean_gradients_first5_abs, std_gradients_first5_abs, mean_gradients_last5_abs, std_gradients_last5_abs, skip_bn=True):
    # make the plots of the layers and gradients for mean and sd, for first 5 and last 5
    if skip_bn:
        # Filter out batch normalisation layers
        layer_names = [name for name in mean_gradients_first5_abs.keys() if not name.startswith('bn')]
    else:
        layer_names = list(mean_gradients_first5_abs.keys())

    num_layers = len(layer_names)
    x = np.arange(num_layers)
    width = 0.35

    fig, (ax1, ax2) = plt.subplots(1, 2, figsise=(16, 6))
    fig.suptitle('Gradient Statistics (Absolute Means, Absolute Standard Deviations)', fontsise=16)

    # Plot mean absolute gradients
    ax1.bar(x - width/2, [torch.mean(mean_gradients_first5_abs[name]).item() for name in layer_names], width, label='First 5 Epochs')
    ax1.bar(x + width/2, [torch.mean(mean_gradients_last5_abs[name]).item() for name in layer_names], width, label='Last 5 Epochs')
    ax1.set_xticks(x)
    ax1.set_xticklabels(layer_names, rotation=45)
    ax1.set_xlabel('Layer')
    ax1.set_ylabel('Mean of Absolute Gradients')
    ax1.set_title('Mean of Absolute Gradients vs Layer')
    ax1.legend()

    # Plot standard deviations of absolute gradients
    ax2.bar(x - width/2, [torch.mean(std_gradients_first5_abs[name]).item() for name in layer_names], width, label='First 5 Epochs')
    ax2.bar(x + width/2, [torch.mean(std_gradients_last5_abs[name]).item() for name in layer_names], width, label='Last 5 Epochs')
    ax2.set_xticks(x)
    ax2.set_xticklabels(layer_names, rotation=45)
    ax2.set_xlabel('Layer')
    ax2.set_ylabel('Standard Deviation of Absolute Gradients')
    ax2.set_title('Standard Deviation of Absolute Gradients vs Layer')
    ax2.legend()

    plt.tight_layout()
    plt.show()

# for better visualisation plotting was done showing comparitave performance on the different specific metrics
def plot_model_comparison(first_5_mean_gradients_non_drop, first_5_mean_gradients_dropout, first_5_mean_gradients_bn,
                          last_5_mean_gradients_non_drop, last_5_mean_gradients_dropout, last_5_mean_gradients_bn,
                          first_5_std_gradients_non_drop, first_5_std_gradients_dropout, first_5_std_gradients_bn,
                          last_5_std_gradients_non_drop, last_5_std_gradients_dropout, last_5_std_gradients_bn):
    layer_names = [name for name in first_5_mean_gradients_non_drop.keys() if not name.startswith('bn')]
    print(layer_names)
    print(last_5_mean_gradients_non_drop['conv1.weight'])
    num_layers = len(layer_names)
    x = np.arange(num_layers)
    width = 0.2

    fig, axs = plt.subplots(2, 2, figsise=(16, 12))
    fig.suptitle('Model Comparison - Gradient Statistics', fontsise=16)

    # Plot mean absolute gradients for the first 5 epochs
    
    print([first_5_mean_gradients_non_drop[name].shape for name in layer_names])
    axs[0, 0].bar(x - width, [torch.mean(first_5_mean_gradients_non_drop[name]).item() for name in layer_names], width, label='Non-Dropout')
    axs[0, 0].bar(x, [torch.mean(first_5_mean_gradients_dropout[name]).item() for name in layer_names], width, label='Dropout')
    axs[0, 0].bar(x + width, [torch.mean(first_5_mean_gradients_bn[name]).item() for name in layer_names], width, label='Batch Norm')
    axs[0, 0].set_xticks(x)
    axs[0, 0].set_xticklabels(layer_names, rotation=45)
    axs[0, 0].set_xlabel('Layer')
    axs[0, 0].set_ylabel('Mean of Absolute Gradients')
    axs[0, 0].set_title('First 5 Epochs - Mean of Absolute Gradients')
    axs[0, 0].legend()
    # axs[0, 0].set_ylim(0, 0.04)
    

    # Plot mean absolute gradients for the last 5 epochs
    axs[0, 1].bar(x - width, [torch.mean(last_5_mean_gradients_non_drop[name]).item() for name in layer_names], width, label='Non-Dropout')
    axs[0, 1].bar(x, [torch.mean(last_5_mean_gradients_dropout[name]).item() for name in layer_names], width, label='Dropout')
    axs[0, 1].bar(x + width, [torch.mean(last_5_mean_gradients_bn[name]).item() for name in layer_names], width, label='Batch Norm')
    axs[0, 1].set_xticks(x)
    axs[0, 1].set_xticklabels(layer_names, rotation=45)
    axs[0, 1].set_xlabel('Layer')
    axs[0, 1].set_ylabel('Mean of Absolute Gradients')
    axs[0, 1].set_title('Last 5 Epochs - Mean of Absolute Gradients')
    axs[0, 1].legend()
    # axs[0, 1].set_ylim(0, 0.2)
    

    # Plot standard deviation of absolute gradients for the first 5 epochs
    axs[1, 0].bar(x - width, [torch.mean(first_5_std_gradients_non_drop[name]).item() for name in layer_names], width, label='Non-Dropout')
    axs[1, 0].bar(x, [torch.mean(first_5_std_gradients_dropout[name]).item() for name in layer_names], width, label='Dropout')
    axs[1, 0].bar(x + width, [torch.mean(first_5_std_gradients_bn[name]).item() for name in layer_names], width, label='Batch Norm')
    axs[1, 0].set_xticks(x)
    axs[1, 0].set_xticklabels(layer_names, rotation=45)
    axs[1, 0].set_xlabel('Layer')
    axs[1, 0].set_ylabel('Standard Deviation of Absolute Gradients')
    axs[1, 0].set_title('First 5 Epochs - Standard Deviation of Absolute Gradients')
    axs[1, 0].legend()

    # Plot standard deviation of absolute gradients for the last 5 epochs
    axs[1, 1].bar(x - width, [torch.mean(last_5_std_gradients_non_drop[name]).item() for name in layer_names], width, label='Non-Dropout')
    axs[1, 1].bar(x, [torch.mean(last_5_std_gradients_dropout[name]).item() for name in layer_names], width, label='Dropout')
    axs[1, 1].bar(x + width, [torch.mean(last_5_std_gradients_bn[name]).item() for name in layer_names], width, label='Batch Norm')
    axs[1, 1].set_xticks(x)
    axs[1, 1].set_xticklabels(layer_names, rotation=45)
    axs[1, 1].set_xlabel('Layer')
    axs[1, 1].set_ylabel('Standard Deviation of Absolute Gradients')
    axs[1, 1].set_title('Last 5 Epochs - Standard Deviation of Absolute Gradients')
    axs[1, 1].legend()
    # axs[1, 1].set_ylim(0, 0.2)


    plt.tight_layout()
    plt.show()
    

# set epochs and learning rate
# Set epochs and learning rate
num_epochs = 50
learning_rate = 0.05

# 3.1-------------------- Gradient flow for the original model --------------------
torch.manual_seed(1984)
non_drop_model = BaselineNet()
criterion = nn.CrossEntropyLoss()
optimiser = optim.SGD(non_drop_model.parameters(), lr=learning_rate)
first_5_epochs_gradients_abs_non_drop, last_5_epochs_gradients_abs_non_drop = collect_gradients_abs_4(non_drop_model, train_dataloader, device, criterion, optimiser, num_epochs)
first_5_mean_gradients_non_drop, first_5_std_gradients_non_drop = compute_gradient_statistics_abs_4(first_5_epochs_gradients_abs_non_drop)
last_5_mean_gradients_non_drop, last_5_std_gradients_non_drop = compute_gradient_statistics_abs_4(last_5_epochs_gradients_abs_non_drop)
plot_gradient_statistics_abs_4(first_5_mean_gradients_non_drop, first_5_std_gradients_non_drop, last_5_mean_gradients_non_drop, last_5_std_gradients_non_drop)

# 3.2 -------------------- Gradient flow for the model with dropout --------------------
torch.manual_seed(1984)
drop_model = DropoutNet(0.6)
criterion = nn.CrossEntropyLoss()
optimiser = optim.SGD(drop_model.parameters(), lr=learning_rate)
first_5_epochs_gradients_abs_dropout, last_5_epochs_gradients_abs_dropout = collect_gradients_abs_4(drop_model, train_dataloader, device, criterion, optimiser, num_epochs)
first_5_mean_gradients_dropout, first_5_std_gradients_dropout = compute_gradient_statistics_abs_4(first_5_epochs_gradients_abs_dropout)
last_5_mean_gradients_dropout, last_5_std_gradients_dropout = compute_gradient_statistics_abs_4(last_5_epochs_gradients_abs_dropout)
plot_gradient_statistics_abs_4(first_5_mean_gradients_dropout, first_5_std_gradients_dropout, last_5_mean_gradients_dropout, last_5_std_gradients_dropout)

# 3.3 -------------------- Gradient flow for the model with batch normalisation --------------------
# create model with BAtch norm as per brief
class BatchNormNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_sise=2, stride=2)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_sise=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_sise=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_sise=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(in_features=64 * 4 * 4, out_features=64)
        self.bn4 = nn.BatchNorm1d(64)
        self.fc2 = nn.Linear(in_features=64, out_features=10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.bn4(self.fc1(x)))
        x = self.fc2(x)
        return x

# this code uses the above to run and plot the gradient for the batch norm model  
torch.manual_seed(1984)
bn_model = BatchNormNet()
criterion = nn.CrossEntropyLoss()
optimiser = optim.SGD(bn_model.parameters(), lr=learning_rate)
first_5_epochs_gradients_abs_bn, last_5_epochs_gradients_abs_bn = collect_gradients_abs_4(bn_model, train_dataloader, device, criterion, optimiser, num_epochs)
first_5_mean_gradients_bn, first_5_std_gradients_bn = compute_gradient_statistics_abs_4(first_5_epochs_gradients_abs_bn)
last_5_mean_gradients_bn, last_5_std_gradients_bn = compute_gradient_statistics_abs_4(last_5_epochs_gradients_abs_bn)
plot_gradient_statistics_abs_4(first_5_mean_gradients_bn, first_5_std_gradients_bn, last_5_mean_gradients_bn, last_5_std_gradients_bn, skip_bn=True)
plot_gradient_statistics_abs_4(first_5_mean_gradients_bn, first_5_std_gradients_bn, last_5_mean_gradients_bn, last_5_std_gradients_bn, skip_bn=False)

# 3.4 
# properly train and visualise performance of batch norm model 

num_epochs = 50
learning_rate = 0.05

random_seeds = list(range(1, 6))
path_to_save = f'./run_data/batch_norm/batch_norm_{num_epochs}_epochs_LR_{learning_rate}.json'
path_to_load = f'./run_data/batch_norm/batch_norm_{num_epochs}_epochs_LR_{learning_rate}.json'
averaged_results = {'bn':{}}
save_experiment = True

# train them both on the original data

epoch_train_losses_by_run = []
epoch_val_losses_by_run = []
epoch_train_accuracies_by_run = []
epoch_val_accuracies_by_run = []
test_losses = []
test_accuracies = []
reports = []

for random_seed in random_seeds:
    print('seed:', random_seed)
    
    torch.manual_seed(random_seed)
    
    model = BatchNormNet()
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimiser = optim.SGD(model.parameters(), lr=learning_rate)
    
    model, train_epoch_losses, train_epoch_accuracy, val_epoch_losses, val_epoch_accuracy, _,_ = run_training_and_validation(model, device, learning_rate, num_epochs, criterion, optimiser, train_dataloader, val_dataloader, metrics = False, manual_lr_schedule=False, plot=True)
    epoch_train_losses_by_run.append(train_epoch_losses)
    epoch_val_losses_by_run.append(val_epoch_losses)
    epoch_train_accuracies_by_run.append(train_epoch_accuracy)
    epoch_val_accuracies_by_run.append(val_epoch_accuracy)
    
    test_loss, test_accuracy, report = run_testing(model, device, criterion, test_dataloader)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
    reports.append(report)
    
average_train_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_train_losses_by_run)]
average_val_losses = [sum(epoch_losses) / len(epoch_losses) for epoch_losses in zip(*epoch_val_losses_by_run)]
average_train_accuracies = [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_train_accuracies_by_run)]
average_val_accuracies =  [sum(epoch_accuracies) / len(epoch_accuracies) for epoch_accuracies in zip(*epoch_val_accuracies_by_run)]
average_test_loss = sum(test_losses)/len(test_losses)
average_test_accuracy = sum(test_accuracies)/len(test_accuracies)

averaged_results['bn'] = {'seeds':random_seeds,'av_train_losses': average_train_losses,
                                    'av_val_losses': average_val_losses,
                                    'av_train_acc': average_train_accuracies,
                                    'av_val_acc': average_val_accuracies,
                                    'all_train_losses':epoch_train_losses_by_run,
                                    'all_val_losses': epoch_val_losses_by_run,
                                    'all_train_accuracies': epoch_train_accuracies_by_run,
                                    'all_val_accuracies': epoch_val_accuracies_by_run,
                                    'all_test_losses':test_losses, 
                                    'all_test_accuracies':test_accuracies,
                                    'av_test_loss': average_test_loss,
                                    'av_test_accuracy':average_test_accuracy}
print('average for ')
plot_single_train_val_smoothed(average_train_losses,average_val_losses,average_train_accuracies,average_val_accuracies, num_epochs, smoothing_window=3, title=f'BATCH NORM MODEL')

    
if save_experiment:
    with open(path_to_save, 'w') as file:
        json.dump(averaged_results, file, indent=4)  # 'indent' makes the output formatted and easier to read

batch_norm = 'run_data/batch_norm/batch_norm_50_epochs_LR_0.05.json'
plot_all_models_performance_from_disk(batch_norm, enforce_axis=True)
plot_performance_comparison_from_file(batch_norm, enforce_axis=True)
display_accuracy_heatmap(batch_norm)


# Conclusions and Discussion (instructions) - 25 MARKS <ignore>
In this section, you are expected to:
* briefly summarise and describe the conclusions from your experiments (8 MARKS).
* discuss whether or not your results are expected, providing scientific reasons (8 MARKS).
* discuss two or more alternative/additional methods that may enhance your model, with scientific reasons (4 MARKS). 
* Reference two or more relevant academic publications that support your discussion. (4 MARKS)

The experiments demonstrated fundamental properties of neural networks. 

Experiment one showed the effect of LRs on a model's ability to fit training data and its impact on generalisation. High LRs were shown to lead to coarse updates, instability, and inability to reach the true minimal loss, while low LRs led to slow progress but a closer fit to training data. A LR scheduler balanced these properties, leading to quick learning with fine-grained accuracy in later stages, but did not significantly improve validation and test performance. These results were in keeping with expectations as LRs are known to impact training dynamics and convergence by dictating the magnitude of the step taken in paramter space [15].

Experiment two demonstrated the regularisation effect of dropout in regular training and transfer learning, significantly reducing the generalisation gap and validation loss as the dropout rate increased. It also had a profound regularizing effect in transfer learning, although performance improvements were small. Again these results were expected given dropout introduces noise and stochasticity into the network, preventing over-reliance on any individual neurons or subnetworks of neurons, leading to improved generalisation and reduced overfitting [15]. 

Experiment three clearly demonstrated the powerful impact of batch normalisation on gradient propagation through NN layers, with a stark contrast in average gradients arriving in early layers during initial episodes compared to the others. However, the impact on model performance was disappointing, with poor generalisation on the test set. This was surprising as batch normalisation is known for its regularisation effect [1, 6]. However, its headline benefit is speeding up learning through early gradient propagation to all layers (as *was* seen here). Given this singificant change to network dynamics, it could be that other hyperparameters need to be tuned to this new model dynamic to facilitate the other benefits.

To enhance the model, experimenting with complementary hyperparameters, such as the Adam optimiser [2], which is widely recommended for deep learning [15], could be explored. Additionally, skip connections, introduced in successful architectures like ResNets [4], could be investigated to deepen the network whilst miantaining gradient flow. 

Data augmentation techniques, such as random rotations, flips, crops, and color jittering, could also be tried to increase training data sise and diversity, potentially adressing one of the key issues with performance identified here: overfitting.

# References (instructions) <ignore>
Use the cell below to add your references. A good format to use for references is like this:

[AB Name], [CD Name], [EF Name] ([year]), [Article title], [Journal/Conference Name] [volume], [page numbers] or [article number] or [doi]

Some examples:

JEM Bennett, A Phillipides, T Nowotny (2021), Learning with reinforcement prediction errors in a model of the Drosophila mushroom body, Nat. Comms 12:2569, doi: 10.1038/s41467-021-22592-4

SO Kaba, AK Mondal, Y Zhang, Y Bengio, S Ravanbakhsh (2023), Proc. 40th Int. Conf. Machine Learning, 15546-15566

<ignore>

1. Bjorck, J., Gomes, C. P., Selman, B. (2018), Understanding batch normalisation, CoRR abs/1806.02375.

2. Kingma, D. P. (201), Adam: A method for stochastic optimisation, (No Title), Available: https:/d2l.ai/ (accessed May 12, 2024).

3. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., et al. (2020), An image is worth 16x16 words: Transformers for image recognition at scale, arXiv preprint arXiv:2010.11929.

4. He, K., Zhang, X., Ren, S., Sun, J. (2016), Deep residual learning for image recognition, In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778.

5. Huang, G., Liu, Z., Weinberger, K. Q. (2016), Densely connected convolutional networks, CoRR abs/1608.06993.

6. Ioffe, S., Szegedy, C. (2015), Batch normalisation: Accelerating deep network training by reducing internal covariate shift, CoRR abs/1502.03167.

7. Krizhevsky, A., Hinton, G., et al. (2009), Learning multiple layers of features from tiny images.

8. Krizhevsky, A., Sutskever, I., Hinton, G. E. (2017), Imagenet classification with deep convolutional neural networks, Communications of theACM 60(6), 84–90.

9. Pytorch Foundation (2024), CrossEntropyLoss - PyTorch 2.3 documentation, https:/pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html (accessed May 12, 2024).

10. Pytorch Foundation (2024), datasets - PyTorch 2.3 documentation, https:/pytorch.org/vision/0.8/datasets.html (accessed May 12, 2024).

11. Pytorch Foundation (2024), LogSoftmax - PyTorch 2.3 documentation, https:/pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html#torch.nn.LogSoftmax (accessed May 12, 2024).

12. Pytorch Foundation (2024), NLLLoss - PyTorch 2.3 documentation, https:/pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss (accessed May 12, 2024).

13. Pytorch Foundation (2024), SGD - PyTorch 2.3 documentation, https:/pytorch.org/docs/stable/generated/torch.optim.SGD.html (accessed May 12, 2024).

14. Simonyan, K., Zisserman, A. (2014), Very deep convolutional networks for large-scale image recognition, arXiv preprint arXiv:1409.1556.

15. Zhang, A., Lipton, Z. C., Li, M., & Smola, A. J. (2020). Dive into Deep Learning. Retrieved from https:/d2l.ai/
</ignore>