In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
from enum import Enum

In [None]:
# Get the device for training using either CPU or GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Running on: {device}')

In [None]:
# [STEP 1] - Getting dataset and creating dataloaders 

# Function to get the CIFAR10 dataset and create the dataloaders
# Batch size is set to 16 by default
def get_data(batch_size=16):

    # Data augmentation, normalization for training and creating the tensors
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    # Test (validation) transform matching the normalisation as the same as the training data and creating the tensors
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    # Downloading the CIFAR10 dataset and applying the transformations
    train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    # Creating the dataloaders and setting the batch size
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [None]:
# [STEP 2] - Creating the model 

# Block
# This class defines the blocks of the backbone of the network
class Block(nn.Module):
  def __init__(self,
               in_channels, # Inital channel value is three, updated with output channel of each block for next block
               out_channels, # Output channel of the block
               conv_layer=3, # Number of conv layers in the block
               conv_kernel=3, # Kernel size of the conv layers
               conv_stride=1, # Stride of the conv layers
               conv_padding=1, # Padding of the conv layers
               max_pool=False, # If the block has max pooling or not
               ):

    super(Block, self).__init__()

    # Number of conv layers in the block
    self.k = conv_layer

    # Spatial Average Pooling - for the Linear layer to produce weights
    self.sap = nn.AdaptiveAvgPool2d((1, 1)).to(device)

    # Linear layer generating weights for the conv2d layers
    self.Linear1 = nn.Linear(in_channels, self.k).to(device)

    # Non-linear activation
    # Used for the Linear layer and the conv2d layers
    self.g = nn.ReLU(inplace=True).to(device)

    #Flatten - reshape dimensions for Linear layer
    self.flat = nn.Flatten().to(device)

    # Maximum Pooling
    # Block may have this layer or not, dependent on what the practioner has set
    # Reduces the spatial dimensions of the weight*conv2d tensor
    self.max_pool = nn.MaxPool2d(2, 2).to(device) if max_pool else None

    # Create k-conv layers
    for i in range(self.k):
      self.add_module(
          'conv{0}'.format(i),
          nn.Conv2d(in_channels,
                    out_channels,
                    kernel_size=conv_kernel,
                    stride=conv_stride,
                    padding=conv_padding
                    ).to(device)
      )

    # Batch norm layers - one for each conv layer
    self.bn_layers = nn.ModuleList([nn.BatchNorm2d(out_channels).to(device) for _ in range(self.k)])
    
  def forward(self, x):
      # ---- WEIGHT GENERATION ----
      out = self.sap(x)       # Spatial Average Pooling
      out = self.flat(out)    # Flatten
      out = self.Linear1(out) # Linear layer
      out = self.g(out)       # Non-linear activation

      # ---- WEIGHTS * CONV2DS ----
      batch_size = out[:, 0].size(0)
      block_out = 0
      for i in range(self.k):
        # Conv2d --> BatchNorm --> ReLU
        conv = self._modules['conv{0}'.format(i)](x)
        conv = self.bn_layers[i](conv)
        conv = self.g(conv)
        # Get the i'th weights (and reshaped to be compatible for multiplication) and multiply it with the i'th conv2d layer
        block_out += out[:, i].view(batch_size, 1, 1, 1) * conv
      
      # ---- REDUCE SPATIAL DIMENSIONS ----
      if self.max_pool:
        block_out = self.max_pool(block_out)

      return block_out

