# Fine-tuning a ResNet-50 model

## Imports

In [None]:
import os

# for when on NCC to be able to import local packages
os.chdir(os.path.expanduser("~/l3_project"))

In [None]:
from pathlib import Path
import platform
import time

import einops
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.models
import torchvision.transforms.v2 as transforms
import wandb
import safetensors.torch as st

import dataset_processing.eurosat

print(f'Using PyTorch {torch.__version__} on {platform.system()}')

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f'Found {torch.cuda.get_device_name()} to use as a cuda device.')
elif platform.system() == 'Darwin':
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f'Using {device} as torch device.')

if platform.system() != 'Linux':
    torch.set_num_threads(1)  # significantly speeds up data loading processes with less loading overhead
    # see https://discuss.pytorch.org/t/pytorch-v2-high-cpu-consumption/205990 and https://discuss.pytorch.org/t/cpu-usage-far-too-high-and-training-inefficient/57228
    print('Set number of threads to 1 as using a non-Linux machine.')

In [None]:
np_rng = np.random.default_rng(42)
torch.manual_seed(42)

In [None]:
dataset_processing.core.get_dataset_root(), Path.getcwd()

In [None]:
checkpoints_path = Path.cwd() / 'checkpoints' / 'resnet50'
checkpoints_path.mkdir(exist_ok=True)

# General ResNet-50 model

In [None]:
class FineTunedResNet50(nn.Module):
    def __init__(self, num_classes: int):
        """
        Initialise a ResNet-50 model with the final linear layer replaced to output the desired number of classes.
        """

        super().__init__()
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

    def freeze_layers(self, keep: int):
        """
        Freeze layers (requires_grad = False) from the first input layer leaving the last `keep` layers (inc. output).
        :param keep: Number of layers from output (inc.) to keep unfrozen.
            e.g. keep=1 means only output layer is trainable.
        """

        for dist_from_output, layer in enumerate(reversed(self.model.children())):
            if dist_from_output >= keep:
                for param in layer.parameters():
                    param.requires_grad = False

    def unfreeze_layers(self):
        """
        Unfreeze all layers in the model.
        """

        for param in self.parameters():
            param.requires_grad = True

    def extra_repr(self):
        """
        Add additional detail on number of frozen layers.
        :return:
        """
        num_frozen = 0
        frozen_layers = []
        for layer in self.model.children():
            for param in layer.parameters():
                if not param.requires_grad:
                    num_frozen += 1
                    frozen_layers.append(layer)
                    break

        return f"> {num_frozen} layers frozen: {', '.join([layer.__class__.__name__ for layer in frozen_layers])} <"

In [None]:
torchvision.models.ResNet50_Weights.transforms

# EuroSAT dataset

## Load dataset

In [None]:
base_transforms = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=False),  # scaling handles by normalise below
    dataset_processing.core.RSNormaliseTransform(),  # normalise to [0, 1] (based on 1st and 99th percentiles)
    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),  # shift to mean 0 and std 1

    # scale as expected by ResNet (see https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
])
wrapped_base_transforms = dataset_processing.core.tensor_dict_transform_wrapper(base_transforms)

training_transforms = transforms.Compose([
    base_transforms,
    # Randomised transforms:
    transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomAffine(0, shear=0.2),  # Shear with range 0.2
    # transforms.RandomAffine(0, scale=(1., 1.2)),  # Zoom in with range 0.2
])
wrapped_training_transforms = dataset_processing.core.tensor_dict_transform_wrapper(training_transforms)

In [None]:
eurosat_train_ds = dataset_processing.eurosat.get_dataset(
    "train", transforms=wrapped_training_transforms, download=False
)
eurosat_val_ds = dataset_processing.eurosat.get_dataset(
    "val", transforms=wrapped_base_transforms, download=False
)

print(f"There are {len(eurosat_train_ds)} training samples and {len(eurosat_val_ds)} validation samples.")
print("Image dimensions and label:", eurosat_train_ds[0]["image"].size(), eurosat_train_ds[0]["label"])

### Visualise some images

In [None]:
random_indices = np_rng.choice(len(eurosat_train_ds), size=25, replace=False)
plt.figure(figsize=(10, 10), tight_layout=True)
for i, idx in enumerate(random_indices):
    ax = plt.subplot(5, 5, i + 1)
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.imshow(einops.rearrange(eurosat_train_ds[idx]["image"], "c h w -> h w c"))
    plt.axis("off")
plt.show()

## Training/Fine-tuning

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    eurosat_train_ds, batch_size=32, num_workers=4, shuffle=True, drop_last=True
)
val_dataloader = torch.utils.data.DataLoader(
    eurosat_val_ds, batch_size=32, num_workers=4, shuffle=False, drop_last=False
)

### Initialise model

In [None]:
resnet50_model = FineTunedResNet50(num_classes=len(eurosat_train_ds.classes))
resnet50_model.freeze_layers(1)
print(resnet50_model)

In [None]:
model_to_train = resnet50_model.to(device)

### Set up training criteria and optimiser

In [None]:
criterion = nn.CrossEntropyLoss()

learning_rate = 0.01
decay = 1e-6
momentum = 0.9

In [None]:
parameters_to_optimise = filter(lambda p: p.requires_grad, model_to_train.parameters())
optimiser = torch.optim.SGD(
    parameters_to_optimise, lr=learning_rate, weight_decay=decay, momentum=momentum, nesterov=True
)

### Track with Weights & Biases

In [None]:
run = wandb.init(
    save_code=True,
    project="evaluating_xAI_for_RS",
    name="",
    notes="",
    tags=[],
    id="",  # REMEMBER TO CHANGE
    resume="never",  # 'allow' to resume a crashed run
    config={
        "dataset": "EuroSAT",
        "transforms": repr(training_transforms),
        "batch_size": train_dataloader.batch_size,

        "model": {
            "name": model_to_train.__class__.__name__,
            "architecture": repr(model_to_train),
        },
        "training": {
            "optimiser": repr(optimiser),
            "learning_rate": learning_rate,
        },

        "initialisation_time": time.asctime()
    }
)

### Training loop

In [None]:
training_loss_arr = np.zeros(0)
training_acc_arr = np.zeros(0)

for epoch in range(num_epochs):
    print(f"Epoch {epoch:03}")
    for i, data in enumerate(train_dataloader):
        model_to_train.train()
        images = data["image"].to(device)
        labels: torch.Tensor = data["label"].to(device)
        train_step(model_to_train, images, labels, criterion, optimiser, training_loss_arr, training_acc_arr,
                   validation_iterator, i)

    if epoch != 0 and epoch % 20 == 0:
        st.save_model(model_to_train, f"checkpoints/resnet50_eurosat_epoch_{epoch:03}.safetensors",
                      metadata={"epoch": str(epoch)})
        print(f"Model saved at epoch {epoch:03}.")

In [None]:
st.save_model(model_to_train, checkpoints_path / f"{model_to_train.__class__.__name__}_final_weights.st")

### Upload final model stats

In [None]:
run.summary["steps_trained"]
run.summary["final_loss/validation"]
run.summary["final_accuracy/validation"]