# Animal Identification
***
## Table of Contents
1. [Introduction](#1-introduction)
1. [Device Agnostic-Code](#2-device-agnostic-code)
1. [Loading Custom Dataset](#3-loading-custom-dataset)
1. [Understanding Data](#4-understanding-data)
1. [Data Preprocessing](#5-data-preprocessing)
    - [Normalisation](#normalisation)
    - [Preparing DataLoaders](#preparing-dataloaders)
1. [Convolutional Neural Network (CNN) Architecture](#6-convolutional-neural-network-cnn-architectures)
    - [ResNet-50](#resnet-50)
    - [Structure](#structure)
1. [Evaluation Metrics](#7-evaluation-metrics)
1. [Loss Function](#8-loss-function)
    - [Cross-Entropy Loss](#cross-entropy-loss)
1. [Optimiser](#9-optimiser)
1. [Training and Evaluation](#10-training-and-evaluation)
    - [Training Steps](#training-steps)
    - [Validation Steps](#validation-steps)
1. [Results (Custom ResNet-50)](#11-results-custom-resnet-50)
    - [Overall Performance](#overall-performance)
    - [Classifications](#classifications)
    - [Missclassifications](#missclassifications)
    - [Confusion Matrix](#confusion-matrix)
    - [Conclusion (Custom ResNet-50)](#conclusion-custom-resnet-50)
1. [Transfer Learning](#12-transfer-learning)
1. [Results (Pre-Trained ResNet-50)](#13-results-pre-trained-resnet-50)
    - [Overall Performance](#overall-performance)
    - [Classifications](#classifications)
    - [Missclassifications](#missclassifications)
    - [Confusion Matrix](#confusion-matrix)
    - [Conclusion (Pre-Trained ResNet-50)](#conclusion-pre-trained-resnet-50)
1. [References](#14-references)
***

In [None]:
import torch
import os
import random
from torch import nn
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
from typing import List
import pandas as pd
import seaborn as sns
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset
import matplotlib.pyplot as plt

## 1. Introduction

## 2. Device Agnostic-Code
Mac GPU acceleration (`mps` backend) delivers significant speed-up over CPU for deep learning tasks, especially for large models and batch sizes. On Windows, `cuda` is used instead of `mps`.

In [None]:
# Set device
# device = "cuda" if torch.cuda.is_available() else "cpu"  # For Windows
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")  # For Mac
device

## 3. Loading Custom Dataset
Retrieved from [Kaggle - Animal-10](https://www.kaggle.com/datasets/alessiocorrado99/animals10)

In [None]:
# Set a seed for reproducibility.
random_seed = 2
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)

In [None]:
data_path = Path("_datasets/animal-10/raw-img")

if data_path.is_dir():
    print(f"{data_path} directory exists.")
else:
    print(f"{data_path} Directory does not exist !")

In [None]:
translation_it_en = {
    "cane": "dog",
    "cavallo": "horse",
    "elefante": "elephant",
    "farfalla": "butterfly",
    "gallina": "chicken",
    "gatto": "cat",
    "mucca": "cow",
    "pecora": "sheep",
    "ragno": "spider",
    "scoiattolo": "squirrel",
}

In [None]:
classes_to_int = {path.name: i for i, path in enumerate(list((data_path).iterdir()))}
classes_to_str = {i: name for i, name in enumerate(classes_to_int)}

In [None]:
def walk_through_dir(dir_path):
    for (
        directory_path,
        directory_names,
        file_names,
    ) in os.walk(dir_path):
        print(
            f"{len(directory_names)} directories and {len(file_names)} images found in {directory_path}"
        )


In [None]:
walk_through_dir(data_path)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, img_path, transform=None) -> None:
        self.img_path = img_path
        self.transform = transform
        self.all_paths = [path for path in img_path.glob("*/*") if path.is_file()]

    def __len__(self):
        return len(self.all_paths)

    def __getitem__(self, index):
        single_file_path = self.all_paths[index]
        img = Image.open(single_file_path).convert("RGB")
        label = classes_to_int[single_file_path.parent.name]

        if self.transform:
            img = self.transform(img)
        return img, label


## 4. Understanding Data

In [None]:
def display_raw_samples(dataset: Dataset) -> None:
    figure = plt.figure(figsize=(9, 9))
    cols, rows = 3, 3
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(0, len(dataset), (1,)).item()
        img, label = dataset[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.title(classes_to_str[label])
        plt.axis("off")
        plt.tight_layout()
        plt.imshow(img)
    plt.show()

In [None]:
custom_ds = CustomDataset(data_path)
display_raw_samples(custom_ds)

In [None]:
print(f"Data size: {len(custom_ds)}")

In [None]:
sample_idx = [0, 25, 50]
for i in sample_idx:
    img, label = custom_ds[i]
    img_array = np.array(img)
    print(f"Image of the index {i} has the shape {img_array.shape}")

Each image has different height and width values (e.g., `(225, 300, 3)` refers to `(height, width, colour channel)`).

In [None]:
class_names_targets = [classes_to_str[label] for _, label in custom_ds]
class_names = np.unique(class_names_targets)
print(class_names)

The dataset contains 10 different animal categories with an imbalanced number of samples.

In [None]:
unique_vals, counts = np.unique(class_names_targets, return_counts=True)
df_dist = pd.DataFrame({"Class Label": unique_vals, "Count": counts})

plt.figure(figsize=(10, 6))
sns.barplot(data=df_dist, x="Class Label", y="Count", hue="Class Label", palette="Set2")
plt.xlabel("Class Label")
plt.ylabel("Count")
plt.title("Distribution of Class Labels")
plt.tight_layout()
plt.show()

## 5. Data Preprocessing

### Normalisation
For simple models (shallow networks, logistic regression, etc.), `ToTensor()` is often sufficient as it rescales image pixel values to the range from 0 to 1. However, for state-of-the-art architectures, it is strongly recommended to re-normalise (standardise) inputs so that each colour channel has zero mean and unit variance. Many pretrained models are trained on such normalised inputs, therefore this approach tends to yield better results than basic normalisation. Furthermore, centring inputs around zero generally results in more stable training and faster convergence, particularly for architectures with activation functions such as tanh or certain weight initialisation schemes.

Let:
- $X_{n, c, h, w}$: Pixel value for image $n$, channel $c$, height $h$, and width $w$.
- $N$: Total number of images.
- $C$: Number of channels (RGB = 3).
- $H, W$: Height and width of an image.

For each batch of images ($\left[B, 3, 32, 32\right]$), 
- Mean per channel:
$$
\mu_{\text{batch, c}} = \dfrac{1}{B \cdot H \cdot W}\sum^{B}_{n=1} \sum^{H}_{h=1} \sum^{W}_{w=1} X_{n, c, h, w}
$$

- Squared mean per channel:

$$
s_{\text{batch, c}} = \dfrac{1}{B \cdot H \cdot W}\sum^{B}_{n=1} \sum^{H}_{h=1} \sum^{W}_{w=1} X^2_{n, c, h, w}
$$

- Mean:
$$
\mu = \dfrac{\sum_{\text{batches}} \mu_{\text{batch, c}}}{n_{\text{batches}}}
$$

Using the identity $\text{Var}(X) = E\left[(X - \mu \right)^2]$ :
- Standard deviation:

\begin{align*}

\sigma &= \sqrt{E\left[X^2\right] - (E\left[X\right])^2} \\
 &= \sqrt{\dfrac{\sum_{\text{batches}} s_{\text{batch, c}}}{n_{\text{batches}}} - \mu^2}

\end{align*}

In [None]:
def get_mean_and_std(train_loader):
    channel_sum, channel_squared_sum, n_batches = 0, 0, 0
    for data, _ in train_loader:
        # Shape: [batch_size, channel=3, height=*, width=*]
        channel_sum += torch.mean(data, dim=(0, 2, 3))
        channel_squared_sum += torch.mean(data**2, dim=(0, 2, 3))
        n_batches += 1
    mean = channel_sum / n_batches
    std = (channel_squared_sum / n_batches - mean**2) ** 0.5
    print(f"Mean: {mean.tolist()}\nStd:{std.tolist()}")
    return mean.tolist(), std.tolist()


### Preparing DataLoaders
`torch.utils.data.DataLoader()` increases the computational efficiency by dividing a large dataset into smaller chunks (called **batches** or **mini-batches**). The size of these batches is controlled by the hyperparameter `batch_size`. Processing data in batches allows gradient descent to be performed once per batch rather than once per epoch, facilitating faster and more stable training process. 

The `prepare_dataloaders()` function follows the steps:

1. Apply stratified splitting.
    - The dataset is partitioned into training (80%), validation (10%) and test (10%) subsets.
    - Stratified splitting ensures each class is proportionally represented in all subsets.
1. Calculate the mean and standard deviation of the training dataset.
    - The mean and std are computed ONLY on the training split to avoid data leakage.
1. Normalise training, validation and testing datasets and apply data augmentation.
    - The computed mean and standard deviation are used to normalise all data subsets.
    - Data augmentation, such as random horizontal flipping or rotation, is applied only to the training data to increase data diversity and enhance generalisation.
1. Create dataloaders.
    - SubsetRandomSampler is used to randomly select samples from given index lists (equivalent to setting shuffle=True).
    - Return dataloaders.

In [None]:
def prepare_dataloaders(dataset, batch_size, img_size):
    # ! 1. Apply stratified splitting and divide dataset into train(80%), validation (10%) and test(10%) subsets.
    # Split into: Training+Validation / Test
    labels = [classes_to_int[path.parent.name] for path in dataset.all_paths]
    train_val_indices, test_indices = train_test_split(
        list(range(len(labels))),
        test_size=0.1,
        stratify=labels,
        random_state=random_seed,
    )
    # Split into: Training / Validation
    train_val_labels = [labels[i] for i in train_val_indices]
    train_indices, validation_indices = train_test_split(
        train_val_indices,
        test_size=0.111111,  # 0.1/0.9 for 80/10/10
        stratify=train_val_labels,
        random_state=random_seed,
    )

    # ! 2. Calculate mean and std of train dataset.
    transform_for_stats = transforms.Compose(
        [
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
        ]
    )
    stats_dataset = CustomDataset(data_path, transform_for_stats)
    stats_sampler = SubsetRandomSampler(train_indices)
    stats_loader = DataLoader(stats_dataset, batch_size, stats_sampler)

    mean, std = get_mean_and_std(stats_loader)

    # ! 3. Prepare transformations for train and test subsets.
    train_transform = transforms.Compose(
        [
            transforms.Resize(size=(img_size, img_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ]
    )
    validation_transform = transforms.Compose(
        [
            transforms.Resize(size=(img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ]
    )

    train_dataset = CustomDataset(data_path, train_transform)
    validation_dataset = CustomDataset(data_path, validation_transform)
    test_dataset = CustomDataset(data_path, validation_transform)

    train_subset = Subset(train_dataset, train_indices)
    validation_subset = Subset(validation_dataset, validation_indices)
    test_subset = Subset(test_dataset, test_indices)

    # ! 4. Using SubsetRandomSampler, create train & test dataloaders.
    train_dataloader = DataLoader(
        dataset=train_subset, batch_size=batch_size, shuffle=True
    )
    validation_dataloader = DataLoader(
        dataset=validation_subset, batch_size=batch_size, shuffle=False
    )
    test_dataloader = DataLoader(
        dataset=test_subset, batch_size=batch_size, shuffle=False
    )
    print(
        f"Length of train_dataloader: {len(train_indices)}/{batch_size} = {len(train_dataloader)}"
    )
    print(
        f"Length of validation_dataloader: {len(validation_indices)}/{batch_size} = {len(validation_dataloader)}"
    )
    print(
        f"Length of test_dataloader: {len(test_indices)}/{batch_size} = {len(test_dataloader)}"
    )
    return train_dataloader, validation_dataloader, test_dataloader

In [None]:
BATCH_SIZE = 32
IMAGE_SIZE = 224

train_dataloader, validation_dataloader, test_dataloader = prepare_dataloaders(
    custom_ds, BATCH_SIZE, IMAGE_SIZE
)

## 6. Convolutional Neural Network (CNN) Architectures
### ResNet-50
ResNet-50 is a deep convolutional neural network (CNN) belonging to the Residual Networks (ResNet) family, developed to address the *vanishing gradient problem* that impairs the training of very deep networks. The core innovation of ResNet architecture is the residual block with skip (shortcut) connections, allowing the network to learn residual mappings instead of direct mappings, which stabilises training and makes deeper architectures feasible.

ResNet-50 comprises 50 learnable layers, making it substantially deeper and more expressive than shallower variants such as ResNet-18. The architecture primarily consists of an initial convolutional layer, followed by a series of bottleneck residual blocks, batch normalisation, rectified linear unit (ReLU) activations, and culminates in a fully connected (FC) layer for classification. Its deeper architecture and increased capacity make ResNet-50 particularly suitable for tasks where high representational power and feature extraction capabilities are needed, such as large-scale image classification challenges.

### Structure

| Layer/Block                        | Details                                                                               |
|-------------------------------------|--------------------------------------------------------------------------------------|
| Initial Conv Layer                  | 7×7, 64 filters, stride 2, padding 3                                                 |
| Batch Normalisation & ReLU          | Applied after initial convolutional layer                                             |
| Max Pooling                         | 3×3, stride 2, padding 1                                                             |
| Residual Block Stage 1 (Layer1)     | 3 bottleneck residual blocks, 256 filters each (1×1, 3×3, 1×1 convolutions)           |
| Residual Block Stage 2 (Layer2)     | 4 bottleneck residual blocks, 512 filters each (1×1, 3×3, 1×1 convolutions)           |
| Residual Block Stage 3 (Layer3)     | 6 bottleneck residual blocks, 1024 filters each (1×1, 3×3, 1×1 convolutions)          |
| Residual Block Stage 4 (Layer4)     | 3 bottleneck residual blocks, 2048 filters each (1×1, 3×3, 1×1 convolutions)          |
| Global Average Pooling              | Reduces feature maps to 1×1                                                          |
| Fully Connected Layer               | Outputs class scores (typically 1,000 for ImageNet)                                  |

**Total**: 1 (initial conv) + (3 + 4 + 6 + 3) (residual blocks per stage) × 3(convs per block) + 1 (FC) = 50 layers

In [None]:
class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, shortcut, stride=1):
        super().__init__()
        self.conv_1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=mid_channels,
            kernel_size=1,
            stride=1,
            padding=0,
        )  # 1x1, 1st conv
        self.bn_1 = nn.BatchNorm2d(mid_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv_2 = nn.Conv2d(
            in_channels=mid_channels,
            out_channels=mid_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
        )  # 3x3, 2nd conv
        self.bn_2 = nn.BatchNorm2d(mid_channels)
        self.conv_3 = nn.Conv2d(
            in_channels=mid_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
        )  # 1x1, 3rd conv
        self.bn_3 = nn.BatchNorm2d(out_channels)

        if shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=1,
                    stride=stride,
                    padding=0,
                ),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)  # Shortcut path
        out = self.relu(self.bn_1(self.conv_1(x)))  # 1x1 conv + BN + ReLU
        out = self.relu(self.bn_2(self.conv_2(out)))  # 3x3 conv + BN + ReLU
        out = self.bn_3(self.conv_3(out))  # 1x1 conv + BN (Without ReLU)
        out += identity  # Add shortcut
        return self.relu(out)  # Final ReLU Activation


In [None]:
class ResNet50(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super().__init__()
        # Layer 0: Conv -> MaxPool -> BN -> ReLU
        self.layer_0 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=64,
                kernel_size=7,
                stride=2,
                padding=3,
                bias=False,
            ),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
        )
        # 3 -> 4 -> 6 -> 3 BRBs.
        # Layer 1: 3 Bottleneck Residual Blocks, 256 filters
        self.layer_1 = nn.Sequential(
            BottleneckBlock(
                in_channels=64,
                mid_channels=64,
                out_channels=256,
                shortcut=True,
                stride=1,
            ),
            BottleneckBlock(
                in_channels=256,
                mid_channels=64,
                out_channels=256,
                shortcut=False,
                stride=1,
            ),
            BottleneckBlock(
                in_channels=256,
                mid_channels=64,
                out_channels=256,
                shortcut=False,
                stride=1,
            ),
        )
        # Layer 2: 4 Bottleneck Residual Blocks, 512 filters
        self.layer_2 = nn.Sequential(
            BottleneckBlock(
                in_channels=256,
                mid_channels=128,
                out_channels=512,
                shortcut=True,
                stride=2,
            ),
            BottleneckBlock(
                in_channels=512,
                mid_channels=128,
                out_channels=512,
                shortcut=False,
                stride=1,
            ),
            BottleneckBlock(
                in_channels=512,
                mid_channels=128,
                out_channels=512,
                shortcut=False,
                stride=1,
            ),
            BottleneckBlock(
                in_channels=512,
                mid_channels=128,
                out_channels=512,
                shortcut=False,
                stride=1,
            ),
        )
        # Layer 3: 6 Bottleneck Residual Blocks, 1024 filters
        self.layer_3 = nn.Sequential(
            BottleneckBlock(
                in_channels=512,
                mid_channels=256,
                out_channels=1024,
                shortcut=True,
                stride=2,
            ),
            BottleneckBlock(
                in_channels=1024,
                mid_channels=256,
                out_channels=1024,
                shortcut=False,
                stride=1,
            ),
            BottleneckBlock(
                in_channels=1024,
                mid_channels=256,
                out_channels=1024,
                shortcut=False,
                stride=1,
            ),
            BottleneckBlock(
                in_channels=1024,
                mid_channels=256,
                out_channels=1024,
                shortcut=False,
                stride=1,
            ),
            BottleneckBlock(
                in_channels=1024,
                mid_channels=256,
                out_channels=1024,
                shortcut=False,
                stride=1,
            ),
            BottleneckBlock(
                in_channels=1024,
                mid_channels=256,
                out_channels=1024,
                shortcut=False,
                stride=1,
            ),
        )
        # Layer 4: 3 Bottleneck Residual Blocks, 2048 filters
        self.layer_4 = nn.Sequential(
            BottleneckBlock(
                in_channels=1024,
                mid_channels=512,
                out_channels=2048,
                shortcut=True,
                stride=2,
            ),
            BottleneckBlock(
                in_channels=2048,
                mid_channels=512,
                out_channels=2048,
                shortcut=False,
                stride=1,
            ),
            BottleneckBlock(
                in_channels=2048,
                mid_channels=512,
                out_channels=2048,
                shortcut=False,
                stride=1,
            ),
        )
        # Classifier (AdaptiveAvgPool -> Flatten -> Dropout -> Fully connected layer)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(p=0.5),
            nn.Linear(2048, num_classes),
        )

    def forward(self, x):
        x = self.layer_0(x)
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = self.layer_4(x)
        x = self.classifier(x)
        return x


In [None]:
resnet_50 = ResNet50(num_classes=len(class_names)).to(device)

In [None]:
from torchinfo import summary

summary(
    resnet_50,
    input_size=(
        1,
        3,
        IMAGE_SIZE,
        IMAGE_SIZE,
    ),  # (batch_size, colour channels, height, width)
    verbose=0,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

## 7. Evaluation Metrics
We will use the following evaluation metrics:
- `torchmetrics.Accuracy`
- `torchmetrics.F1Score`

In [None]:
from torchmetrics import Accuracy, F1Score

n_classes = len(class_names)

calculate_accuracy = Accuracy(task="multiclass", num_classes=n_classes).to(device)

calculate_f1 = F1Score(task="multiclass", num_classes=n_classes, average="macro").to(
    device
)

metrics = [calculate_accuracy, calculate_f1]

## 8. Loss Function
### Cross-Entropy Loss
Cross-Entropy Loss is a loss function used for classification problems, particularly when the model outputs probabilities using a softmax activation in the final layer. It measures the difference between the true labels and the predicted probability distribution.

For a single data point, the cross-entropy loss is defined as:

\begin{align*}
    L = - \sum^{k}_{i=1}y_{i}\log{(\hat y_{i})}
\end{align*}

where:
- $y_i$: True label for the $i$-th class. If one-hot encoded, $y_{i} = 1$ for the corrected class, $y_{i} = 0$ otherwise.
- $\hat y_i$: Predicted probability for the $i$-th class.
- $k$: Number of classes.

For a batch of $m$ data point:

\begin{align*}
    C = \dfrac{1}{m} \sum^{m}_{j=1} \left (- \sum^{k}_{i=1}y_{j, i}\log{(\hat y_{j, i})} \right)
\end{align*}

where:
- $C$: Average cross-entropy loss over the batch.
- $m$: Number of training examples (batch size).
- $k$: Number of classes.
- $y_{j, i} \in { 0, 1}$: Indicator that true class for sample $j$ corresponds to class $i$.
- $\hat y_{j, i} \in { 0, 1}$: Predicted probability for sample $j$ belonging to class $i$.

In PyTorch:
- Use `nn.CrossEntropyLoss()` directly with raw logits.
- Do not apply `Softmax()` or `LogSoftmax()` manually before the loss.
- Internally, `nn.CrossEntropyLoss() = LogSoftmax() + NegativeLogLikelihoodLoss()`. 

In [None]:
loss_function = nn.CrossEntropyLoss()

## 9. Optimiser
An optimiser in neural networks is used to adjust the parameters (weights and biases) of a model during training to minimise the loss. Optimisers are essential for enabling neural networks to learn from data: without them, the model would not improve over time.

In [None]:
optimiser = torch.optim.SGD(
    resnet_50.parameters(), lr=0.005, weight_decay=0.0005, momentum=0.9
)
scheduler = ReduceLROnPlateau(optimiser, mode="min", factor=0.5, patience=3)

## 10. Training and Evaluation
1. Iterate through epochs
1. For each epoch, iterate through training batches, perform training steps, calculate the train loss and evaluation metrics per batch.
1. For each epoch, iterate through testing batches, perform testing steps, calculate the test loss and evaluation metrics per batch.
1. Store the results.

### Training Steps
1. Forward Pass
    - Pass inputs through the model to obtain predictions.
1. Calculate Loss Per Batch
    - Measure how far the predictions deviate from the true labels, using a loss function.
1. Zero the Gradients
    - Clear the previously stored gradients to prevent accumulation from multiple backward passes.
1. Backward Pass
    - Computes gradients of the loss with respect to the model's parameters via backpropagation.
1. Optimiser Step
    - Update the parameter $\theta$ using the gradients just computed, typically following an equation such as:
    $$
        \theta \leftarrow \theta - \eta \dfrac{\partial \mathcal{L}}{\partial \theta}
    $$
    where $\eta$ is the learning rate.
1. Average Training Loss
    - Computes the mean training loss across all batches for the epoch.

In [None]:
def train_step(
    model: nn.Module,
    data_loader: DataLoader,
    loss_function: nn.Module,
    optimiser: torch.optim.Optimizer,
    metrics: List[nn.Module],
    device: torch.device = device,
):
    model.to(device)
    model.train()  # Set the model to training mode
    train_loss = 0
    for metric in metrics:
        metric.reset()
    for X, y in tqdm(data_loader):
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        y_pred = model(X)

        # 2. Calculate loss per batch
        loss = loss_function(y_pred, y)
        train_loss += loss.item()

        for metric in metrics:
            metric.update(y_pred, y)

        # 3. Optimiser zero grad
        optimiser.zero_grad()

        # 4. Loss backward
        loss.backward()

        # 5. Optimiser step
        optimiser.step()

    # Divide total train loss by length of train dataloader (average per batch per epoch)
    train_loss /= len(data_loader)
    train_acc = calculate_accuracy.compute().item() * 100
    train_f1 = calculate_f1.compute().item() * 100
    print(f"Train loss: {train_loss:.5f} | Train accuracy: {train_acc:.2f}%")
    train_metrics = [train_acc, train_f1]
    return train_loss, train_metrics

### Validation Steps
1. Forward pass
    - Set the model to evaluation mode (which disables dropout and batch normalisation and desactivates gradient tracking for safety).
    - Pass inputs through the model to obtain predictions.
1. Calculate Loss Per Batch
    - Measure how far the predictions deviate from the true labels, using a loss function.
1. Update and Compute Accuracy
    - Updates accuracy state with each batch, and compute the overall accuracy after all validation batches.
1. Average Validation Loss
    - Computes the mean Test loss across all batches for the epoch.

In [None]:
def validation_step(
    model: nn.Module,
    data_loader: DataLoader,
    loss_function: nn.Module,
    metrics: List[nn.Module],
    device: torch.device = device,
):
    model.to(device)
    model.eval()
    validation_loss = 0
    for metric in metrics:
        metric.reset()
    with torch.inference_mode():
        for X, y in data_loader:
            X, y = X.to(device), y.to(device)

            # 1. Forward pass
            validation_pred = model(X)

            # 2. Calculate loss
            validation_loss += loss_function(validation_pred, y).item()

            # 3. Calculate metrics
            for metric in metrics:
                metric.update(validation_pred, y)

        # 4. Take the averages of test loss and compute metrics
        validation_loss /= len(data_loader)
    validation_acc = calculate_accuracy.compute().item() * 100
    validation_f1 = calculate_f1.compute().item() * 100
    print(
        f"Validation loss: {validation_loss:.5f} | Validation accuracy: {validation_acc:.2f}%\n"
    )
    validation_metrics = [validation_acc, validation_f1]
    return validation_loss, validation_metrics

In [None]:
EPOCHS = 10
epochs_range = range(1, EPOCHS + 1)
train_losses, train_accuracies, train_f1s = (
    [],
    [],
    [],
)
validation_losses, validation_accuracies, validation_f1s = (
    [],
    [],
    [],
)

for epoch in epochs_range:
    print(f"Epoch: {epoch}\n==========")
    train_loss, train_metrics = train_step(
        data_loader=train_dataloader,
        model=resnet_50,
        loss_function=loss_function,
        optimiser=optimiser,
        metrics=metrics,
        device=device,
    )
    train_losses.append(train_loss)
    train_accuracies.append(train_metrics[0])
    train_f1s.append(train_metrics[1])

    validation_loss, validation_metrics = validation_step(
        data_loader=validation_dataloader,
        model=resnet_50,
        loss_function=loss_function,
        metrics=metrics,
        device=device,
    )
    validation_losses.append(validation_loss)
    validation_accuracies.append(validation_metrics[0])
    validation_f1s.append(validation_metrics[1])

    scheduler.step(validation_loss)

## 11. Results (Custom ResNet-50)
### Overall Performance

In [None]:
train_metrics = {
    "Loss": train_losses,
    "Accuracy": train_accuracies,
    "F1 Score": train_f1s,
}

validation_metrics = {
    "Loss": validation_losses,
    "Accuracy": validation_accuracies,
    "F1 Score": validation_f1s,
}

In [None]:
def plot_results(epochs_range, train_metrics, validation_metrics) -> None:
    metric_names = list(train_metrics.keys())
    n_metrics = len(metric_names)
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
    axes = axes.flatten()

    for i, metric_name in enumerate(metric_names):
        ax = axes[i]
        ax.plot(
            epochs_range, train_metrics[metric_name], label=f"Train {metric_name}"
        )  # Train metric
        ax.plot(
            epochs_range,
            validation_metrics[metric_name],
            label=f"Validation {metric_name}",
        )  # Validation metric
        ax.set_title(f"{metric_name} Over Epochs", fontsize=15)
        ax.legend()
        ax.set_xlabel("Epoch")
        if metric_name == "Loss":
            ax.set_ylabel("Loss")
        else:
            ax.set_ylabel(f"{metric_name} (%)")

    if n_metrics < len(axes):
        for j in range(n_metrics, len(axes)):
            plt.delaxes(axes[j])
    plt.tight_layout()
    plt.show()

In [None]:
plot_results(epochs_range, train_metrics, validation_metrics)

### Classifications

In [None]:
def make_all_predictions(
    model: nn.Module,
    data_loader: DataLoader,
    loss_function: nn.Module,
    calculate_accuracy: nn.Module,
    device: torch.device = device,
):
    y_preds = []
    y_labels = []
    model.to(device)
    model.eval()
    test_loss = 0
    calculate_accuracy.reset()
    with torch.inference_mode():
        for X, y in data_loader:
            X, y = X.to(device), y.to(device)

            # 1. Forward pass
            test_pred = model(X)
            y_prob = torch.softmax(test_pred, dim=1)
            y_pred = y_prob.argmax(dim=1)
            y_preds.append(y_pred.cpu())
            y_labels.append(y.cpu())
            # 2. Calculate test loss
            test_loss += loss_function(test_pred, y).item()

            # 3. Calculate test accuracy
            calculate_accuracy.update(test_pred, y)

        # 4. Take the averages of test loss and compute metrics
        test_loss /= len(data_loader)
    test_acc = calculate_accuracy.compute().item() * 100
    print(f"Test loss: {test_loss:.5f} | Test accuracy: {test_acc:.2f}%\n")
    y_preds_tensor = torch.cat(y_preds)
    y_labels_tensor = torch.cat(y_labels)
    return y_preds_tensor, y_labels_tensor

In [None]:
all_preds, all_labels = make_all_predictions(
    model=resnet_50,
    data_loader=test_dataloader,
    loss_function=loss_function,
    calculate_accuracy=calculate_accuracy,
    device=device,
)
print(all_preds)

### Missclassifications

In [None]:
wrong_indices = (all_preds != all_labels).nonzero(as_tuple=True)[0]
print(
    f"Number of failed predictions: {len(wrong_indices)}/{len(all_labels)} ({100.0 * len(wrong_indices) / len(all_labels):.2f}%)"
)

### Confusion Matrix

In [None]:
from torchmetrics import ConfusionMatrix
from mlxtend.plotting import plot_confusion_matrix

# Convert to tensors if not already
true_labels = all_labels
pred_tensor = all_preds.detach().clone()

# Compute confusion matrix
conf_matrix = ConfusionMatrix(num_classes=len(class_names), task="multiclass")
conf_matrix_tensor = conf_matrix(pred_tensor, true_labels)

# Plot
fig, ax = plot_confusion_matrix(
    conf_mat=conf_matrix_tensor.numpy(), class_names=class_names, figsize=(10, 7)
)


### Conclusion (Custom ResNet-50)
After 10 epochs, our custom ResNet-50 model achieved an accuracy of $75.82$% on the test dataset (misclassification rate of $24.18$%). The confusion matrix illustrates the distribution of correctly and incorrectly classified images across each label.

A detailed inspection reveals the following misclassification patterns:

- Horses (cavallo) were sometimes misclassified as butterflies (farfalla).
- Dogs (cane) were occasionally mistaken for other animals, such as chickens (gallina), cats (gatto), cows (mucca) and sheep (pecora).

Around epoch 10, the model began to overfit, as indicated by a continued increase in training accuracy while validation accuracy fluctuated or decreased in subsequent epochs.

## 12. Transfer Learning
We can use pre-trained models to improve performance on related tasks, while simultaneously reducing both the amount of training data and the time required. The use of pre-trained models to address other, related problems is known as transfer learning.

Setting `param.requires_grad = False` tells PyTorch not to compute gradients for all the parameters (weights and biases) of the model ResNet-50 so that they stay frozen during training.

In [None]:
from torchvision import models

pretrained_resnet_50 = models.resnet50(
    weights="IMAGENET1K_V1"
)  # Pre-trained ResNet-50 Model
for param in pretrained_resnet_50.parameters():
    param.requires_grad = False

We typically freeze layers to train only the new, custom layers, such as the final classifier layer, avoiding overfitting and speeding up training.

In [None]:
n_features = pretrained_resnet_50.fc.in_features
hidden_dim = 256
pretrained_resnet_50.fc = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(n_features, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, len(class_names)),
)

In [None]:
from torchinfo import summary

summary(
    pretrained_resnet_50,
    input_size=(
        1,
        3,
        IMAGE_SIZE,
        IMAGE_SIZE,
    ),  # (batch_size, colour channels, height, width)
    verbose=0,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

The optimiser must be re-initialised only with the unfrozen parameters.

In [None]:
optimiser = torch.optim.SGD(
    filter(lambda p: p.requires_grad, pretrained_resnet_50.parameters()),
    lr=0.005,
    weight_decay=0.0005,
    momentum=0.9,
)

scheduler = ReduceLROnPlateau(optimiser, mode="min", factor=0.5, patience=3)

In [None]:
train_losses, train_accuracies, train_f1s = (
    [],
    [],
    [],
)
validation_losses, validation_accuracies, validation_f1s = (
    [],
    [],
    [],
)

for epoch in epochs_range:
    print(f"Epoch: {epoch}\n==========")
    train_loss, train_metrics = train_step(
        data_loader=train_dataloader,
        model=pretrained_resnet_50,  # Use pre-trained ResNet-50
        loss_function=loss_function,
        optimiser=optimiser,
        metrics=metrics,
        device=device,
    )
    train_losses.append(train_loss)
    train_accuracies.append(train_metrics[0])
    train_f1s.append(train_metrics[1])

    validation_loss, validation_metrics = validation_step(
        data_loader=validation_dataloader,
        model=pretrained_resnet_50,  # Use pre-trained ResNet-50
        loss_function=loss_function,
        metrics=metrics,
        device=device,
    )
    validation_losses.append(validation_loss)
    validation_accuracies.append(validation_metrics[0])
    validation_f1s.append(validation_metrics[1])

    scheduler.step(validation_loss)

## 13. Results (Pre-Trained ResNet-50)
### Overall Performance

In [None]:
train_metrics = {
    "Loss": train_losses,
    "Accuracy": train_accuracies,
    "F1 Score": train_f1s,
}

validation_metrics = {
    "Loss": validation_losses,
    "Accuracy": validation_accuracies,
    "F1 Score": validation_f1s,
}

In [None]:
plot_results(epochs_range, train_metrics, validation_metrics)

### Classifications

In [None]:
all_preds, all_labels = make_all_predictions(
    model=pretrained_resnet_50,  # Use pre-trained ResNet-50
    data_loader=test_dataloader,
    loss_function=loss_function,
    calculate_accuracy=calculate_accuracy,
    device=device,
)
print(all_preds)

### Missclassifications

In [None]:
wrong_indices = (all_preds != all_labels).nonzero(as_tuple=True)[0]
print(
    f"Number of failed predictions: {len(wrong_indices)}/{len(all_labels)} ({100.0 * len(wrong_indices) / len(all_labels):.2f}%)"
)

### Confusion Matrix

In [None]:
from torchmetrics import ConfusionMatrix
from mlxtend.plotting import plot_confusion_matrix

# Convert to tensors if not already
true_labels = all_labels
pred_tensor = all_preds.detach().clone()

# Compute confusion matrix
conf_matrix = ConfusionMatrix(num_classes=len(class_names), task="multiclass")
conf_matrix_tensor = conf_matrix(pred_tensor, true_labels)

# Plot
fig, ax = plot_confusion_matrix(
    conf_mat=conf_matrix_tensor.numpy(), class_names=class_names, figsize=(10, 7)
)

### Conclusion (Pre-Trained ResNet-50)

## 14. References

1. Aditi Rastogi. (2022). *ResNet50*. <br>
https://blog.devgenius.io/resnet50-6b42934db431

1. Bader Dammak. (2023). *ResNet-50 training from scratch on Animal-10 Dataset*.<br>
https://github.com/Darkmyter/Popular-models-implemented-in-Pytorch/blob/main/Computer-vision/image-classification/resnet-50-animal-10.ipynb

1. Dhruv Matani. (2023). *A Practical Guide to Transfer Learning using PyTorch*.<br>
https://www.kdnuggets.com/2023/06/practical-guide-transfer-learning-pytorch.html

1. He, K., Zhang, X., Ren, S., Sun, J. (2015). *Deep Residual Learning for Image Recognition*. arXiv preprint arXiv:1512.03385.<br>
https://arxiv.org/abs/1512.03385

1. PyTorch Docs. (n.d.). *torchvision.models.resnet*<br>
https://docs.pytorch.org/vision/stable/_modules/torchvision/models/resnet.html#ResNet50_Weights

1. PyTorch Docs. (n.d.). *Models and pre-trained weights*.<br>
https://docs.pytorch.org/vision/0.21/models.html#models-and-pre-trained-weights