# Network - Backbone and Classifier
# This class defines the network architecture, generating the backbone and classifier
class Net(nn.Module):
    def __init__(self, blocks): # Blocks is a list of tuples, each tuple contains the parameters for the block
        super(Net, self).__init__()
        self.num_classes = 10 # CIFAR10 has 10 classes

        # ---- GENERATING BLOCKS AND BACKBONE ----
        backbone_blocks = []
        channel_pass = 3 # Intial channel is three; inital channel is updated with output channel of each block for next block
        for conv_layer, out_channels, k, s, p, mxpool in blocks:
            backbone_blocks.append(
                Block(
                    channel_pass,
                    out_channels,
                    conv_layer=conv_layer,
                    conv_kernel=k,
                    conv_stride=s,
                    conv_padding=p,
                    max_pool=mxpool
                ).to(device)
            )
            channel_pass = out_channels
        
        self.out_channels = channel_pass # Output channel of the last block

        # Adding the blocks to the backbone      
        self.backbone = nn.Sequential(*backbone_blocks)

        # ---- CREATING CLASSIFIER ----
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),       # f = SpatialAveragePool(On), On is output of last block
            nn.Flatten(),                       # Reshape f, ready for and pass to classifier
            nn.Linear(self.out_channels, 512),  # classifier steps
            nn.ReLU(inplace=True),              # classifier steps
            nn.Dropout(0.5),                    # classifier steps
            nn.Linear(512, self.num_classes),   # classifier steps
        )


    def forward(self, x):
        return self.classifier(self.backbone(x))

In [None]:
# [STEP 4 - part of] - Calculating and tracking the loss and accuracy, and plotting graphs

%matplotlib inline

# Enum for the type of run - training or evaluation
class RunType(Enum):
    TRAIN = 1
    EVAL = 2

