In [1]:
import os
import torch
import numpy as np
import pandas as pd
import lightning as L

from torchvision.transforms.v2 import ToImage, Compose
from torch.nn import MSELoss, L1Loss, CrossEntropyLoss
from torch.utils.data import DataLoader
from torcheval.metrics import R2Score

from core.dataset import SentinelDataset
from core.model import UNet
from core.utils import normalize_rgb_bands
from core.transforms import BandNormalize, TargetNormalize

In [2]:
# Parameters
DATA_DIR = "data"
DATA_SOURCE = "eea"
SKIP_CONNECTIONS = False
LAND_COVER = False
CHECKPOINT = "models/unet_ae_s2s5p/ae_no2.ckpt"
PRED_SIZE = 8
PATCH_SIZE = 128

In [3]:
# Random seed for splitting
SEED = 42
L.seed_everything(42, workers=True)

Seed set to 42


42

In [4]:
SAMPLES_PATH = os.path.join(
    DATA_DIR, f"samples/samples_S2S5P_2018_2020_{DATA_SOURCE}.csv"
)

In [5]:
samples_df = pd.read_csv(SAMPLES_PATH, index_col="idx")
# Remove NA measurements
samples_df = samples_df[~samples_df["no2"].isna()]

# Exclude samples for which no valid land cover ground truth is present ~200
if DATA_SOURCE== "eea":
    valid_land_cover_stations = []
    land_cover_path = os.path.join(DATA_DIR, "worldcover")
    for file in os.listdir(land_cover_path):
        lc = np.load(os.path.join(land_cover_path, file))
        if lc.shape == (200, 200):
            valid_land_cover_stations.append(file[:-4])

    samples_df = samples_df.loc[
        samples_df["AirQualityStation"].isin(valid_land_cover_stations)
    ]

In [6]:
# Random shuffle
samples_df = samples_df.sample(frac=1)
# Split samples dataframe to avoid sampling patches across sets
if DATA_SOURCE== "eea":
    df_train, df_val, df_test = np.split(
        samples_df, [int(0.7 * len(samples_df)), int(0.85 * len(samples_df))]
    )
else:
    df_test = samples_df

  return bound(*args, **kwds)


In [7]:
stats_train = {
    "band_means": np.array(
        [
            9.48530855e02,
            8.85735776e02,
            6.69150641e02,
            2.31917082e03,
            1.28305729e03,
            1.97936703e03,
            2.23097768e03,
            2.38542771e03,
            2.06128535e03,
            1.55304775e03,
            5.61707384e02,
            2.38793057e03,
            2.78282474e15,
        ]
    ),
    "band_stds": np.array(
        [
            6.69703046e02,
            5.34104387e02,
            4.92931981e02,
            1.00871675e03,
            5.90429095e02,
            7.41831336e02,
            8.81957446e02,
            9.48550975e02,
            8.09822953e02,
            7.57415252e02,
            3.27955892e02,
            8.74079961e02,
            1.36616750e15,
        ]
    ),
    "no2_mean": 20.973578214241755,
    "no2_std": 11.575741710970245,
}



# Create normalizers for bands and NO2 measurements
band_normalize = BandNormalize(stats_train["band_means"], stats_train["band_stds"])
no2_normalize = TargetNormalize(stats_train["no2_mean"], stats_train["no2_std"])

# Create transforms for images and measurements
s2_transform = Compose([ToImage(), band_normalize])
no2_transform = no2_normalize

In [9]:
# Create Test Dataset
dataset_test = SentinelDataset(
    samples_df,
    DATA_DIR,
    n_patches=1,
    patch_size=PATCH_SIZE,
    pred_size=PRED_SIZE,
    pre_load=False,
    s2_transform=s2_transform,
    no2_transform=no2_transform,
    data_source=DATA_SOURCE
)

In [10]:
dataloader_test = DataLoader(dataset_test, batch_size=8, shuffle=False)

In [11]:
class_weights = [
    1.88037399e02,
    1.64854777e00,
    5.85467415e01,
    3.36591268e00,
    6.51197767e00,
    1.00000000e00,
    9.56108398e01,
    3.37424575e06,
    1.38434725e01,
    3.68938629e02,
    2.18333547e05,
]


