In [None]:
# COLAB NECESSARY CELL
from google.colab import drive

# Clone repository with all classes
github_pat = input("Github Personal Access Token")
!git clone https://{github_pat}@github.com/rbngz/IMP-2023.git

# Navigate to repository source
%cd /content/IMP-2023/src

# Mount drive folder
drive.mount('/content/drive')

# Install additional libraries
!pip install -U wandb lightning

In [1]:
import wandb
import torch
import numpy as np
import pandas as pd
import lightning as L
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize
from torch.nn import MSELoss, L1Loss
from lightning.pytorch.loggers import WandbLogger

from dataset_new import SentinelDataset
from utils import get_dataset_stats
from transforms import TargetNormalize
from model import UNet

In [2]:
# Random seed for splitting
SEED = 42


# File paths
SAMPLES_PATH = "/content/drive/MyDrive/imp-data/samples/samples_S2S5P_2018_2020_eea.csv"
DATA_DIR  = "/content/drive/MyDrive/imp-data/sentinel-2-eea"

# Hyperparameters
N_PATCHES = 4
PATCH_SIZE = 64
BATCH_SIZE = 8
LEARNING_RATE = 1e-4

In [None]:
# Read the samples file
samples_df = pd.read_csv(SAMPLES_PATH, index_col="idx")

# Remove NA measurements
samples_df = samples_df[~samples_df["no2"].isna()]

# Random shuffle
samples_df = samples_df.sample(frac=1, random_state=SEED)

In [None]:
# Split samples dataframe to avoid sampling patches across sets
df_train, df_val, df_test = np.split(
    samples_df, [int(0.7 * len(samples_df)), int(0.85 * len(samples_df))]
)
print(f"Train set: {len(df_train)}, Validation set: {len(df_val)}, Test set: {len(df_test)}")

In [None]:
# Get statistics for normalization
stats_train = get_dataset_stats(df_train, DATA_DIR)
print(stats_train)

In [None]:
# Create normalizers for bands and NO2 measurements
band_normalize = Normalize(stats_train["band_means"], stats_train["band_stds"])
no2_normalize = TargetNormalize(stats_train["no2_mean"], stats_train["no2_std"])

In [None]:
# Create Train Dataset
dataset_train = SentinelDataset(
    df_train,
    DATA_DIR,
    n_patches=N_PATCHES,
    patch_size=PATCH_SIZE,
    pre_load=False,
    transform=ToTensor(),
)

# Create Validation Dataset
dataset_val = SentinelDataset(
    df_val,
    DATA_DIR,
    n_patches=N_PATCHES,
    patch_size=PATCH_SIZE,
    pre_load=False,
    transform=ToTensor(),
)

# Create Test Dataset
dataset_test = SentinelDataset(
    df_test,
    DATA_DIR,
    n_patches=N_PATCHES,
    patch_size=PATCH_SIZE,
    pre_load=False,
    transform=ToTensor(),
)

print(
    f"Train dataset: {len(dataset_train)}, Validation dataset: {len(dataset_val)}, Test dataset: {len(dataset_test)}"
)

In [None]:
# Create Dataloaders
dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=12, persistent_workers=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, num_workers=12, persistent_workers=True)
dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE)

In [None]:
# Define Pytorch lightning model

class Model(L.LightningModule):
    def __init__(self):
        super().__init__()

        # Set model
        self.model = UNet(
            enc_chs=(12, 64, 128, 256, 512),
            dec_chs=(512, 256, 128, 64)
            )

        # Set hyperparameters
        self.loss = MSELoss()
        self.mae = L1Loss()
        self.n_patches = N_PATCHES
        self.patch_size = PATCH_SIZE
        self.batch_size = BATCH_SIZE
        self.lr = LEARNING_RATE

        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        loss, mae = self._step(batch)
        self.log("train_loss", loss)
        self.log("train_mae", mae)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, mae = self._step(batch)
        self.log("val_loss", loss)
        self.log("val_mae", mae)
        return loss
    
    def test_step(self, batch, batch_idx):
        loss, mae = self._step(batch)
        self.log("test_loss", loss)
        self.log("test_mae", mae)
        return loss
    
    def _step(self, batch):
        # Unpack batch
        patches, measurements, coords = batch

        # Normalize band and measurement data
        patches_norm = band_normalize(patches)
        measurements_norm = no2_normalize(measurements)

        # Get normalized predictions
        predictions_norm = self.model(patches_norm)

        # Extract values in coordinate location
        target_values_norm = torch.diag(predictions_norm[:, 0, coords[0], coords[1]])

        # Compute loss on normalized data
        loss = self.loss(target_values_norm, measurements_norm)

        # Compute Mean Absolute Error on unnormalized data
        target_values = no2_normalize.revert(target_values_norm)
        mae = self.mae(target_values, measurements)

        return loss, mae
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [None]:
# Instantiate Model
model = Model()

In [None]:
# Get logger for weights & biases
wandb_logger = WandbLogger(project="IMP-2023")

In [None]:
# Train model
trainer = L.Trainer(max_epochs=100, logger=wandb_logger, log_every_n_steps=400)
trainer.fit(model=model, train_dataloaders=dataloader_train, val_dataloaders=dataloader_val)

In [None]:
wandb.finish()