## Train an image classifier using a pre-trained model

In this network we will look at how to train basic image classifier from scratch in PyTorch.

**Before you go any further** follow the instructions in the file `create-classification-dataset.ipynb` to make you dataset before you can train anything here. 

Start with the other notebook `train-image-classifier-from-scratch.pynb` before moving onto this one. Complete all the tasks there before working through this notebook.

You will want to wait until that has trained before running this notebook. 

First lets do some imports:

In [None]:
import torch
import torchvision

import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.utils as vutils
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split

If you are getting an error you may need to uncomment the next line to install sklearn:

In [None]:
#!pip install scikit-learn

##### Hyperparameters

Now lets define out hyperparameters. **If your dataset has more than 3 classes** make sure to change the parameter `num_classes`!

In [None]:
device = 'cpu'
momentum = 0.9
num_epochs = 10
num_classes = 3
val_size = 0.3
batch_size = 100
learn_rate = 0.001
freeze_lower_layers = True
data_path = '../data/my-data/my-classification-dataset/'

##### Training image transforms

Here we define our image transforms (and data augmentation) for our training data. There is a task in the other notebook to write comments for these. 

In [None]:
train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomAffine(degrees=(-30,30),translate=(0.15,0.15),scale=(0.85,1.15)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        transforms.RandomResizedCrop(size=(224, 224), antialias=True),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

##### Validation image transforms

Here we define our image transforms for our validation data. How do they differ from the training transforms?

In [None]:
val_transform = transforms.Compose(
    [   
        torchvision.transforms.Resize(224, antialias=True),
        torchvision.transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

##### Create datasets

Here we create our dataset classes. Because we are using different transforms, we need to make two seperate dataset classes. We will then take a random sub-selection of our data and split our dataset into two. 

When we do the split, by setting `random_state=42`, we are doing this in a deterministic way, such that we will always get the same 'random' split of data into the training and validation sets.

In [None]:
# Instatiate train and validation dataset with seperate transforms
train_dataset = ImageFolder(data_path, transform=train_transform)
val_dataset = ImageFolder(data_path, transform=val_transform)

# Get length of dataset and indicies
num_train = len(train_dataset)
indices = list(range(num_train))

# Get train / val split for data points
train_indices, val_indices = train_test_split(indices, test_size=val_size, random_state=42)

# Override dataset classes to only be samples for each split
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

##### Plot training images

Here we are plotting a sample of training images. See how the data augmentation transforms are manipulating the images, compared to the images in the validation set (visualised in the next cell).

In [None]:
# Plot some training images
real_batch = next(iter(train_loader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

##### Plot validation set images

In [None]:
# Plot some training images
real_batch = next(iter(val_loader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Validation Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

#### Load pre-trained model

Instead of creating our own convolution neural network from scratch, here we are downloading a pre-trained model from [the torchvision models library](https://pytorch.org/vision/stable/models.html). Here we are using a [ResNet](https://arxiv.org/abs/1512.03385) trained on [the imagenet dataset](https://www.image-net.org/), but feel free to change tho one of the many other available models from the torchvison library if you want.

The boolean `freeze_lower_layers` that you can change in the [hyperparamers](#hyperparameters) cell determines whether we freeze the weights of most of the CNN. If this is set to `True` then we are performing **transfer learning** (only learning a new set of weights for the final layer). If this is set to `False` then we are performing **fine-tuning**, where we fine-tune the weights of the whole model from an initial set of pre-trained weights.

The following block of code is originally sourced from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

In [None]:
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')

# Freeze weights
if freeze_lower_layers == True:
    for param in model.parameters():
        param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model.to(device)

##### Setup loss and optimiser

In [None]:
# Cross entropy loss for training classification
criterion = nn.CrossEntropyLoss()

# Stochastic gradient descent loss
optimizer = optim.SGD(model.parameters(), lr=learn_rate, momentum=momentum)

##### Training loop

Here is our training loop for our data. Just like the other notebook, look at how the training dataset and validation dataset are used? 

What differences are there in the code when we cycle through each of these sets of data?

In [None]:
train_losses = []
val_losses = []

best_loss = 100000
for epoch in range(num_epochs): 
    train_loss = 0.0
    
    # Training loop
    for i, data in enumerate(train_loader, 0):
        # Get data
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Process data
        outputs = model(inputs)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Update model weights
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    
    # Validation loop
    with torch.no_grad():
        val_loss = 0.0
        for i, data in enumerate(val_loader, 0):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    
    # Normalise cumulative losses to dataset size
    train_loss = train_loss / len(train_loader)
    val_loss = val_loss / len(val_loader)
    
    # Added cumulative losses to lists for later display
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f'Epoch {epoch + 1}, train loss: {train_loss:.3f}, val loss: {val_loss:.3f}')
    
    # if validation score is lowest so far, save the model
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), 'best_finetuned_model.pt')

##### Plot training vs validation loss 

Lets plot our training vs validation loss over time. If you train for long enough you will usually see the validation loss start to get worse at some point while the training loss will continue to get better. This occurs when the model starts to **overfit** to the training data, and become worse at accurately classifying unseen data. 

When our model is giving the best performance on our validation data (before the validation loss starts to increase) is when we would perform early stopping. If you haven't observed that here, you may need to re-run this code and train for more epochs. Given the limited time available in class, you may want to do this at home, running the code overnight (with your laptop plugged in!)

**Compare the losses here with the other notebook.** How does the overall magnitude of the loss compare? Does this approach appear to learn faster or slower than the other notebook?

In [None]:
plt.figure(figsize=(10,5))
plt.title("Train vs validation loss")
plt.plot(train_losses,label="train")
plt.plot(val_losses,label="val")
plt.xlabel("epochs")
plt.ylabel("cumulative loss")
plt.legend()
plt.show()

### Tasks

**Task 1:** Run all the cells in this code to train a classifier on your custom dataset. 

**Task 2:** Once it has trained compare the results with the notebook `train-image-classifier-from-scratch.ipynb`, which approach gives you the lower overall loss and which converges faster?

#### Bonus tasks
**Task A:** [Change the hyperparameter](#hyperparameters) `freeze_lower_layers` from `True` to `False` to perform **fine-tuning** instead of **transfer learning** and re-run the code (make a copy of this notebook if you want to have a direct comparison). How does the results compare? Is this approach noticably slower?

**Task B:** [Change the CNN model loaded from torchvision library](#load-pre-trained-model) from a ResNet18 to another ResNet model. How does that affect the training peformance? (Again you may want to make a copy of this notebook for a direct comparison.) 

