## Week 2: Training neural networks

In this notebook we are going to look in more depth at the process of training neural networks, with some nice visual representations where we can see our networks learning over time. We are going to try editing hyperparameters for training, such as the **learning rate** and **momentum**, and using different kinds of optimisation algorithm and eventually editing neural networks ourselves. 

The type of neural network we are using in this class is a [Compositional Pattern-Producing Network](http://eplex.cs.ucf.edu/papers/stanley_gpem07.pdf) (CPPN). They are very simple neural networks that can be trained quickly, which is highly unusual for generative neural networks. So whats the catch? Well they can only learn to generate a single image (in the standard use-case). Still, they have a very unique aesthetic and would be a great topic for further investigation for your mini-project, especially with limited computational resources. 

CPPNs have a long and interesting history which predates the modern discourse on creativeAI and generativeAI. You [can watch this bonus lecture all about the subject from a previous year](https://ual.cloud.panopto.eu/Panopto/Pages/Viewer.aspx?id=b3df426b-94fa-40b2-b774-af8e0115093e) if you want to know more about them. 

First lets do some imports:

In [None]:
%matplotlib inline

import os
import torch
import random
import IPython.display

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from PIL import Image
from torchvision.utils import save_image

# Import utility functions from the file util.py in the src folder
from src.util import get_normalised_coordinate_grid, make_training_gif

##### Define hyperparameters

Here we define our hyperparameters, try to write a comment to define what each parameter does and why we need it:

In [None]:
#
device = 'cpu'
#
num_steps = 100000
#
batch_size = 100
#
learn_rate = 0.001
#
momentum = 0.9
#
num_channels = 3
#
image_shape = (128,128)

##### Load target image

Lets load in our target image for this training process. There are a few in the folder `media`, you can try different images here or load your own images into the code:

In [None]:
target_im_path = '../media/colour_wheel.png'
target_im = Image.open(target_im_path).convert('RGB')
resized_im = target_im.resize(image_shape)
resized_im

##### Create coordinate grid

Here we create a grid of normalised coordinates between the values -1,1 for every pixel in the 128x128 image, which is flattened into one long list of x and y coordinates. This gives us a matrix tensor which has a length of 16384 (128x128) and second dimension of length 2 (x and y coordinate values).

In [None]:
all_xy_coordinates = get_normalised_coordinate_grid(image_shape) 
all_xy_coordinates = torch.tensor(all_xy_coordinates, device=device, dtype=torch.float32)
print(f'coordinate grid shape: {all_xy_coordinates.shape}')
print(f'coordinate grid data: \n {all_xy_coordinates}')

##### Create pixel tensor

Here we create our tensor containing our pixel values, normally this would be a 3-D tensor (width, height, channels), but we will flatten this to be a 2-D matrix tensor with the length 16384 (128x128 pixels) and second dimension of length 3 (red, green and blue pixel values).

In [None]:
all_pixel_values = np.reshape(resized_im, [-1, num_channels]) / 255
all_pixel_values = torch.tensor(all_pixel_values, device=device, dtype=torch.float32)
print(f'image pixel tensor shape: {all_pixel_values.shape}')
print(f'image pixel tensor data: \n {all_pixel_values}')

##### Define neural network

Here we define our CPPN neural network. Can you write comments for each line of code here?

Every network you create in PyTorch will inherit from the [torch.nn.Module](https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py) base class. This will mean that our network has all of the handy utility functions needed for us to be able to train it on our data. 

If you are unsure about anything you can consult [the pytorch neural network (torch.nn) reference](https://pytorch.org/docs/stable/nn.html) and [the W3 schools reference on python inheritance](https://www.w3schools.com/python/python_inheritance.asp). 

In [None]:
class CPPN(nn.Module):
    #
    def __init__(self):
      #
      super(CPPN, self).__init__()
      #
      self.fc1 = nn.Linear(2, 16)
      #
      self.fc2 = nn.Linear(16, 32)
      #
      self.fc3 = nn.Linear(32, num_channels)     
    
    #
    def forward(self, x):
        #
        x = self.fc1(x)
        #
        x = F.relu(x)
        #
        x = self.fc2(x)
        #
        x = F.relu(x)
        #
        x = self.fc3(x)
        #
        x = F.sigmoid(x)
        #
        return x

##### Setup core objects

Here we setup our core objects, the model, the loss function and the optimiser.

In [None]:
cppn = CPPN()
cppn.to(device)
cppn.requires_grad_()

optimiser = torch.optim.SGD(cppn.parameters(), lr=learn_rate, momentum=momentum)
criterion = nn.MSELoss(reduction='sum')

##### Training loop

Here is our training loop for our data. 

In [None]:
num_coords = all_xy_coordinates.shape[0]
coord_indexes = list(range(0, num_coords))
losses = []
img_list = []
running_loss = 0.0

for i in range(num_steps):
    optimiser.zero_grad()
    cppn.zero_grad()

    # Sample a random batch of indexes from the list coord_indexes
    batch_indexes = torch.tensor(np.array(random.sample(coord_indexes, batch_size)))
    
    # Get batch of respective xy_coordiantes
    coordinates_batch = all_xy_coordinates[batch_indexes]
    
    # And respective pixel values 
    pixel_values_batch = all_pixel_values[batch_indexes]
    
    # Process data with model
    approx_pixel_values = cppn(coordinates_batch)
    
    # Calculate and track loss function
    loss = criterion(pixel_values_batch, approx_pixel_values)
    running_loss += loss.item()
    losses.append(loss.item())
    
    if i % 1000 == 0:
        print(f'step {i}, loss {running_loss/1000:.3f}')
        running_loss = 0.0
        with torch.no_grad():
            prediction = cppn(all_xy_coordinates)
            prediction  = torch.swapaxes(prediction, 0, 1)
            prediction = torch.reshape(prediction, (num_channels, image_shape[0], image_shape[1]))
            if not os.path.exists('training_ims'):
                os.makedirs('training_ims')
            save_image(prediction, f'training_ims/im_{int(i/1000):06}.png')
            img_list.append(prediction)
            
    #Update model
    loss.backward()
    optimiser.step()

##### Generate whole image

Here we are generating an entire image in one go. This may seem counter-intuitive but we are passing in our entire coordinate matrix, all 16384 cooridinates in one code. This means we are processing our data with a batch size of 16834 and **processing that data through 16384 copies of our neural network** all in one go! 

We can only do this because our network is so small and our modern computer have so much memory available to store and process all that data. As you will see in later weeks, we noramlly do inference with our networks with a batch size of 1, not 16384!

In [None]:
with torch.no_grad():
    prediction = cppn(all_xy_coordinates)


##### Make training GIF

Here lets make a gif of our training performance over time. This is a nice visual way of see how training occurs and whether there is a smooth convergence of our model or not. 

In [None]:
im_folder_path = 'training_ims'
file_out = 'training_gif.gif'
make_training_gif(im_folder_path=im_folder_path, im_ext='png', file_out=file_out)
IPython.display.Image('training_gif.gif')

##### Interactive training animation

We can also interactively look at training if we want to interactively zoom in on particular parts of the training process. 

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

IPython.display.HTML(ani.to_jshtml())

##### Side by side comparison

Here is our final predicted image (left) side by side with our target image (right).

In [None]:
# reshaping it from 1D to 2D
reconstructed_img = np.reshape(prediction.cpu(), (image_shape[0], image_shape[1], num_channels)) #adding 3 because of RGB
# scaling the values from [0,1] to [0, 255]
reconstructed_img *= 255
# converting the tensor into a numpy array, and cast the type into a uint8.
reconstructed_img = reconstructed_img.numpy().astype(np.uint8)
# looking at our creation next to the original!
fig, axes_array = plt.subplots(1,2, figsize=(20,10))
axes_array[0].imshow(resized_im)
axes_array[1].imshow(reconstructed_img)
plt.show()

##### Plot loss

Here is our training loss over time, what do you observe? Come back to this when you change the training hyperparameters to see if you see any difference here.

In [None]:
plt.figure(figsize=(10,5))
plt.title("Loss During Training")
plt.plot(losses,label="loss")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

### Tasks

**Task 1** Write comments for [the hyperparameters](#define-hyperparameters) and [code for defining the network](#define-neural-network). If there is code you don't understand and get stuck here don't get too bogged down, put a note raise it as a question at the end of the session. 

**Task 2:** Try [changing the hyperparameters](#define-hyperparameters) `learning_rate`, `momentum` and `batch_size` to see what effect they have. What is the highest learning rate you can use and still get a network that replicates the target image. What happens when you when make the momentum or batch size very low?

**Task 3:** Try [loading in a different image](#load-target-image) and see how the network does there? Is the performance better or worse with a different image. Feel free to go and find your own image here and load it into the code.

**Task 4:** Try [changing the architecture of the neural network](#define-neural-network). Add more layers, increase (or decrease) the number of units in each fully connected layer. Change the activation functions to [one of the other many available activation functions in pytorch](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions). How do all these things affect training. 

**Task 5:** Try [changing the optimiser](#setup-core-objects) from SGD (stochastic gradient descent) to one of the [many other optimisers in PyTorch](https://pytorch.org/docs/stable/optim.html). Are there any new hyperparameters that you can adjust in your optimiser? What effect does changing these hyperparameters have?

#### Bonus tasks

**Task A:** Can you revisit the code from Week 1 and finish writing comments for all of the lines of code now based on what you have learnt from this weeks lecture? Are there still any gaps?

**Task B:** Can you rewrite this notebook to use a custom dataset class and a dataloader? Then can you rewrite the training loop to use epochs instead of iterations? See this PyTorch reference for more details: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

