### Selected samples MMGN validation

**Author:** Jakub Walczak, PhD

This notebook contains validation of the MMGN method.

In [7]:
import csv
import shutil
from functools import partial
from pathlib import Path
from typing import Any, Callable

import xarray as xr
from rich.console import Console

import climatrix as cm

%load_ext rich

In [8]:
console = Console()

NAN_POLICY = "resample"
console.print("[bold green]Using NaN policy: [/bold green]", NAN_POLICY)

SEED = 1
console.print("[bold green]Using seed: [/bold green]", SEED)

DSET_PATH = Path(__session__).parent.parent.joinpath("data")
console.print("[bold green]Using dataset path: [/bold green]", DSET_PATH)

EUROPE_BOUNDS = {"north": 71, "south": 36, "west": -24, "east": 35}
EUROPE_DOMAIN = cm.Domain.from_lat_lon(
    lat=slice(EUROPE_BOUNDS["south"], EUROPE_BOUNDS["north"], 0.1),
    lon=slice(EUROPE_BOUNDS["west"], EUROPE_BOUNDS["east"], 0.1),
    kind="dense",
)
cm.seed_all(SEED)

In [9]:
def get_all_dataset_idx() -> list[str]:
    return sorted(
        list({path.stem.split("_")[-1] for path in DSET_PATH.glob("*.nc")})
    )

In [10]:
def run_single_method(
    d: str, i: int, method: str, reconstruct_dense: bool = True, **params
):
    cm.seed_all(SEED)
    train_dset = xr.open_dataset(
        DSET_PATH / f"ecad_obs_europe_train_{d}.nc"
    ).cm
    val_dset = xr.open_dataset(DSET_PATH / f"ecad_obs_europe_val_{d}.nc").cm
    reconstructed_dset = train_dset.reconstruct(
        val_dset.domain,
        method=method,
        checkpoint="./checkpoint",
        overwrite_checkpoint=True,
        validation=val_dset,
        **params,
    )
    if reconstruct_dense:
        reconstructed_dense = train_dset.reconstruct(
            EUROPE_DOMAIN, method=method, checkpoint="./checkpoint", **params
        )
    return val_dset, reconstructed_dset, reconstructed_dense

In [11]:
dset_idx = get_all_dataset_idx()
console.print(
    f"[bold green]There is [bold yellow]{len(dset_idx)}[/bold yellow] samples available [/bold green]"
)

In [12]:
IDX = 0

In [15]:
mmgn_val_dset, mmgn_reconstructed_dset, mmgn_reconstructed_dense = (
    run_single_method(
        dset_idx[IDX],
        IDX,
        "mmgn",
        lr=1.000398223348225,
        weight_decay=0.01,
        batch_size=1024,
        hidden_dim=512,
        latent=115,
        n_layers=1,
        scale=122,
        alpha=1.0,
        device="cuda"
    )
)

11-09-2025 12:35:11 INFO | climatrix.reconstruct.nn.base_nn | Using checkpoint path: /home/jakub/projects/climatrix/experiments/jwalczak/01_Apr_02_compare_recon_method/notebooks/checkpoint
11-09-2025 12:35:11 INFO | climatrix.reconstruct.mmgn.mmgn | Initializing MMGN model...
11-09-2025 12:35:11 INFO | climatrix.reconstruct.mmgn.mmgn | Configuring Adam optimizer with learning rate: 1.000398
11-09-2025 12:35:11 INFO | climatrix.reconstruct.nn.base_nn | Configuring epoch schedulers...
11-09-2025 12:35:11 INFO | climatrix.reconstruct.nn.base_nn | Training MMGNet model...
11-09-2025 12:35:11 INFO | climatrix.reconstruct.nn.base_nn | Validation dataset is available. Using it for validation.
11-09-2025 12:35:11 INFO | climatrix.reconstruct.nn.base_nn | Epoch 1/1000: train loss = 0.3183 | val loss = 39.7533
11-09-2025 12:35:11 INFO | climatrix.reconstruct.nn.base_nn | Epoch 2/1000: train loss = 40.0149 | val loss = 233956.6406
11-09-2025 12:35:11 INFO | climatrix.reconstruct.nn.base_nn | Epoc

