In [9]:
import torch
import numpy as np
import lightning as L
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import ToTensor, Normalize
from torch.nn import MSELoss, L1Loss
from lightning.pytorch.loggers import WandbLogger

from dataset_new import SentinelDataset
from plotting import plot_patch, plot_coords_distribution
from utils import get_dataset_stats
from transforms import TargetNormalize
from model import UNet

In [10]:
# Create Dataset
dataset = SentinelDataset(
    "../data/samples/samples_S2S5P_2018_2020_eea.csv",
    "../data/sentinel-2-eea",
    n_patches=4,
    patch_size=256,
    pre_load=False,
    transform = ToTensor()
)

In [11]:
# Split dataset
generator = torch.Generator().manual_seed(42)
dataset_train, dataset_val, dataset_test = random_split(dataset, [0.70, 0.15, 0.15], generator=generator)

In [15]:
# Create Dataloaders
batch_size = 16
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=9, persistent_workers=True)
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, num_workers=9, persistent_workers=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size)

In [12]:
# Compute statistics on training dataset
# stats_train = get_dataset_stats(dataset_train)
stats_train = {
    "band_means": np.array(
        [
            945.33649575,
            883.58855028,
            668.64586693,
            2310.99404037,
            1278.33074938,
            1972.10197878,
            2223.2728232,
            2376.76157865,
            2052.98432657,
            1548.28168202,
            562.55420622,
            2380.19286924,
        ]
    ),
    "band_stds": np.array(
        [
            530.8678687,
            432.87239649,
            405.77116269,
            814.35871003,
            431.60858877,
            556.86715088,
            659.36416425,
            705.73567854,
            566.20799311,
            555.80665277,
            216.07139462,
            564.50506163,
        ]
    ),
    "no2_mean": 20.683807196968576,
    "no2_std": 11.520772291494632,
}

In [14]:
# Create normalizers for bands and NO2 measurements
band_normalize = Normalize(stats_train["band_means"], stats_train["band_stds"])
no2_normalize = TargetNormalize(stats_train["no2_mean"], stats_train["no2_std"])

In [22]:
# Define Pytorch lightning model

class Model(L.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()

        # Set model
        self.model = UNet()

        # Set hyperparameters
        self.loss = MSELoss()
        self.mae = L1Loss()
        self.lr = lr

        self.save_hyperparameters(ignore=['model'])

    def training_step(self, batch, batch_idx):
        loss, mae = self._step(batch)
        self.log("train_loss", loss)
        self.log("train_mae", mae)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, mae = self._step(batch)
        self.log("val_loss", loss)
        self.log("val_mae", mae)
        return loss
    
    def test_step(self, batch, batch_idx):
        loss, mae = self._step(batch)
        self.log("test_loss", loss)
        self.log("test_mae", mae)
        return loss
    
    def _step(self, batch):
        # Unpack batch
        patches, measurements, coords = batch

        # Normalize band and measurement data
        patches_norm = band_normalize(patches)
        measurements_norm = no2_normalize(measurements)

        # Get normalized predictions
        predictions_norm = self.model(patches_norm)

        # Extract values in coordinate location
        target_values_norm = torch.diag(predictions_norm[:, 0, coords[0], coords[1]])

        # Compute loss on normalized data
        loss = self.loss(target_values_norm, measurements_norm)

        # Compute Mean Absolute Error on unnormalized data
        target_values = no2_normalize.revert(target_values_norm)
        mae = self.mae(target_values, measurements)

        return loss, mae
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [23]:
# Instantiate Model
model = Model()

In [20]:
# Get logger for weights & biases
wandb_logger = WandbLogger(project="IMP-2023")

In [21]:
# Train model
trainer = L.Trainer(limit_train_batches=100, max_epochs=1, logger=wandb_logger, log_every_n_steps=1, val_check_interval=20)
trainer.fit(model=model, train_dataloaders=dataloader_train, val_dataloaders=dataloader_val)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrubengaviles[0m ([33mimp-2023[0m). Use [1m`wandb login --relogin`[0m to force relogin



  | Name  | Type    | Params
----------------------------------
0 | model | UNet    | 31.0 M
1 | loss  | MSELoss | 0     
2 | mae   | L1Loss  | 0     
----------------------------------
31.0 M    Trainable params
0         Non-trainable params
31.0 M    Total params
124.148   Total estimated model params size (MB)


Epoch 0:   1%|          | 1/100 [00:21<35:58,  0.05it/s, v_num=fwq7]       

/Users/rubengonzalez/miniconda3/envs/imp/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
