Convert to python script after running top to bottom in Jupyter without interactions.

In [1]:
import os

# for when on NCC to be able to import local packages
os.chdir(os.path.expanduser("~/l3_project"))

In [2]:
from pathlib import Path
import platform
import typing as t

import dataset_processing
import helpers

import numpy as np
import torch
import torch.nn as nn

lg = helpers.logging.get_logger("main")
lg.debug("Successfully imported packages.")

  from tqdm.autonotebook import tqdm


In [3]:
if torch.cuda.is_available():
    torch_device = torch.device('cuda')
    lg.debug(f'Found {torch.cuda.get_device_name()} to use as a cuda device.')
elif platform.system() == 'Darwin':
    torch_device = torch.device('mps')
else:
    torch_device = torch.device('cpu')
lg.info(f'Using {torch_device} as torch device.')

if platform.system() != 'Linux':
    torch.set_num_threads(1)
    lg.debug('Set number of threads to 1 as using a non-Linux machine.')

In [4]:
np_rng = np.random.default_rng(42)
_ = torch.manual_seed(42)

In [5]:
checkpoints_path = Path.cwd() / 'checkpoints'
checkpoints_path.mkdir(exist_ok=True)

In [15]:
DATASET_NAMES = t.Literal["EuroSATRGB", "EuroSATMS"]

In [16]:
def get_num_classes(
        name: DATASET_NAMES
) -> int:
    if name in ["EuroSATRGB", "EuroSATMS"]:
        n = 10
    else:
        lg.error(f"Invalid dataset name ({name}) provided to get_num_classes.")
        raise ValueError(f"Dataset {name} does not exist.")

    return n

In [22]:
def get_dataset_object(
        name: DATASET_NAMES,
        split: t.Literal["train", "val", "test"],
        image_size: int,
        download: bool = False,
        do_transforms: bool = True,
):
    if name == "EuroSATRGB":
        lg.debug("Loading EuroSATRGB dataset...")
        ds = dataset_processing.eurosat.EuroSATRGB(
            split, image_size, download=download, do_transforms=do_transforms
        )
    elif name == "EuroSATMS":
        lg.debug("Loading EuroSATMS dataset...")
        ds = dataset_processing.eurosat.EuroSATMS(
            split, image_size, download=download, do_transforms=do_transforms
        )
    else:
        lg.error(f"Invalid dataset name ({name}) provided to get_dataset_object.")
        raise ValueError(f"Dataset {name} does not exist.")

    lg.info(f"Dataset {name} loaded with {len(ds)} samples.")
    return ds

In [18]:
def get_model_object(
        name: t.Literal["ResNet50"],
        num_classes: int,
) -> nn.Module:
    if name == "ResNet50":
        lg.debug("Loading ResNet50 model...")
        m = helpers.models.FineTunedResNet50(num_classes)
    else:
        lg.error(f"Invalid model name ({name}) provided to get_model_object.")
        raise ValueError(f"Model {name} does not exist.")

    return m

In [19]:
dataset_name = "EuroSATMS"
model_name = "ResNet50"
model = get_model_object(model_name, get_num_classes(dataset_name))

224

In [155]:
training_dataset = get_dataset_object(dataset_name, "train", model.expected_input_size)