# Calculating and tracking the loss and accuracy, and plotting graphs
class Stats():
    # total_epochs - number of epochs to train
    # target_acc - target accuracy to stop training
    def __init__(self, total_epochs, target_acc=0.91):

        # Train loss per end of epoch
        self.train_loss = []
        # Train accuracy per end of epoch
        self.train_acc = []
        # Test loss per end of epoch (evaluation)
        self.test_loss = []
        # Test accuracy per end of epoch (evaluation)
        self.test_acc = []

        self.total_epochs = total_epochs
        self.target = target_acc

        # List of the epochs accuracy, loss, learning rate stats to print for each epoch
        self.print_list = []

        # Running loss, correct predictions and total predictions for training
        self.average_loss = 0.0
        self.running_loss = 0.0
        self.correct = 0
        self.total = 0

        # Running loss, correct predictions and total predictions for evaluation
        self.eval_average_loss = 0.0
        self.eval_running_loss = 0.0
        self.eval_correct = 0
        self.eval_total = 0
    
    # Update training accuracy and loss during training / evaluation
    # Referenced from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
    # run_type - training or evaluation, run_loss - loss for the current batch, y_hat - predicted values, y - actual values
    def update_loss_accuracy(self, run_type, run_loss, y_hat, y):
        if run_type == RunType.TRAIN:
            self.running_loss += run_loss # Accumulate loss for each batch in the epoch 
            _, pred = torch.max(y_hat, 1) # Extracts the predicted class labels  
            self.total += y.size(0) # Accumulate total number of predictions
            self.correct += (pred == y).sum().item() # Accumulate correct predictions
        elif run_type == RunType.EVAL:
            self.eval_running_loss += run_loss
            _, pred = torch.max(y_hat, 1)
            self.eval_total += y.size(0)
            self.eval_correct += (pred == y).sum().item()
    
    # Calculate and keep track of average loss and accuracy for each epoch for training / evaluation
    # run_type - training or evaluation, num_batches - number of batches 
    def calculate_loss_accuracy(self, run_type, num_batches):
        avg_loss, avg_acc = 0, 0 
        if run_type == RunType.TRAIN:
            self.average_loss = self.running_loss / num_batches   # Loss of the current epoch
            avg_loss = self.average_loss
            avg_acc = self.correct / self.total                 # Accuracy of the current epoch
            self.train_loss.append(self.average_loss)
            self.train_acc.append(avg_acc)
            self.correct, self.total, self.running_loss = 0, 0, 0 # Reset for the next epoch
        elif run_type == RunType.EVAL:
            self.eval_average_loss = self.eval_running_loss / num_batches
            avg_loss = self.eval_average_loss
            avg_acc = self.eval_correct / self.eval_total
            self.test_loss.append(self.eval_average_loss)
            self.test_acc.append(avg_acc)
            self.eval_correct, self.eval_total, self.eval_running_loss = 0, 0, 0

        return avg_loss, avg_acc

    # Plots accuracy and loss graphs, and prints stats for each epoch
    # final_epoch is used to set the final, true x-axis limit for the graphs once training has finished 
    def print_stats(self, final_epoch=None):
        epochs = range(1, len(self.train_loss) + 1)

        # Plot training loss
        plt.figure(figsize=(8, 6))
        plt.plot(epochs, self.train_loss, label='Training Loss')
        plt.plot(epochs, self.test_loss, label='Validation Loss')
        plt.title('Training and Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.ylim(0, 3)
        plt.xlim(0, self.total_epochs if final_epoch is None else final_epoch+10) # Adding 10 padding for better visualisation
        plt.grid(True)

        # Plot training accuracy
        plt.figure(figsize=(8, 6))
        plt.plot(epochs, self.train_acc, label='Training Accuracy')
        plt.plot(epochs, self.test_acc, label='Validation Accuracy')
        plt.title('Training and Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.ylim(0, 1)
        plt.xlim(0, self.total_epochs if final_epoch is None else final_epoch+10)
        plt.yticks(np.arange(0, 1+0.1, 0.1))
        plt.grid(True)

        plt.show()
        display.display(self.print_list)
        display.clear_output(wait=True)

    # Print stats at the end of each epoch - accuracy, loss, learning rate
    def update(self, epoch, lr):
        self.print_list.append(f'Epoch: {epoch}, Loss: {self.average_loss}, Test Loss: {self.eval_average_loss}, Train Accuracy: {self.train_acc[-1]}, Test (Evaluation) Accuracy: {self.test_acc[-1]},  Next LR: {lr}')
        self.print_stats()
    
    # Print finalised stats once accuracy has been reached / training/evaluation has finished - accuracy, loss, learning rate
    def finished(self, final_epoch):
        self.print_list.append(f'Finished training at epoch: {final_epoch} --> Evaluation accuracy: {self.test_acc[-1]*100}%')
        self.print_stats(final_epoch)
    
    # Check if the inital training accuracy is poor
    # Used to restart and reinitalise the model if the training accuracy is poor
    def is_poor_performance(self):
        return self.train_acc[-1] < 0.2 and len(self.train_acc) == 1 

    # Check if the target accuracy has been reached
    def has_reached_target(self):
        return self.test_acc[-1] >= self.target

In [None]:
# [STEP 4 - part of] - Training and evaluation script

# Training the model        
def _train(net, train_iter, loss, optimiser, device, stats):
    net.train()
    
    # Standard training loop
    for X, y in train_iter:
        optimiser.zero_grad()
        
        X, y = X.to(device), y.to(device)
        y_hat = net(X)
        l = loss(y_hat, y)
        
        l.backward()
        optimiser.step()

        # Update training loss and accuracy, and track these values
        stats.update_loss_accuracy(RunType.TRAIN, l.item(), y_hat, y)
    
    # Calculate loss and accuracy, and update stats for each epoch
    stats.calculate_loss_accuracy(RunType.TRAIN, len(train_iter))

# Evaluate the model
# net - model, test_iter - evaluation data, loss - loss function, device - CPU or GPU, stats - stats object to keep track of loss, accuracy and learning rate
def _evaluate_accuracy(net, test_iter, loss, device, stats):
    net.eval()

    # No gradient calculation for evaluation
    with torch.no_grad():
        # Pass to network and get loss
        for X, y in test_iter:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)

            # Calculate loss
            eval_loss = loss(y_hat, y)

            # Update training loss and accuracy, and track these values
            stats.update_loss_accuracy(RunType.EVAL, eval_loss.item(), y_hat, y)

        # Calculate loss and accuracy, and update stats for each epoch
        avg_loss, _ = stats.calculate_loss_accuracy(RunType.EVAL, len(test_iter))

    return avg_loss

# Train Model
# net - model, train_iter - training data, test_iter - evaluation data, loss - loss function, optimiser - optimiser, scheduler - learning rate scheduler, device - CPU or GPU, epochs - number of epochs
def train_model(net, train_iter, test_iter, loss, optimiser, scheduler, device, epochs=60):
    
    # Stats object to keep track of loss, accuracy and learning rate
    stats = Stats(epochs)
    current_epoch = 0 # Store the last epoch for the final stats
    for epoch in range(epochs):
        current_epoch = epoch

        # Training
        _train(net, train_iter, loss, optimiser, device, stats)

        # Evaluation
        eval_loss = _evaluate_accuracy(net, test_iter, loss, device, stats)
        
        # Update learning rate with evaluation loss after each epoch
        scheduler.step(eval_loss)

        # View current stats
        stats.update(epoch, scheduler.get_last_lr())

        # Rejects training if inital training accuracy is poor - for reinitialisation and restarting the training
        # ONLY USED FOR THE FIRST EPOCH TRAINING ACCURACY
        if stats.is_poor_performance():
            return False

        # Breaks the training loop if the target accuracy has been reached (minimum 91%)
        if stats.has_reached_target():
            break
    
    # Finished training - print final stats
    stats.finished(current_epoch)
    
    return True


In [None]:
# [STEP 3, 4, 5] -  Training

def start_training():

    # 6 BLOCKS for the backbone
    # Parameters: num_conv_layer, conv_channels, conv_kernel, conv_stride, conv_padding, do_max_pooling
    blocks = (
        (3, 64, 3, 1, 1, False),
        (3, 64, 3, 1, 1, True),
        (3, 128, 3, 1, 1, False),
        (3, 128, 3, 1, 1, True),
        (3, 256, 3, 1, 1, False),
        (3, 256, 3, 1, 1, True),
    )

    batch_size = 16
    lr, wd = 0.0001, 1e-5  # Learning rate and weight decay
    
    # [STEP 3] - Create loss and optimiser
    # Also initalising the model with the blocks, getting the dataloaders and scheduler
    trainloader, testloader     =   get_data(batch_size)
    model                       =   Net(blocks).to(device)
    loss                        =   nn.CrossEntropyLoss().to(device)
    optimiser                   =   torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    sched                       =   torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, mode='min', patience=5)

    # ------ IMPORTANT ------
    # Sometimes the model will start with accuracy scores closely around 0.1 and RARELY will not train properly after that i.e. gets stuck at 0.1 essentially
    # In current testing however, this issue has not been encountered as it should be fixed now with lower learning rate but may be possible
    # Weight initialisation makes it worse. So the approach was to leave it to pytorch to automatically initalise the weights
    # We want the starting accuracy (train accuracy on first epoch) to be at least 20% for safe measure but can be lower and still train properly
    # Therefore, this while loop will restart the training process if the training accuracy of the first epoch is less than 20%
    # It rarely will restart, if so, it should restart once.
    # If there is an issue (hopefully not and extremely unlikely), please restart the kernel and run the training again
    # The model will train for a maximum of 80 epochs, but will stop early once it reaches 91% accuracy
    # It usually reaches 91% accuracy around 20 - 50 epochs
    while not train_model(model, trainloader, testloader, loss, optimiser, sched, device, 80):
        print('Poor performance start on first epoch training accuracy (< 0.2)')
        print('Reinitialising model, loss, optimiser, scheduler, data loaders')
        
        # Relieve memory (RAM / GPU VRAM)
        del model
        del loss
        del optimiser
        del sched
        del trainloader
        del testloader

        # Refreshing and reinitialising the model, loss, optimiser, scheduler, data loaders
        trainloader, testloader     =   get_data(batch_size)
        model                       =   Net(blocks).to(device)
        loss                        =   nn.CrossEntropyLoss().to(device)
        optimiser                   =   torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
        sched                       =   torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, mode='min', patience=5)
        
        print('Reinitialisation complete, restarting training')

In [None]:
# Start training
start_training()