Convert to python script after running top to bottom in Jupyter without interactions.

In [1]:
import os

# for when on NCC to be able to import local packages
os.chdir(os.path.expanduser("~/l3_project"))

In [29]:
from pathlib import Path
import platform
import typing as t
import time

import numpy as np
import torch
import torch.nn as nn
import wandb
from tqdm.autonotebook import tqdm
import safetensors.torch as st

import dataset_processing
import helpers

lg = helpers.logging.get_logger("main")
lg.debug("Successfully imported packages.")

In [3]:
if torch.cuda.is_available():
    torch_device = torch.device('cuda')
    lg.debug(f'Found {torch.cuda.get_device_name()} to use as a cuda device.')
elif platform.system() == 'Darwin':
    torch_device = torch.device('mps')
else:
    torch_device = torch.device('cpu')
lg.info(f'Using {torch_device} as torch device.')

if platform.system() != 'Linux':
    torch.set_num_threads(1)
    lg.debug('Set number of threads to 1 as using a non-Linux machine.')

In [4]:
random_seed = 42  # todo: turn these into command line args

In [5]:
np_rng = np.random.default_rng(random_seed)
_ = torch.manual_seed(random_seed)
lg.debug(f'Random seed set to {random_seed}.')

In [6]:
checkpoints_root_name = "checkpoints_temp"  # todo: remove temp

In [7]:
checkpoints_path = Path.home() / "l3_project" / checkpoints_root_name
checkpoints_path.mkdir(exist_ok=True)
lg.debug(f'Checkpoints directory set to {checkpoints_path.resolve()}.')

In [8]:
DATASET_NAMES = t.Literal["EuroSATRGB", "EuroSATMS"]

In [9]:
def get_dataset_object(
        name: DATASET_NAMES,
        split: t.Literal["train", "val", "test"],
        image_size: int,
        download: bool = False,
        do_transforms: bool = True,
):
    kwargs = {
        "split": split,
        "image_size": image_size,
        "download": download,
        "do_transforms": do_transforms,
    }

    if name == "EuroSATRGB":
        lg.debug("Loading EuroSATRGB dataset...")
        ds = dataset_processing.eurosat.EuroSATRGB(**kwargs)
    elif name == "EuroSATMS":
        lg.debug("Loading EuroSATMS dataset...")
        ds = dataset_processing.eurosat.EuroSATMS(**kwargs)
    else:
        lg.error(f"Invalid dataset name ({name}) provided to get_dataset_object.")
        raise ValueError(f"Dataset {name} does not exist.")

    lg.info(f"Dataset {name} ({split}) loaded with {len(ds)} samples.")
    return ds

In [10]:
def get_model_type(
        name: t.Literal["ResNet50"],
) -> t.Type[helpers.models.FreezableModel]:
    if name == "ResNet50":
        lg.debug("Returning ResNet50 model type...")
        m = helpers.models.FineTunedResNet50
    else:
        lg.error(f"Invalid model name ({name}) provided to get_model_type.")
        raise ValueError(f"Model {name} does not exist.")

    return m

In [11]:
dataset_name = "EuroSATMS"
model_name = "ResNet50"

In [12]:
model_type = get_model_type(model_name)

In [13]:
training_dataset = get_dataset_object(dataset_name, "train", model_type.expected_input_dim)
validation_dataset = get_dataset_object(dataset_name, "val", model_type.expected_input_dim)

In [14]:
model = model_type(
    n_input_bands=training_dataset.N_BANDS,
    n_output_classes=training_dataset.N_CLASSES
).to(torch_device)

In [30]:
batch_size = 32
num_workers = 4

In [28]:
training_dataloader = torch.utils.data.DataLoader(
    training_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=True
)
validation_dataloader = torch.utils.data.DataLoader(
    validation_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
)

validation_iterator = iter(dataset_processing.core.cycle(validation_dataset))

## Fine tune first and final layer

In [17]:
model.freeze_layers(1)  # freeze all but the last layer
if model.modified_input_layer:  # unfreeze the input layer if we need to train it too
    model.unfreeze_input_layers(model.input_layers_to_train)

In [21]:
loss_criterion = nn.CrossEntropyLoss()
frozen_lr = 0.01
optimiser_name = "SGD"
lr_early_stop_threshold = 0.0001

In [22]:
def get_opt_and_scheduler(lr: float):
    opt: torch.optim.Optimizer = getattr(torch.optim, optimiser_name)(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=lr, weight_decay=1e-6, momentum=0.9, nesterov=True,
    )
    sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, factor=np.float_power(10, -1 / 4),  # requires 4 reductions to reduce by factor 10 (*0.1)
        patience=5, threshold=0.005
    )
    return opt, sch

