In [50]:
import os
import csv
import tempfile

import numpy as np
import matplotlib.pyplot as plt

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torchgeo.datamodules import CycloneDataModule
from torchgeo.datasets import Sentinel2
from torchgeo.trainers import RegressionTask

In [51]:
# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll
# skip the expensive training.
in_tests = "PYTEST_CURRENT_TEST" in os.environ

In [52]:
# API_KEY generata su https://mlhub.earth/profile, non scade mai, quindi usiamo sempre questa nel caso
MLHUB_API_KEY = os.environ.get("d8da54d06b97104d4669d0464f6395642bc0509d34f60813f93275b20faf46c4")

In [53]:
# Questa cella si crea la cartella con i data in /tmp/cyclone_data e li salva lì, però non lo fa, indagherò sul perchè
data_dir = os.path.join(tempfile.gettempdir(), "cyclone_data")

datamodule = CycloneDataModule(root_dir=data_dir, seed=1337, batch_size=64, num_workers=6, api_key=MLHUB_API_KEY, download=True)

In [54]:
# Modo semplice e super clean per definire una regression task
task = RegressionTask(
    model="resnet18",
    pretrained=True,
    learning_rate=0.1,
    learning_rate_schedule_patience=5,
)

In [55]:
# Questa cella definisce dove salvare la roba del logger e i risultati degli esperimenti, insieme a dove salvare i checkpoint della callback e la metrica per cui
# Fare early stopping: la validation_loss
experiment_dir = os.path.join(tempfile.gettempdir(), "cyclone_results")

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", dirpath=experiment_dir, save_top_k=1, save_last=True
)

early_stopping_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10)

csv_logger = CSVLogger(save_dir=experiment_dir, name="tutorial_logs")

In [56]:
# Qui definisce il trainer, che è solo un modo carino per dire l'oggetto che si occupa di usare le callback, definire tutti i parametri del modello, e sul quale viene chiamata il .fit()
trainer = pl.Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    logger=[csv_logger],
    default_root_dir=experiment_dir,
    min_epochs=1,
    max_epochs=10,
    fast_dev_run=in_tests,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [57]:
trainer.fit(model=task, datamodule=datamodule)

RuntimeError: Dataset not found or corrupted. You can use download=True to download it

In [None]:
# Da qui in poi è tutta visualizzazione
if not in_tests:
    train_steps = []
    train_rmse = []

    val_steps = []
    val_rmse = []
    with open(
        os.path.join(experiment_dir, "tutorial_logs", "version_0", "metrics.csv"), "r"
    ) as f:
        csv_reader = csv.DictReader(f, delimiter=",")
        for i, row in enumerate(csv_reader):
            try:
                train_rmse.append(float(row["train_RMSE"]))
                train_steps.append(i)
            except ValueError:  # Ignore rows where train RMSE is empty
                pass

            try:
                val_rmse.append(float(row["val_RMSE"]))
                val_steps.append(i)
            except ValueError:  # Ignore rows where val RMSE is empty
                pass

In [None]:
if not in_tests:
    plt.figure()
    plt.plot(train_steps, train_rmse, label="Train RMSE")
    plt.plot(val_steps, val_rmse, label="Validation RMSE")
    plt.legend(fontsize=15)
    plt.xlabel("Batches", fontsize=15)
    plt.ylabel("RMSE", fontsize=15)
    plt.show()
    plt.close()

In [None]:
trainer.test(model=task, datamodule=datamodule)