# Define Pytorch lightning model
class Model(L.LightningModule):
    def __init__(self, model, lr, include_lc, lc_loss_weight, lc_class_weights):
        super().__init__()

        # Set model
        self.model = model

        # Set hyperparameters
        self.no2_loss = MSELoss()
        self.no2_mae = L1Loss()
        self.lc_loss = CrossEntropyLoss(weight=torch.tensor(lc_class_weights))
        self.include_lc = include_lc
        self.lc_loss_weight = lc_loss_weight
        self.lr = lr

    def training_step(self, batch, batch_idx):
        no2_loss, no2_mae, lc_loss = self._step(batch)
        total_loss = no2_loss + (self.lc_loss_weight * lc_loss)
        self.log("train_no2_loss", no2_loss)
        self.log("train_no2_mae", no2_mae)
        self.log("train_lc_loss", lc_loss)
        self.log("train_total_loss", total_loss)
        loss = total_loss if self.include_lc else no2_loss
        return loss

    def validation_step(self, batch, batch_idx):
        no2_loss, no2_mae, lc_loss = self._step(batch, batch_idx == 0)
        total_loss = no2_loss + (self.lc_loss_weight * lc_loss)
        self.log("val_no2_loss", no2_loss)
        self.log("val_no2_mae", no2_mae)
        self.log("val_lc_loss", lc_loss)
        self.log("val_total_loss", total_loss)
        loss = total_loss if self.include_lc else no2_loss
        return loss

    def test_step(self, batch, batch_idx):
        no2_loss, no2_mae, lc_loss, no2_mse, r2_score = self._step(batch,test=True)
        total_loss = no2_loss + (self.lc_loss_weight * lc_loss)
        self.log("test_no2_loss", no2_loss)
        self.log("test_no2_mae", no2_mae)
        self.log("test_no2_mse", no2_mse)
        self.log("test_no2_r2", r2_score)
        self.log("test_lc_loss", lc_loss)
        self.log("test_total_loss", total_loss)
        loss = total_loss if self.include_lc else no2_loss
        return loss

    def _step(self, batch, log_predictions=False, test=False):
        # Unpack batch
        patches_norm, lc_truth, measurements_norm, coords = batch

        # Get normalized predictions
        predictions_norm, land_cover_pred = self.model(patches_norm)

        # Extract values in coordinate location
        target_values_norm = torch.diag(predictions_norm[:, 0, coords[0], coords[1]])

        # Compute loss on normalized data
        no2_loss = self.no2_loss(target_values_norm, measurements_norm)

        # Center crop
        # lc_truth = CenterCrop(land_cover_pred.shape[-2:])(lc_truth)
        lc_loss = self.lc_loss(land_cover_pred, lc_truth)

        # Compute Mean Absolute Error on unnormalized data
        measurements = no2_normalize.revert(measurements_norm)
        target_values = no2_normalize.revert(target_values_norm)
        no2_mae = self.no2_mae(target_values, measurements)

        if test:
            no2_mse = self.no2_loss(target_values, measurements)
            metric = R2Score()
            metric.update(target_values, measurements)
            r2_score = metric.compute()
            return no2_loss, no2_mae, lc_loss, no2_mse, r2_score

        if log_predictions:
            self.logger.log_image(
                "images",
                [
                    torch.moveaxis(normalize_rgb_bands(im.cpu()), 0, 2).numpy()
                    for im in patches_norm[:, :3]
                ],
            )
            self.logger.log_image(
                "predictions", list(no2_normalize.revert(predictions_norm))
            )

        return no2_loss, no2_mae, lc_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


# Instantiate Model
unet = UNet(
    (13, 64, 128, 256, 512, 1024),
    (1024, 512, 256, 128, 64),
    SKIP_CONNECTIONS,
)

model = Model.load_from_checkpoint(
    checkpoint_path=CHECKPOINT,
    model=unet,
    lr=0.000005,
    include_lc=LAND_COVER,
    lc_loss_weight=0.1,
    lc_class_weights=class_weights,
)


In [12]:
trainer = L.Trainer()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.test(model, dataloader_test)

Missing logger folder: /Users/rubengonzalez/Coding/IMP-2023/lightning_logs
/Users/rubengonzalez/miniconda3/envs/imp/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 359/359 [03:30<00:00,  1.71it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_lc_loss          2.4139034748077393
      test_no2_loss         0.3697924315929413
      test_no2_mae          4.9932475090026855
      test_no2_mse          49.551387786865234
       test_no2_r2          0.4637247920036316
     test_total_loss        0.6111828088760376
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_no2_loss': 0.3697924315929413,
  'test_no2_mae': 4.9932475090026855,
  'test_no2_mse': 49.551387786865234,
  'test_no2_r2': 0.4637247920036316,
  'test_lc_loss': 2.4139034748077393,
  'test_total_loss': 0.6111828088760376}]