In [24]:
wandb_track_run = True

In [None]:
def train_model(lr: float, frozen: bool = False):
    weights_save_path = checkpoints_path / training_dataset.__class__.__name__ / model.__class__.__name__
    if frozen:
        weights_save_path /= "frozen_partial"
    else:
        weights_save_path /= "full"
    lg.debug(f"Output path set to {weights_save_path.resolve()}.")

    optimiser, scheduler = get_opt_and_scheduler(lr)
    lg.debug(f"Initialised optimiser (lr={lr}) and scheduler.")

    if wandb_track_run:
        wandb_run = wandb.init(
            save_code=True,
            project="evaluating_xAI_for_RS",
            name=f"{dataset_name}_{model_name}{'_frozen' if frozen else ''}",
            notes="",
            tags=[dataset_name, model_name, "frozen" if frozen else "full"],
            id="",  # REMEMBER TO CHANGE
            resume="never",  # 'allow' to resume a crashed run
            config={
                "dataset": dataset_name,
                "batch_size": batch_size,

                "model": model_name,
                "model_repr": repr(model),
                "training": {
                    "optimiser": repr(optimiser),
                    "scheduler": repr(scheduler),
                    "early_stopping_threshold": lr_early_stop_threshold,
                },

                "wandb_init_time": time.asctime(),
                "save_path": str(weights_save_path.resolve()),
            }
        )
        lg.info(f"Initialised wandb run, id={wandb_run.id}.")
    else:
        wandb_run = None

    with tqdm(total=50, desc="Epochs") as prog_bar1:
        for epoch in range(50):
            training_loss_arr = np.zeros(0)
            training_acc_arr = np.zeros(0)

            with tqdm(total=len(training_dataloader), desc="Batches") as prog_bar2:
                lg.debug(f"Starting{' frozen' if frozen else ''} "
                         f"{model.__class__.__name__} network training epoch {epoch:03}.")
                for i, data in enumerate(training_dataloader):
                    images: torch.Tensor = data["image"]
                    labels: torch.Tensor = data["label"]

                    loss, acc = helpers.ml.train_step(
                        model, images, labels, loss_criterion, optimiser
                    )
                    training_loss_arr = np.append(training_loss_arr, loss)
                    training_acc_arr = np.append(training_acc_arr, acc)

                    prog_bar2.update()

                    if i > 0 and i % 100 == 0:
                        training_mean_loss = training_loss_arr.mean()
                        training_mean_acc = training_acc_arr.mean()

                        prog_bar2.set_postfix(train_loss=training_mean_loss, train_acc=training_mean_acc)
                        lg.debug(str(prog_bar2))

                        if wandb_run:
                            wandb_run.log({
                                "loss/train": training_mean_loss,
                                "accuracy/train": training_mean_acc,
                                "total_steps_trained": (epoch * len(training_dataloader)) + i,
                            })

                        training_loss_arr = np.zeros(0)
                        training_acc_arr = np.zeros(0)

            val_mean_loss, val_mean_acc = helpers.ml.validation_step(
                model, loss_criterion, validation_iterator, len(validation_dataloader)
            )

            scheduler.step(val_mean_loss)
            current_lr = scheduler.get_last_lr()[0]

            prog_bar1.update()
            prog_bar1.set_postfix(val_loss=val_mean_loss, val_acc=val_mean_acc, lr=current_lr)
            lg.info(str(prog_bar1))

            if wandb_run:
                wandb_run.log({
                    "loss/validation": val_mean_loss,
                    "accuracy/validation": val_mean_acc,
                    "learning_rate": current_lr,
                })

                if epoch != 0 and epoch % 10 == 0:
                    model_save_path = weights_save_path / f"{wandb_run.id}_epoch{epoch:03}.st"
                    st.save_model(model, model_save_path)
                    lg.info(f"Saved model at epoch {epoch} to {model_save_path}.")

            if current_lr < lr_early_stop_threshold:
                lg.info(
                    f"Early stopping on low learning rate {current_lr} (loss plateaued at {val_mean_loss} after lr reductions).")
                break

    model_save_path = weights_save_path / f"{wandb_run.id}_final_{val_mean_acc:.3f}.st"
    st.save_model(model, model_save_path)
    lg.info(f"Saved final model to {model_save_path}.")

    if wandb_run:
        wandb_run.summary["n_epochs"] = epoch
        wandb_run.finish(0)
        lg.info(f"Finished wandb run, id={wandb_run.id}.")


train_model(frozen_lr, True)

In [None]:
model.unfreeze_layers()
full_lr = 0.001

In [None]:
train_model(full_lr, False)