In [None]:
cm.Comparison(mmgn_val_dset, mmgn_reconstructed_dset).compute_report()

### After optimising hyperpararmeters

In [7]:
BOUNDS = {
    "lr": (1e-5, 10.0),
    "weight_decay": (0, 1e-1),
    "batch_size": (32, 4096),
    "mse_loss_weight": (1e-5, 100),
    "eikonal_loss_weight": (0.0, 10.0),
    "laplace_loss_weight": (0.0, 10.0),
    "scale": (0.01, 100.0),
    "hidden_dim": [16, 64, 128, 256],
}
console.print("[bold green]Hyperparameter bounds: [/bold green]", BOUNDS)

OPTIM_INIT_POINTS: int = 50
console.print(
    "[bold green]Using nbr initial points for optimization: [/bold green]",
    OPTIM_INIT_POINTS,
)

OPTIM_N_ITERS: int = 100
console.print(
    "[bold green]Using iterations for optimization[/bold green]", OPTIM_N_ITERS
)
console.print(
    "[bold green]Dataset: [/bold green]", dset_idx[IDX]
)

In [8]:
def find_hyperparameters(
    train_dset: cm.BaseClimatrixDataset,
    val_dset: cm.BaseClimatrixDataset,
    bounds: dict[str, tuple],
    n_init_points: int = 30,
    n_iter: int = 200,
    seed: int = 0,
    verbose: int = 2,
) -> tuple[float, dict[str, float]]:
    finder = cm.optim.HParamFinder(
        "sinet",
        train_dset,
        val_dset,
        metric="mae",
        n_iters=OPTIM_N_ITERS,
        bounds=BOUNDS,
        random_seed=SEED,
        exclude=["num_epochs"]
    )
    result = finder.optimize()
    return result


def run_single_experiment(d: str):
    train_dset = xr.open_dataset(
        DSET_PATH / f"ecad_obs_europe_train_{d}.nc"
    ).cm
    val_dset = xr.open_dataset(DSET_PATH / f"ecad_obs_europe_val_{d}.nc").cm
    result = find_hyperparameters(
        train_dset,
        val_dset,
        BOUNDS,
        n_init_points=OPTIM_INIT_POINTS,
        n_iter=OPTIM_N_ITERS,
        seed=SEED,
        verbose=2,
    )
    console.print("[bold yellow]Optimized parameters:[/bold yellow]")
    console.print(
        "[yellow]Learning rate (lr):[/yellow]", result["best_params"]["lr"]
    )
    console.print(
        "[yellow]Number of epochs:[/yellow]",
        result["best_params"]["num_epochs"],
    )
    console.print(
        "[yellow]Scale:[/yellow]",
        result["best_params"]["scale"],
    )
    console.print(
        "[yellow]Batch size:[/yellow]", result["best_params"]["batch_size"]
    )
    console.print(
        "[yellow]MSE loss weight:[/yellow]",
        result["best_params"]["mse_loss_weight"],
    )
    console.print(
        "[yellow]Eikonal loss weight:[/yellow]",
        result["best_params"]["eikonal_loss_weight"],
    )
    console.print(
        "[yellow]Laplace loss weight:[/yellow]",
        result["best_params"]["laplace_loss_weight"],
    )
    console.print(
        "[yellow]Early stopping patience:[/yellow]",
        result["best_params"]["patience"],
    )
    console.print(
        "[yellow]Hidden dimension:[/yellow]",
        result["best_params"]["hidden_dim"],
    )
    console.print(
        "[yellow]Weight decay:[/yellow]",
        result["best_params"]["weight_decay"],
    )    
    console.print("[yellow]Best loss:[/yellow]", result["best_score"])
    reconstructed_dset = train_dset.reconstruct(
        val_dset.domain,
        method="sinet",
        device="cuda",
        lr=result["best_params"]["lr"],
        weight_decay=result["best_params"]["weight_decay"],
        num_epochs=result["best_params"]["num_epochs"],
        batch_size=result["best_params"]["batch_size"],
        num_workers=0,
        scale=result["best_params"]["scale"],
        mse_loss_weight=result["best_params"]["mse_loss_weight"],
        eikonal_loss_weight=result["best_params"]["eikonal_loss_weight"],
        laplace_loss_weight=result["best_params"]["laplace_loss_weight"],
        patience=result["best_params"]["patience"],
        hidden_dim=result["best_params"]["hidden_dim"],
        checkpoint="./mmgn_checkpoint.pth",
        overwrite_checkpoint=True,
    )
    cmp = cm.Comparison(reconstructed_dset, val_dset)
    metrics = cmp.compute_report()
    metrics["dataset_id"] = d
    hyperparams = {
        "dataset_id": d,
        "lr": result["best_params"]["lr"],
        "num_epochs": result["best_params"]["num_epochs"],
        "scale": result["best_params"]["scale"],
        "batch_size": result["best_params"]["batch_size"],
        "mse_loss_weight": result["best_params"]["mse_loss_weight"],
        "eikonal_loss_weight": result["best_params"]["eikonal_loss_weight"],
        "laplace_loss_weight": result["best_params"]["laplace_loss_weight"],
        "patience": result["best_params"]["patience"],
        "hidden_dim": result["best_params"]["hidden_dim"],
        "weight_decay":result["best_params"]["weight_decay"],
        "opt_loss": result["best_score"],
    }
    return (metrics, hyperparams)

In [None]:
metrics, hyperparams = run_single_experiment(dset_idx[IDX])

10-09-2025 16:11:59 INFO | climatrix.optim.bayesian | Starting Bayesian optimization for method 'sinet'
10-09-2025 16:11:59 INFO | climatrix.optim.bayesian | Bounds: OrderedDict({'batch_size': (32, 4096, <class 'int'>), 'eikonal_loss_weight': (0.0, 10.0, <class 'float'>), 'hidden_dim': [16, 64, 128, 256], 'laplace_loss_weight': (0.0, 10.0, <class 'float'>), 'lr': (1e-05, 10.0, <class 'float'>), 'mse_loss_weight': (1e-05, 100, <class 'float'>), 'scale': (0.01, 100.0, <class 'float'>), 'weight_decay': (0, 0.1, <class 'float'>)})
10-09-2025 16:11:59 INFO | climatrix.optim.bayesian | Using 100 iterations


  from .autonotebook import tqdm as notebook_tqdm
  sampler = optuna.samplers.GPSampler(
[I 2025-09-10 16:11:59,498] A new study created in memory with name: sinet_study
  0%|                                                                                                             | 0/100 [00:00<?, ?it/s]

10-09-2025 16:11:59 INFO | climatrix.optim.bayesian | Suggested parameters for trial 0: {'batch_size': 1727, 'eikonal_loss_weight': 7.203244934421581, 'hidden_dim': 64, 'laplace_loss_weight': 1.862602113776709, 'lr': 3.455613814823207, 'mse_loss_weight': 39.67675345539225, 'scale': 53.886285232995654, 'weight_decay': 0.041919451440329485}
10-09-2025 16:11:59 INFO | climatrix.reconstruct.sinet.sinet | Initializing SiNET model...
10-09-2025 16:11:59 INFO | climatrix.reconstruct.sinet.sinet | Configuring Adam optimizer with learning rate: 3.455614
10-09-2025 16:12:00 INFO | climatrix.reconstruct.nn.base_nn | Training SiNET model...


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [None]:
mmgn_val_dset, mmgn_reconstructed_dset, mmgn_reconstructed_dense = (
    run_single_method(
        dset_idx[IDX],
        IDX,
        "sinet",
        lr=hyperparams["lr"],
        weight_decay=hyperparams["weight_decay"],
        num_epochs=hyperparams["num_epochs"],
        batch_size=hyperparams["batch_size"],
        num_workers=0,
        device="cuda",
        mse_loss_weight=hyperparams["mse_loss_weight"],
        hidden_dim=hyperparams["hidden_dim"],
    )
)

In [None]:
mmgn_reconstructed_dense.plot()

In [None]:
train_dset = xr.open_dataset(
    DSET_PATH / f"ecad_obs_europe_train_{dset_idx[IDX]}.nc"
).cm.plot()

In [None]:
mmgn_reconstructed_dset.plot()

In [None]:
cm.Comparison(mmgn_val_dset, mmgn_reconstructed_dset).compute_report()