## Imports

These are the libraries we will use in this notebook that will likely need to be installed in your environment. You can install them using pip or conda.


In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding
!pip install omegaconf

In [None]:
!pip install torch-spatiotemporal

In [None]:
!pip install scoringrules

In [None]:
import torch
!pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

In [None]:
!pip install wandb
!pip install xarray

# Classes

The following cells contain the probabilistic layers, loss functions and data handlers used for the training of the model. The code is organized in a way that allows for easy modification and extension of the model. 

## Probabilistic Layers

In [4]:
import torch.nn as nn
import torch
from abc import ABC, abstractmethod
from typing import Literal
import sys
from inspect import getmembers, isclass


class SoftplusWithEps(nn.Module):
    def __init__(self, eps=1e-5):
        super().__init__()
        self.softplus = nn.Softplus()
        self.eps = eps

    def forward(self, x):
        return self.softplus(x) + self.eps


class DistributionLayer(nn.Module, ABC):
    def __init__(self, input_size):
        super().__init__()

        self.distribution = getattr(torch.distributions, self.name)

        self.encoder = nn.Linear(input_size, self.num_parameters)

    @property
    @abstractmethod
    def num_parameters(self):
        pass

    @property
    @abstractmethod
    def name(self):
        pass

    @abstractmethod
    def process_params(self, x):
        pass


    def forward(self, x, return_type: Literal['samples', 'distribution']='distribution', reparametrized=True, num_samples=1):
        params = self.encoder(x)
        distribution = self.process_params(params)
        if return_type == 'distribution':
            return distribution
        return distribution.rsample((num_samples,)) if reparametrized else distribution.sample((num_samples,))


class LogNormalLayer(DistributionLayer):
    _name = 'LogNormal'
    def __init__(self, input_size):
        super(LogNormalLayer, self).__init__(input_size=input_size)
        self.get_positive_std = SoftplusWithEps()

    @property
    def name(self):
        return self._name

    @property
    def num_parameters(self):
        return 2

    def process_params(self, x):
        new_moments = x.clone()
        new_moments[...,1] = self.get_positive_std(x[...,1])

        log_normal_dist = self.distribution(new_moments[...,0:1], new_moments[...,1:2])
        return log_normal_dist


class NormalLayer(DistributionLayer):
    _name = 'Normal'
    def __init__(self, input_size):
        super(NormalLayer, self).__init__(input_size=input_size)
        self.get_positive_std = SoftplusWithEps()

    @property
    def name(self):
        return self._name

    @property
    def num_parameters(self):
        return 2

    def process_params(self, x):
        new_moments = x.clone()
        new_moments[...,1] = self.get_positive_std(x[...,1])

        normal_dist = self.distribution(new_moments[...,0:1], new_moments[...,1:2])
        return normal_dist


prob_layers = [obj[1] for obj in getmembers(sys.modules[__name__], isclass) if issubclass(obj[1], DistributionLayer) and obj[0] != 'DistributionLayer']
dist_to_layer = {
    l._name: l for l in prob_layers
}


## Losses
The `MaskedCRPSLogNormal` class is a custom loss function that computes the Continuous Ranked Probability Score (CRPS) for a log-normal distribution of the MultiScale Graph network. The CRPS is a measure of the accuracy of probabilistic forecasts, and it is particularly useful for evaluating the performance of probabilistic models. The rest of the classes were adapted from the given code.

In [3]:
import torch
import torch.nn as nn
import scoringrules as sr
from torch.distributions import Normal, Independent, LogNormal
import math


class MaskedCRPSNormal(nn.Module):

    def __init__(self):
        super(MaskedCRPSNormal, self).__init__()

    def forward(self, pred, y):
        mask = ~torch.isnan(y)
        y = y[mask]
        mu = pred.loc[mask].flatten()
        sigma = pred.scale[mask].flatten()

        normal = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(sigma))

        scaled = (y - mu) / sigma
        self.log_sigma_factor = nn.Parameter(torch.log(torch.tensor(init_factor)))

        Phi = normal.cdf(scaled)
        phi = torch.exp(normal.log_prob(scaled))

        crps = sigma * (scaled * (2 * Phi - 1) + 2 * phi - (1 / torch.sqrt(torch.tensor(torch.pi, device=sigma.device))))

        return crps.mean()

class MaskedCRPSLogNormal(nn.Module):

    def __init__(self):
        super(MaskedCRPSLogNormal, self).__init__()
        self.i = 0
        self.eps = 1e-5

    def forward(self, pred, y, t=None, L=97):
        if t is not None:
          B, _, N, _ = y.shape
        else:
          B, L, N, _ = y.shape
        # print batch size
        #print("B: ", B)
        #print("L: ", L)
        #print("N: ", N)
        #print("y shape in loss", y.shape)

        mask = ~torch.isnan(y)
        y = y[mask]
        elements = len(y)
        #
        eps = 1e-5
        y += eps  # Avoid 0s (pdf(y=0) is undefined for  Y~LogNormal )

        #### New Code to change shape from [B*N, L, 1] to [B, L, N, 1]
        mu   = pred.loc
        sigma= pred.scale
        # print shapes
        #print("mu shape: ", mu.shape)
        #print("sigma shape: ", sigma.shape)
        #print("y shape: ", y.shape)

        mu   = mu.view(B, N, L, 1)
        sigma= sigma.view(B, N, L, 1)

        # print shape after first permutation
        #print("mu shape: ", mu.shape)
        #print("sigma shape: ", sigma.shape)

        mu    = mu.permute(0, 2, 1, 3)
        sigma = sigma.permute(0, 2, 1, 3)

        if t is not None:
          mu = mu[:, t, :, :].unsqueeze(1)

          sigma = sigma[:, t, :, :].unsqueeze(1)


        #mu = pred.loc[mask].flatten()
        mu = mu[mask].flatten()
        #sigma = pred.scale[mask].flatten()
        sigma = sigma[mask].flatten()

        normal = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(sigma))

        # Source: Baran and Lerch (2015) Log‐normal distribution based Ensemble Model Output Statistics models for probabilistic wind‐speed forecasting
        omega = (torch.log(y)-mu)/sigma

        ex_input = mu + (sigma**2)/2

        # Clamp exponential for stability (e^15 = 3269017)
        # Note that the true mean of the Log-Normal is E[Y]=exp(mu+sigma^2/2), Y~LogN(mu,sigma)
        # This means that clamping this value still leaves room for a huge range of values
        # (Definitely enough for the wind speed :P)
        ex_input = torch.clamp(ex_input, max=15)
        #mlflow.log_metric('exp_input_debug', (ex_input).max(), step=self.i)
        self.i += 1

        ex = 2*torch.exp(ex_input)

        crps = y * (2*normal.cdf(omega)-1.0) - ex * (normal.cdf(omega-sigma)+normal.cdf(sigma/(2**0.5))-1.0)

        return crps.mean()




class MaskedCRPSEnsemble(nn.Module):

    def __init__(self):
        super(MaskedCRPSEnsemble, self).__init__()

    def forward(self, samples, y):
        # Pattern of y := [batch, time, station]
        # Patter of samples := [batch, time, station, sample]

        mask = ~torch.isnan(y)

        losses = sr.crps_ensemble(y.squeeze(-1), samples.squeeze(-1))

        return losses[mask.squeeze(1)].mean()


class MaskedMAE(nn.Module):
    def __init__(self):
        super(MaskedMAE, self).__init__()

    def forward(self, pred, target):
        # build mask: True where target is finite
      mask = torch.isfinite(target)
      valid = mask.sum()

      if valid == 0:
          # no valid points
          return torch.tensor(float('nan'), device=pred.device)

      # compute abs error only on valid entries
      abs_err = (pred[mask] - target[mask]).abs()
      return abs_err.mean(), valid

## Data

The following functions are used to load and preprocess the data. Certain changes were made to the original code to enable the altering of the graph structure and masking of anomalous data.

In [7]:
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from tsl.ops.similarities import top_k
import xarray as xr
from typing import Union, Optional
from omegaconf import ListConfig
import numpy as np

class XarrayDataset(Dataset):
    def __init__(self, input_data, target_data, anomalous=False):

        self.input_data, self.input_denormalizer = self.normalize(np.transpose(input_data.to_array().data, (1,2,3,0)))

        # NOTE No transformation is applied to targets. If normalized ~(0,1), they can be negatve
        # which is incompatible with CRPS_LogNormal.
        if not anomalous:
            self.target_data, self.target_denormalizer = np.transpose(target_data.to_array().data, (1,2,3,0)), lambda x : x
        else:
        # load raw targets
        # anomalous data
            raw_y = np.transpose(target_data.to_array().data, (1,2,3,0))
            # mask out-of-range values
            clean_y = self.mask_anomalous_targets(
                torch.from_numpy(raw_y).float(),
                min_speed=0.2,
                max_speed=10.0
            ).numpy()  # still shape (t, l, s, 1) or (t,l,s)
    
            self.target_data = clean_y
            self.target_denormalizer = lambda x: x
            print("masked anomalous data")


        self.t, self.l, self.s, self.f = self.input_data.shape
        self.tg = self.target_data.shape[-1]

    def normalize(self, data):
        data_mean = np.nanmean(data, axis=(0, 1, 2), keepdims=True)
        data_std = np.nanstd(data, axis=(0, 1, 2), keepdims=True)
        standardized_data = (data - data_mean) / data_std

        def denormalizer(x): # closure (alternative: use a partial)
            if isinstance(x, torch.Tensor):
                return (x * torch.Tensor(data_std).to(x.device)) + torch.Tensor(data_mean).to(x.device)
            return (x * data_std) + data_mean

        return standardized_data, denormalizer

    def mask_anomalous_targets(self, y, min_speed, max_speed):
        squeezed = (y.squeeze(-1) if y.dim()==4 else y)
        bad = (squeezed < min_speed) | (squeezed > max_speed) | torch.isnan(squeezed)
        y_clean = squeezed.clone()
        y_clean[bad] = float('nan')
        return y_clean.unsqueeze(-1) if y.dim()==4 else y_clean

    def get_baseline_score(self, score_fn):
        pass

    @property
    def stations(self):
        return self.s

    @property
    def forecasting_times(self):
        return self.t

    @property
    def lead_times(self):
        return self.l

    @property
    def features(self):
        return self.f

    @property
    def targets(self):
        return self.tg

    def __len__(self):
        return self.input_data.shape[0]  # Number of forecast_reference_time

    def __getitem__(self, idx):
        sample_x = self.input_data[idx]
        sample_y = self.target_data[idx]
        return torch.tensor(sample_x, dtype=torch.float), torch.tensor(sample_y, dtype=torch.float)


    def __str__(self):
        return f"Dataset: [time={self.t}, lead_time={self.l}, stations={self.s}, features={self.f}] | target dim={self.tg}\n"



def get_graph(lat, lon, knn=10, threshold=None, theta=None):

    def haversine(lat1, lon1, lat2, lon2, radius=6371):
        import math
        lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])

        # Differences in coordinates
        delta_lat = lat2 - lat1
        delta_lon = lon2 - lon1

        a = math.sin(delta_lat / 2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(delta_lon / 2)**2
        c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
        distance = radius * c

        return distance
    n = lat.shape[0]
    dist = np.zeros((n,n))
    for i in tqdm(range(n), desc="Outer Loop Progress"):
        for j in tqdm(range(i, n), desc=f"Inner Loop Progress (i={i})", leave=False):
            s1_lon = lon[i]
            s1_lat = lat[i]

            s2_lon = lon[j]
            s2_lat = lat[j]

            d = haversine(lat1=s1_lat, lon1=s1_lon, lat2=s2_lat, lon2=s2_lon)
            dist[i,j] = d
            dist[j,i] = d

    def gaussian_kernel(x, theta=None):
        if theta is None or theta == "std":
            theta = x.std()
        elif theta == "median":
            # extract strictly off-diagonal entries
            i, j = np.triu_indices(dist.shape[0], k=1)
            d_off = dist[i, j]
            theta = np.median(d_off)
        elif theta == "factormedian":
            # extract strictly off-diagonal entries
            i, j = np.triu_indices(dist.shape[0], k=1)
            d_off = dist[i, j]
            theta = np.median(d_off)*0.5
        weights = np.exp(-np.square(x / theta))
        return weights

    adj = gaussian_kernel(dist, theta)

    adj = top_k(adj,knn, include_self=True,keep_values=True)

    if threshold is not None:
            adj[adj < threshold] = 0

    return adj

class PostprocessDatamodule():
    def __init__(self, train_dataset: XarrayDataset,
                 val_dataset: XarrayDataset,
                 test_dataset: XarrayDataset,
                 adj_matrix: np.ndarray = None):
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.adj_matrix = adj_matrix
        self.num_edges = (self.adj_matrix != 0).astype(np.int32).sum() if adj_matrix is not None else 0


    def __str__(self) -> str:
        dm_str = "Data Module: \n\n"
        dm_str += "Train:\n"
        dm_str += str(self.train_dataset)
        dm_str += "Val:\n"
        dm_str += str(self.val_dataset)
        dm_str += "Test:\n"
        dm_str += str(self.test_dataset)

        dm_str += f"Number of edges = {self.num_edges}"
        return dm_str

def get_datamodule(ds: xr.Dataset,
                   ds_targets: xr.Dataset,
                   predictors: Union[list, ListConfig],
                   lead_time_hours: int,
                   val_split: float,
                   target_var: str,
                   test_start_date: str,
                   train_val_end_date: Optional[str] = None,
                   return_graph=True,
                   graph_kwargs=None,
                   anomalous = False) -> PostprocessDatamodule:
    """_summary_

    Args:
        ds (xr.Dataset): The input dataset.
        ds_targets (xr.Dataset): The target dataset.
        predictors (Union[list, ListConfig]): The variable names to be used as predictors.
        lead_time_hours (int): The number of hours considered for the forecasted window.
        val_split (float): The percentage in [0,1) to be used as validation.
        target_var (str): The (single) target variable.
        test_start_date (str): The day where the test set will start.
        train_val_end_date (Optional[str], optional): The day when train and validation end. If not provided it will be set to test_start_date - 1.
        Pick a date that ensures no data leakage, eg. by using a large enough gap. Defaults to None.
        graph_kwargs (_type_, optional): Arguments to be passed to the graph-generating function. Defaults to None.

    Returns:
        PostprocessDatamodule: A datamodule with the train/val/test splits.
    """
    if isinstance(predictors, ListConfig):
        predictors = list(predictors)
    test_datetime = np.datetime64(test_start_date)
    if train_val_end_date is None:
        train_val_datetime = test_datetime - np.timedelta64(1,'D')
    else:
        train_val_datetime = np.datetime64(train_val_end_date)

    print(f'Train&Val sets end at {train_val_datetime}')
    print(f'Test set starts at {test_datetime}')

    # Get input data and split
    input_data = ds[predictors]
    input_data = input_data.sel(lead_time=slice(None, np.timedelta64(lead_time_hours, 'h')))

    input_data_train_val = input_data.sel(forecast_reference_time=slice(None, train_val_datetime))
    test_input_data = input_data.sel(forecast_reference_time=slice(test_datetime, None))

    train_val_rtimes = len(input_data_train_val['forecast_reference_time'])
    split_index = int(train_val_rtimes * (1.0 - val_split))

    train_input_data = input_data_train_val.isel(forecast_reference_time=slice(0, split_index))
    val_input_data = input_data_train_val.isel(forecast_reference_time=slice(split_index, None))

    # Get target data
    target_data = ds_targets[[target_var]]
    target_data = target_data.sel(lead_time=slice(None, np.timedelta64(lead_time_hours, 'h')))

    target_data_train_val = target_data.sel(forecast_reference_time=slice(None, train_val_datetime))
    test_target_data = target_data.sel(forecast_reference_time=slice(test_datetime, None))

    train_target_data = target_data_train_val.isel(forecast_reference_time=slice(0, split_index))
    val_target_data = target_data_train_val.isel(forecast_reference_time=slice(split_index, None))

    if return_graph:
        lat = ds.latitude.data
        lon = ds.longitude.data
        adj_matrix = get_graph(lat=lat, lon=lon, **graph_kwargs)
        return PostprocessDatamodule(train_dataset=XarrayDataset(input_data=train_input_data, target_data=train_target_data, anomalous=anomalous),
                                     val_dataset=XarrayDataset(input_data=val_input_data, target_data=val_target_data,anomalous=anomalous),
                                     test_dataset=XarrayDataset(input_data=test_input_data, target_data=test_target_data, anomalous=anomalous),
                                     adj_matrix=adj_matrix)
    return PostprocessDatamodule(train_dataset=XarrayDataset(input_data=train_input_data, target_data=train_target_data,anomalous=anomalous),
                                     val_dataset=XarrayDataset(input_data=val_input_data, target_data=val_target_data,anomalous=anomalous),
                                     test_dataset=XarrayDataset(input_data=test_input_data, target_data=test_target_data,anomalous=anomalous))

# Models

## Multi Scale Graph Wavenet

The following classes implement the Multi Scale Graph Wavenet model.

In [5]:
"""
Multi‑Scale Graph WaveNet (MSGWN) — **distribution ready**
========================================================
This version optionally **outputs a Log‑Normal predictive distribution**
(instead of a point forecast) using the lightweight `LogNormalLayer` below.

Arguments
---------
* `output_dist`  –  `None` (default) → returns tensor `(B, horizon, N)`
                  –  `'LogNormal'`   → returns a `torch.distributions.LogNormal`
                    object whose batch shape is `(B, horizon, N)`.

* `adj_matrix`, `learnable_adj` as in the previous revision (fixed / prior / random).

Example
-------
```python
model = MultiScaleGraphWaveNet(
            num_nodes=N, in_channels=F, history_len=T, horizon=12,
            adj_matrix=dm.adj_matrix, output_dist='LogNormal')

mu_sigma = model(torch.randn(B, F, N, T))  # returns LogNormal distribution
samples  = mu_sigma.rsample()              # (B, 12, N)
```
"""
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import LogNormal
################################################################################
# ─────────────────── Distribution helper ──────────────────────────────────────
################################################################################
class SoftplusWithEps(nn.Softplus):
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
    def forward(self, x):
        return super().forward(x) + self.eps

################################################################################
# Graph modules (unchanged up to STBlock)
################################################################################
class LearnableAdjacency(nn.Module):
    def __init__(self, n, emb_dim=10, init_A: Optional[np.ndarray] = None):
        super().__init__()
        self.E1 = nn.Parameter(torch.randn(n, emb_dim))
        self.E2 = nn.Parameter(torch.randn(n, emb_dim))
        if init_A is not None:
            A = np.asarray(init_A, dtype=np.float64) + 1e-9
            U, S, Vt = np.linalg.svd(np.log(A), full_matrices=False)
            B = U[:, :emb_dim] @ np.diag(np.sqrt(S[:emb_dim]))
            self.E1.data.copy_(torch.tensor(B, dtype=torch.float32))
            self.E2.data.copy_(torch.tensor(B, dtype=torch.float32))
    def forward(self):
        return F.softmax(self.E1 @ self.E2.T, dim=1)

class GraphConv(nn.Module):
    def __init__(self, c_in, c_out, K=2):
        super().__init__(); self.K = K
        self.theta = nn.Parameter(torch.empty(K, c_in, c_out))
        nn.init.xavier_uniform_(self.theta)
    #def forward(self, x, supports):
    #    out = 0.0
    #    for k in range(self.K):
    #        out += torch.einsum("bcn,nm,cmd->bdm", x, supports[k], self.theta[k])
    #    return out

    def forward(self, x, supports: List[torch.Tensor]):
        # x: (B, C_in, N)
        # supports[k]: (N, N)
        # theta[k]   : (C_in, C_out)
        out = 0.0
        for k in range(self.K):
            # 1) diffuse along the graph:
            #    (B,C_in,N) x (N,N) → (B,C_in,N)
            xk = torch.einsum('bcn,nm->bcm', x, supports[k])
            # 2) mix channels:
            #    (B,C_in,N) x (C_in,C_out) → (B,C_out,N)
            out = out + torch.einsum('bcm,cd->bdm', xk, self.theta[k])
        return out


class HistoryDropout(nn.Module):
    def __init__(self, p: float = 0.05, block_size: int = 12):
        """
        p          – probability *per sample* of zeroing out a block
        block_size – length of the contiguous time-block to mask
        """
        super().__init__()
        self.p = p
        self.block_size = block_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, N, T)
        if not self.training or self.p <= 0:
            return x

        B, C, N, T = x.shape
        # create a mask of ones
        mask = x.new_ones((B, 1, 1, T))
        for b in range(B):
            if torch.rand((), device=x.device) < self.p:
                # pick a random start index so the block fits
                t0 = torch.randint(0, T - self.block_size + 1, (), device=x.device)
                mask[b, :, :, t0 : t0 + self.block_size] = 0.0
        return x * mask

class KernelDilatedTCN(nn.Module):
    """One kernel size k, four dilations 1,2,4,8, causal padding."""
    def __init__(self, c_in: int, c_out: int, k: int, dil: tuple, drop: float):
        super().__init__()
        self.kernel_size = k
        # 1×1 pre-mix
        self.pre = nn.Conv2d(c_in, c_out, kernel_size=1)
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.BatchNorm2d(c_out),
                nn.Dropout2d(drop),
                nn.Conv2d(c_out, 2*c_out, (1,k), dilation=(1,d), padding=0)
            ) for d in dil
        ])

        self.dil = dil
        # fuse back to c_out
        #self.merge = nn.Conv2d(4*c_out, c_out, kernel_size=1)
        self.merge = nn.Conv2d(len(dil)*c_out, c_out, 1)   # ← variable
        # residual projection
        self.res   = nn.Conv2d(c_in, c_out, kernel_size=1)

    def forward(self, x: torch.Tensor):
        # x: (B, C_in, N, T)
        z = self.pre(x)  # (B, C_out, N, T)
        outs = []
        for conv, d in zip(self.branches, self.dil):
            #pad = (self.kernel_size - 1) * d
            pad_val = (self.kernel_size - 1) * d
            # only pad on the *left* of the time axis:
            # F.pad pads last two dims by (left,right) pairs = (pad_left,pad_right)
            #z_pad = F.pad(z, (pad, 0))  # → (B, C_out, N, T + pad)
            # non causal padding
            z_pad = F.pad(z, (pad_val // 2, pad_val - pad_val // 2))
            # conv without extra padding then chops back to length T
            p, q = conv(z_pad).chunk(2, dim=1)  # each (B, C_out, N, T)
            outs.append(torch.tanh(p) * torch.sigmoid(q))
        y = self.merge(torch.cat(outs, dim=1))  # (B, C_out, N, T)
        y = F.relu(y)
        return y + self.res(x)

class DynamicKernelDilatedTCN(nn.Module):
    def __init__(self, c_in, c_out, k, dilations: tuple, drop: float = 0.1, r: int = 8):
        super().__init__()
        self.kernel_size = k
        self.dilations = dilations

        # 1×1 pre-mix
        self.pre = nn.Conv2d(c_in, c_out, 1)

        # dilated branches (exactly as before)
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.BatchNorm2d(c_out),
                nn.Dropout2d(drop),
                nn.Conv2d(c_out, 2*c_out, (1, k), dilation=(1, d), padding=0)
            )
            for d in dilations
        ])

        # gating network: global-pool → FC → one logit per branch
        # produces shape (B, num_branches, 1, 1)
        self.gate = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),                # (B, C, 1, 1)
            nn.Conv2d(c_out, c_out//r, 1),          # squeeze
            nn.ReLU(),
            nn.Conv2d(c_out//r, len(dilations), 1)  # one logit per branch
        )

        # no more static fuse conv
        # final 1×1 to mix back to c_out
        self.project = nn.Conv2d(c_out, c_out, 1)
        self.res     = nn.Conv2d(c_in, c_out, 1)

    def forward(self, x):
        # x: (B, c_in, N, T)
        z = self.pre(x)  # (B, c_out, N, T)

        # 1) compute branch outputs
        outs = []
        for conv, d in zip(self.branches, self.dilations):
            pad = (self.kernel_size - 1) * d
            z_pad = F.pad(z, (pad//2, pad - pad//2))  # or causal
            p, q = conv(z_pad).chunk(2, dim=1)        # each (B,c_out,N,T)
            outs.append(torch.tanh(p) * torch.sigmoid(q))

        # 2) compute gates
        # gate_logits: (B, num_branches, 1, 1)
        gate_logits = self.gate(z)
        # softmax over branch dim → (B, num_branches, 1, 1)
        gate_weights = F.softmax(gate_logits, dim=1)

        # 3) weighted sum of branches
        y = 0
        for i, out_i in enumerate(outs):
            # out_i: (B,c_out,N,T); weight_i broadcast across C,N,T
            w_i = gate_weights[:, i:i+1, :, :]      # (B,1,1,1)
            y = y + out_i * w_i

        # 4) final projection + residual
        y = self.project(y)                        # (B,c_out,N,T)
        return F.relu(y + self.res(x))

class MultiScaleTCN(nn.Module):
    def __init__(self, c_in, c_out, kernels: tuple, dil, drop, dynamic=False):
        super().__init__()
        if dynamic:
            self.cols = nn.ModuleList([DynamicKernelDilatedTCN(c_in, c_out, k, dil, drop) for k in kernels])
        else:
            self.cols = nn.ModuleList([KernelDilatedTCN(c_in, c_out, k, dil, drop) for k in kernels])
        self.fuse = nn.Conv2d(len(kernels)*c_out, c_out, 1); self.res = nn.Conv2d(c_in, c_out, 1)
    def forward(self, x):
        y = torch.cat([col(x) for col in self.cols], dim=1)
        y = F.relu(self.fuse(y))
        return y + self.res(x)

class STBlock(nn.Module):
    def __init__(self, c_in, channels, kernels, dil, drop, dynamic):
        super().__init__(); self.tcn = MultiScaleTCN(c_in, channels, kernels, dil, drop, dynamic)
        self.gcn = GraphConv(channels, channels)
        self.skip = nn.Conv2d(channels, channels, 1)
        self.res  = nn.Conv2d(c_in, channels, 1)
    def forward(self, x, supports):
        y = self.tcn(x)
        B,C,N,T = y.shape
        y = y.permute(0,3,1,2).reshape(B*T, C, N)
        y = self.gcn(y, supports).view(B,T,C,N).permute(0,2,3,1)
        return F.relu(y + self.res(x)), self.skip(y)
################################################################################
# ─────────────────── MSGWN main class (distribution‑aware) ────────────────────
################################################################################
class MultiScaleGraphWaveNet(nn.Module):
    def __init__(self,
                 num_nodes: int,
                 in_channels: int,
                 history_len: int,
                 horizon: int = 1, # only for point forecasts
                 layers: int = 4,
                 channels: int = 32,
                 emb_dim: int = 10,
                 node_emb_dim: int = 10,
                 adj_matrix: Optional[np.ndarray] = None,
                 learnable_adj: bool = True,
                 kernels = (1, 3, 5, 7),
                 dil = (1, 2, 4, 8),
                 drop = 0.1,
                 edge_drop_p = 0.0,
                 history_dropout_p: float = 0.05,
                 history_block:   int   = 12,
                 dynamic=True,
                 output_dist: Optional[str] = None):
        super().__init__()
        self.horizon = horizon
        self.edge_drop_p = edge_drop_p
        self.num_nodes = num_nodes
        # adjacency
        self.node_embeddings = nn.Embedding(num_nodes, node_emb_dim) # node_emb_dim is a new hyperparameter
        nn.init.xavier_uniform_(self.node_embeddings.weight)
        if adj_matrix is not None and not learnable_adj:
            self.register_buffer('A_fixed', torch.tensor(adj_matrix, dtype=torch.float32))
            self.adj = None
        else:
            self.A_fixed = None
            self.adj = LearnableAdjacency(num_nodes, emb_dim, init_A=adj_matrix)
        # layers
        #self.in_proj = nn.Conv2d(in_channels, channels, 1)

        self.history_dropout = HistoryDropout(p=history_dropout_p,
                                              block_size=history_block)
        self.in_proj = nn.Conv2d(in_channels + node_emb_dim, channels, 1)
        self.blocks = nn.ModuleList([STBlock(channels, channels, kernels, dil, drop, dynamic) for _ in range(layers)])
        self.skip_proj = nn.Conv2d(layers*channels, channels, 1)
        self.end1 = nn.Conv2d(channels, channels, 1)

        # param projection (μ, σ) per horizon
        out_dim = 2 * horizon if output_dist=='LogNormal' else horizon
        self.param_conv = nn.Conv2d(channels, out_dim, 1)
        self.output_dist = output_dist
        if output_dist == 'LogNormal':
            # project to model channels (distribution input size)
            self.param_conv = nn.Conv2d(channels, channels, 1)
            self.dist_layer = LogNormalLayer(input_size=channels)
        else:
            # point forecast: one output per horizon
            self.param_conv = nn.Conv2d(channels, horizon, 1)

    # supports helper
    def _supports(self):
        A = self.A_fixed if self.A_fixed is not None else self.adj()
        #return [A, A @ A]

        # ---- insert edge-dropout here ------------------------------------
        if self.training and self.edge_drop_p > 0:          # dropout only while training
            mask = torch.bernoulli(A.new_full(A.shape, 1.0 - self.edge_drop_p))
            A = A * mask                                    # zero out ~p fraction of edges
        # ------------------------------------------------------------------

        return [A, A @ A]      # first- and second-order supports

    def forward(self, x):  # x: (B, C_in, N, T)
        #x = x.permute(0, 3, 2, 1)

        # --- START OF NODE EMBEDDING INTEGRATION ---
        B, C_original, N, T = x.shape # Get dimensions from input
        
        # 1. Get node embeddings
        node_idx = torch.arange(self.num_nodes, device=x.device)
        n_emb = self.node_embeddings(node_idx)  # Shape: (N, node_emb_dim)
        
        # 2. Expand embeddings to match input tensor's batch and time dimensions
        #    Target shape for n_emb: (B, node_emb_dim, N, T)
        n_emb_expanded = n_emb.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, T)
        n_emb_expanded = n_emb_expanded.permute(0, 2, 1, 3) # Permute to (B, node_emb_dim, N, T)

        x = self.history_dropout(x)
        # 3. Concatenate with original input features along the channel dimension
        x_augmented = torch.cat([x, n_emb_expanded], dim=1) # Shape: (B, C_original + node_emb_dim, N, T)
        # --- END OF NODE EMBEDDING INTEGRATION ---
        
        # Now, use x_augmented for the rest of the forward pass, starting with in_proj
        supports = self._supports()
        x = self.in_proj(x_augmented) # Pass the augmented features to in_pro
        supports = self._supports()
        #x = self.in_proj(x)
        skips = []
        for blk in self.blocks:
            x, s = blk(x, supports)
            skips.append(s)
        x = F.relu(self.skip_proj(torch.cat(skips, dim=1)))
        x = F.relu(self.end1(x))
        params = self.param_conv(x)    # (B,out_dim,N,T)
        if self.output_dist=='LogNormal':
            # shape to (B,N,T,2)
            params = params.permute(0,2,3,1)
            return self.dist_layer(params)
        else:
            out = params.squeeze(1)     # (B,N,T)
            return out.permute(0,2,1)   # (B,T,N)

# Load Data

In [9]:
import os
import xarray as xr

# ------------------------------------------------------------------
# 1. Basic parameters (formerly in the YAML)
# ------------------------------------------------------------------
nwp_model      = "ch2"
d_map          = {"ch2": 96}          # hours that correspond to each NWP model
hours_leadtime = d_map[nwp_model]

val_split            = 0.20
test_start_date      = "2024-05-16"
train_val_end_date   = "2023-09-30"
target_var           = "obs:wind_speed"

predictors = [
    f"{nwp_model}:wind_speed_ensavg",
    f"{nwp_model}:wind_speed_ensstd",
    f"{nwp_model}:mslp_difference_GVE_GUT_ensavg",
    f"{nwp_model}:mslp_difference_BAS_LUG_ensavg",
    "time:sin_hourofday",
    "time:cos_hourofday",
    "time:sin_dayofyear",
    "time:cos_dayofyear",
    "terrain:elevation_50m",
    "terrain:distance_to_alpine_ridge",
    "terrain:tpi_2000m",
    "terrain:std_2000m",
    "terrain:valley_norm_2000m",
    "terrain:sn_derivative_500m",
    "terrain:sn_derivative_2000m",
    "terrain:we_derivative_500m",
    "terrain:we_derivative_2000m",
    "terrain:sn_derivative_100000m",
]

graph_kwargs = {
    "knn": 5,
    "threshold": 0.6,
    "theta": "std"
}

# ------------------------------------------------------------------
# 2. Paths to NetCDF feature / target files
# ------------------------------------------------------------------
# Option 1 – hard‑code:


# Option 2 – keep the same environment variable logic:
# DATA_BASE_FOLDER = os.environ["DATA_BASE_FOLDER"]
# features_pth = f"{DATA_BASE_FOLDER}/features.nc"
# targets_pth  = f"{DATA_BASE_FOLDER}/targets.nc"

# ------------------------------------------------------------------
# 3. Load datasets and build the DataModule
# ------------------------------------------------------------------
data_path = "./data"

ds = xr.open_dataset(f"{data_path}/features.nc")
ds_targets = xr.open_dataset(f"{data_path}/targets.nc")

dm = get_datamodule(
    ds=ds,
    ds_targets=ds_targets,
    val_split=val_split,
    test_start_date=test_start_date,
    train_val_end_date=train_val_end_date,
    lead_time_hours=hours_leadtime,
    predictors=predictors,
    target_var=target_var,
    return_graph=True,
    graph_kwargs=graph_kwargs,
)


  ds = xr.open_dataset(f"{data_path}/features.nc")
  ds_targets = xr.open_dataset(f"{data_path}/targets.nc")


Train&Val sets end at 2023-09-30
Test set starts at 2024-05-16


Outer Loop Progress: 100%|██████████| 152/152 [00:00<00:00, 521.86it/s]


In [10]:
import xarray as xr
from tsl.ops.connectivity import adj_to_edge_index
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

adj_matrix = dm.adj_matrix
edge_index, edge_weight = adj_to_edge_index(adj=torch.tensor(adj_matrix)) # NOTE not using w_ij for now


batch_size = 32
train_dataloader = DataLoader(dm.train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(dm.val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(dm.test_dataset, batch_size=batch_size, shuffle=False)

assert dm.train_dataset.stations == dm.val_dataset.stations == dm.test_dataset.stations # sanity check


# Hyperparameter search
The following cells contain the hyperparameter search code. The search is done using Weights and Biases (wandb). The user needs to set up a wandb account and specific and project and entity to log the data

In [None]:
from typing import List, Optional, Dict, Any 
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import LogNormal
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts, SequentialLR
import wandb
import itertools 
import random


def train_evaluate_model(config: Dict[str, Any],
                         dm: Any, 
                         train_loader: torch.utils.data.DataLoader,
                         val_loader: torch.utils.data.DataLoader,
                         global_settings: Dict[str, Any]):
    """
    Trains and evaluates a model for a given set of hyperparameters.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_type = "cuda" if torch.cuda.is_available() else "cpu"
    if device_type == 'cpu':
        torch.set_num_threads(16) 

    # Unpack global settings
    N = global_settings['num_stations']
    P = global_settings['num_predictors']
    L = global_settings['hours_leadtime'] 

    seed = config['seed']
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # for deterministic reductions
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

    graph_kwargs = {
        "knn": config["knn"],
        "threshold": config["threshold"],
        "theta": config["theta"]
    }
    
    data_path = global_settings["data_path"]
    
    ds = xr.open_dataset(f"{data_path}/features.nc")
    ds_targets = xr.open_dataset(f"{data_path}/targets.nc")
    
    dm = get_datamodule(
        ds=ds,
        ds_targets=ds_targets,
        val_split=val_split,
        test_start_date=test_start_date,
        train_val_end_date=train_val_end_date,
        lead_time_hours=hours_leadtime,
        predictors=predictors,
        target_var=target_var,
        return_graph=True,
        graph_kwargs=graph_kwargs,
        anomalous=config["anomalous"]
    )

    adj_matrix = dm.adj_matrix
    edge_index, edge_weight = adj_to_edge_index(adj=torch.tensor(adj_matrix)) # NOTE not using w_ij for now

    batch_size = global_settings["batch_size"]
    train_dataloader = DataLoader(dm.train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(dm.val_dataset, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(dm.test_dataset, batch_size=batch_size, shuffle=True)


    # Initialize model
    model = MultiScaleGraphWaveNet(
        num_nodes=N,
        in_channels=P,
        history_len=L, 
        layers=config['layers'],
        channels=config['channels'],
        emb_dim=config['emb_dim'],
        adj_matrix=dm.adj_matrix, 
        output_dist='LogNormal',
        drop=config['drop'],
        edge_drop_p=config['edge_dropout'],
        kernels=config['kernels'],
        dil=config['dil'],
        history_dropout_p=config['hist_drop'],
        history_block = config['history_block'],
        dynamic = config['dynamic'],
        node_emb_dim = config['node_emb_dim']
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=config['lr'], weight_decay=1e-5)
    criterion = MaskedCRPSLogNormal()
    mae = MaskedMAE()

    epochs = global_settings.get('epochs', 100)
    max_lr_scheduler = config.get('max_lr_scheduler', config['lr'] * 10) # Example: 10x initial LR

    #scheduler = OneCycleLR(
    #    optimizer,
    #    max_lr=max_lr_scheduler,
    #    epochs=epochs,
    #    steps_per_epoch=len(train_loader),
    #    pct_start=0.5
    #)
    if config['scheduler'] == "OneCycleLR":
        scheduler = OneCycleLR(
            optimizer,
            max_lr=max_lr_scheduler,
            epochs=epochs,
            steps_per_epoch=len(train_loader),
            div_factor=25,
            pct_start=0.5
        )
    elif config['scheduler'] == "Reduce":
        scheduler = ReduceLROnPlateau(
            optimizer,
            factor=0.7,
            patience=5,            # give it more time before cutting LR
            cooldown=2,            # wait 2 epochs after a LR drop
            min_lr=1e-7,
            threshold=5e-4,        # slightly looser threshold
            threshold_mode='rel'
        )
    elif config['scheduler'] == "Cosine":
        scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0 = len(train_loader),    # one epoch’s worth of iterations
            T_mult = 1,                  # keep each cycle the same length
            eta_min = 1e-5,              # floor learning‐rate
        )
    elif config['scheduler'] == "Sequential":
        oc = OneCycleLR(
            optimizer,
            max_lr = max_lr_scheduler,
            epochs = 100,
            steps_per_epoch = len(train_loader),
            pct_start = 0.5,
            anneal_strategy = 'cos'
        )
        cos = CosineAnnealingWarmRestarts(
            optimizer,
            #T_max = 100,      # next 100 epochs
            eta_min = 1e-6,
            T_0 = len(train_loader),      # 1-epoch cycles
            T_mult = 1
        )
        scheduler = [oc, cos]

    run_name = (
        f"MSGWN_"
        f"emb{config['emb_dim']}_"
        f"ch{config['channels']}_"
        f"lay{config['layers']}_"
        f"lr{config['lr']:.0e}_"
        f"drop{config['drop']}_"
        f"edge_drop{config['edge_dropout']}_"
        f"dil{config['dil']}_"
        f"kernels{config['kernels']}_"
        f"histdrop{config['hist_drop']}_"
        f"history_block{config['history_block']}_"
        f"scheduler{config['scheduler']}_"
        f"dynamic{config['dynamic']}_"
        f"node_emb_dim{config['node_emb_dim']}_"
        f"seed{config['seed']}_"
        f"knn{config["knn"]}_"
        f"threshold{config["threshold"]}_"
        f"theta{config["theta"]}_"
        f"anomalous{config["anomalous"]}"
    )


    wandb.init(entity=global_settings.get('your_entity', "your_entity"),
           project=global_settings.get('your_project', "your_project"),
           name=run_name,
           config=dict(
               model      ="MultiScalGraphWavenet",
               lr         =config['lr'],
               epochs     =epochs,
               max_lr     =max_lr_scheduler,
               n_samples  =20,
               emb_dim    =config['emb_dim'],
               channels   =config['channels'],
               layers     =config['layers'],
               drop       =config["drop"],
               edge_drop  =config["edge_dropout"],
               dil        =config["dil"],
               kernels    =config["kernels"],
               histdrop = config['hist_drop'],
               history_block = config['history_block'],
               scheduler = config['scheduler'],
               dynamic = config['dynamic'],
               node_emb_dim = config['node_emb_dim'],
               seed = config['seed'],
               knn = config["knn"],
               threshold = config["threshold"],
               theta = config["theta"],
               anomalous = config["anomalous"]
               ),
           reinit=True)
    wandb.watch(model, log="all", log_freq=100)

    cfg = wandb.config

    best_val_mae_overall = float("inf") # MAE over all lead times
    ckpt_path = f"{run_name}_best_model.pt"
    art_name = f"{run_name}-ckpt"

    total_iter = 0


    print(f"Training started for: {run_name}. Check WandB.")

    # ————————————————————————————————————————————————————————————————
    # 2) TRAINING
    # ————————————————————————————————————————————————————————————————
    for epoch in range(epochs):

        # ———————————————————————————
        # 2.1 TRAIN
        # ———————————————————————————
        model.train()
        train_loss_sum = 0.0
        train_error_sum  = 0.0
        i = 0
        train_bar = tqdm(train_dataloader, desc=f"Epoch {epoch} [Train]", leave=False)
        for x_batch, y_batch in train_bar:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()

            #preds = model(x_batch, edge_index)

            # For MultiGraph
            x_batch = x_batch.permute(0, 3, 2, 1)
            preds = model(x_batch)

            # 1.1 Compute loss
            loss = criterion(preds, y_batch)

            # 1.2 Compute a proxy “accuracy” from the predictive mean
            B, L, N, _ = y_batch.shape
            mean = preds.mean
            mean = mean.view(B, N, L, 1)
            mean = mean.permute(0, 2, 1, 3)
            pred_mean = mean.squeeze(-1)     # → [B, L, N]
            true_vals = y_batch.squeeze(-1)        # → [B, L, N]
            error, _ = mae(pred_mean, true_vals)

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if config["scheduler"] == "OneCycleLR":
                scheduler.step()
            elif config["scheduler"] == "Cosine":
                scheduler.step(epoch + i/len(train_dataloader))
            elif config["scheduler"] == "Sequential":
                if epoch <= 100:
                    scheduler[0].step()
                else:
                    scheduler[1].step(epoch + i/len(train_dataloader))
                    
                
            

            # 1.3 Accumulate & update bar
            train_loss_sum += loss.item()
            train_error_sum  += error.item()
            total_iter     += 1
            i += 1

            train_bar.set_postfix(
                loss = f"{loss.item():.4f}",
                ErrorOptionsWithWarn  = f"{error.item():.4f}"
            )

        avg_train_loss = train_loss_sum / len(train_dataloader)
        avg_train_error  = train_error_sum  / len(train_dataloader)

        # --  end of TRAIN loop ------------------------------------------------
        wandb.log({                                   # >>> NEW
            "loss/train_crps" : avg_train_loss,
            "mae/train"       : avg_train_error,
        }, step=epoch)


        # ————————————————————————————————————————————————————————————————
        # 2.2) VALIDATION
        # ————————————————————————————————————————————————————————————————
        model.eval()
        val_loss_sum = 0.0
        val_acc_sum  = 0.0

        sum_crps = {1:0, 24:0, 48:0, 96:0}
        sum_mae  = {1:0, 24:0, 48:0, 96:0}
        count_crps = {1:0, 24:0, 48:0, 96:0}
        count_mae  = {1:0, 24:0, 48:0, 96:0}

        val_bar = tqdm(val_dataloader, desc=f"Epoch {epoch} [ Valid ]", leave=False)
        with torch.no_grad():
            for x_batch, y_batch in val_bar:

                x_batch, y_batch = x_batch.to(device), y_batch.to(device)


                #preds = model(x_batch, edge_index)

                # For MultiGraph
                x_batch = x_batch.permute(0, 3, 2, 1)
                preds = model(x_batch)


                # 2.1 loss
                val_loss = criterion(preds, y_batch)

                # 2.2 accuracy on the mean
                # 1.2 Compute a proxy “accuracy” from the predictive mean
                # reshape
                B, L, N, _ = y_batch.shape
                mean = preds.mean
                mean = mean.view(B, N, L, 1)
                mean = mean.permute(0, 2, 1, 3)

                pred_mean = mean.squeeze(-1)     # → [B, L, N]
                true_vals = y_batch.squeeze(-1)
                val_acc, elements   = mae(pred_mean, true_vals)

                # 2.3 accumulate & update bar
                val_loss_sum += val_loss.item()
                val_acc_sum  += val_acc.item()

                val_bar.set_postfix(
                    loss = f"{val_loss.item():.4f}",
                    acc  = f"{val_acc.item():.4f}"
                )

                for h, idx in zip([1,24,48,96], [1,24,48,96]):
                  y_h    = y_batch[:,idx,:,:] # select time step
                  y_h    = y_h.unsqueeze(1) # extend shape at second dimension
                  crps_mean = criterion(preds, y_h, t=idx).item()
                  valid = torch.isfinite(y_h).sum().item() # count valid points

                  # 3) accumulate the *sum* of CRPS over all points
                  sum_crps[h] += crps_mean * valid
                  count_crps[h] += valid

                  B, L, N, _ = y_batch.shape
                  mean = preds.mean
                  mean = mean.view(B, N, L, 1)
                  mean = mean.permute(0, 2, 1, 3)
                  mean = mean[:, idx, :, :]
                  pred_mean = mean.unsqueeze(1)     # → [B, L, N]
                  pred_mean = pred_mean.squeeze(-1)
                  true_vals = y_h.squeeze(-1)

                  error, elements = mae(pred_mean, true_vals)
                  sum_mae[h]  += error * elements
                  count_mae[h]    += elements

        avg_val_loss = val_loss_sum / len(val_dataloader)
        avg_val_acc  = val_acc_sum  / len(val_dataloader)
        #scheduler.step(avg_val_loss)

        avg_crps = {h: sum_crps[h]/count_crps[h] for h in sum_crps}
        avg_mae  = {h: sum_mae [h]/count_mae[h] for h in sum_mae }

        metrics = {
            "loss/val_crps" : avg_val_loss,
            "mae/val"       : avg_val_acc,
        }

        # ———————————————————————————
        # 4.3  Wandb
        # ———————————————————————————
        for h in [1, 24, 48, 96]:
            metrics[f"crps_t{h}"] = avg_crps[h]
            metrics[f"mae_t{h}"]  = avg_mae[h]
        wandb.log(metrics, step=epoch)

        # ———————————————————————————
        # 4.4  CHECKPOINT (best only)
        # ———————————————————————————
        if avg_val_acc < best_val_mae_overall:
            best_val_mae_overall = avg_val_acc
            torch.save(model.state_dict(), ckpt_path)

            #art = wandb.Artifact(
            #    name     = art_name,        # → v0, v1, …
            #    type     = "model",
            #    metadata = {
            #        "epoch"   : epoch,
            #        "val_mae" : best_val_mae_overall,
            #        **cfg      # full sweep hyper-parameters
            #    }
            #)
            #art.add_file(ckpt_path)
            #wandb.log_artifact(art, aliases=["best"])

            wandb.run.summary["best_epoch"] = epoch
            wandb.run.summary["best_val_mae"] = best_val_mae_overall

            print(f"✅  new best @ {epoch:03d}  MAE={best_val_mae_overall:.4f}")


    wandb.finish() # Finish the current WandB run
    return best_val_mae_overall


def hyperparameter_grid_search(dm, train_loader, val_loader, global_settings):
    search_space = {
        'emb_dim': [8,16,32],
        'channels': [16,32,64],
        'layers': [2,3,4],
        'lr': [1e-3,1e-4,5e-4],
        'drop': [0.1, 0.2],
        "edge_dropout": [0.1, 0.2],
        "dil": [(1,2,4,8), (1,2,4)],
        "kernels": [(1,3,5,7), (1,3,5)],
        "hist_drop": [0.05, 0.07, 0.1],
        "history_block": [12, 14, 16],
        "scheduler": ["OneCycleLR", "CosineAnnealingWarmRestarts", "Reduce"],
        "dynamic": [True, False],
        "node_emb_dim": [10, 20, 30],
        'seed': [0,1,2,3,4],
        "knn": [5],
        "threshold": [0.6],
        "theta": ["factormedian"],
        "anomalous": [False]
    }
    

    # Create a list of all hyperparameter combinations
    keys, values = zip(*search_space.items())
    hyperparam_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

    best_overall_mae = float('inf')
    best_hyperparams = None

    print(f"Starting grid search with {len(hyperparam_combinations)} combinations.")

    for i, params in enumerate(hyperparam_combinations):
        print(f"\n--- Running trial {i+1}/{len(hyperparam_combinations)} ---")
        print(f"Parameters: {params}")
        current_run_config = params.copy() # Specific config for this run

        # You can add more fixed parameters to current_run_config if needed
        # e.g., current_run_config['dropout'] = 0.1 # If you had a fixed dropout

        val_mae = train_evaluate_model(
            config=current_run_config,
            dm=dm,
            train_loader=train_loader,
            val_loader=val_loader,
            global_settings=global_settings
        )

        if val_mae < best_overall_mae:
            best_overall_mae = val_mae
            best_hyperparams = params
            print(f"🎉 New overall best MAE: {best_overall_mae:.4f} with params: {best_hyperparams}")

    print("\n--- Grid Search Complete ---")
    print(f"Best MAE found: {best_overall_mae:.4f}")
    print(f"Best hyperparameters: {best_hyperparams}")
    return best_hyperparams, best_overall_mae


if __name__ == '__main__':


    global_run_settings = {
        "num_stations": dm.train_dataset.stations,
        "num_predictors": len(predictors),
        "hours_leadtime": 96, # Input sequence length for the model
        # "horizon": hours_leadtime, # Output sequence length, assumed to be same as history_len for this setup
        "epochs": 100, # For quick testing, use more for real runs (e.g., 100)
        "wandb_entity": "your_entity", # YOUR WANDB ENTITY
        "wandb_project": "your_project",
        "plot_every_epochs": 5,
        "batch_size": 32,
        "data_path": "./data"
        # Add other fixed settings like edge_index if needed by the model variant
    }



    best_params, best_score = hyperparameter_grid_search(
        dm=dm,
        train_loader=train_dataloader,
        val_loader=val_dataloader,
        global_settings=global_run_settings
    )

    print(f"\nFinal Best Hyperparameters: {best_params}")
    print(f"Final Best Validation MAE (Overall): {best_score}")

# Testing
The following cells contain the testing code, the user can specify a checkpoint path of a specific model to test which is based on the config file, alongside the data path.

In [11]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
import json

import matplotlib.pyplot as plt
import numpy as np
import torch

import numpy as np
import torch
import matplotlib.pyplot as plt


def make_rank_histograms(pred_dist: torch.distributions.Distribution,
                         y_true     : torch.Tensor,
                         lead_ids   : list[int],
                         n_samples  : int = 20,
                         epoch      : int | str = "final",
                        root_dir         : str  = "histograms"):
    """
    pred_dist.mean/scale shape : (B , N , T [, C])
    y_true                     : (B , T , N , 1)
    lead_ids                   : list of lead-time indices (e.g. [1,24,48,96])
    """

    y_true = y_true.to(pred_dist.loc.device)

    # --- 1. sample & reshape ------------------------------------------------
    # samples  -> (S , B , N , T)
    samples = pred_dist.rsample((n_samples,)).squeeze(-1)     # remove C / 1-axis
    samples = samples.permute(0, 1, 3, 2)                     # (S,B,T,N)

    # observation -> (B , T , N)
    obs = y_true.squeeze(-1)       
    hist_list = []                           # (B,T,N)

    # --- 2. iterate over requested lead times ------------------------------
    for t in lead_ids:
        # extract slice  ->  obs_t  (B,N) ;  samp_t  (S,B,N)
        obs_t   = obs[:,  t, :]                # (B,N)
        samp_t  = samples[:, :, t, :]          # (S,B,N)

        # flatten station & batch dims
        obs_1d  = obs_t.flatten()              # (B*N,)
        samp_2d = samp_t.reshape(n_samples, -1)  # (S, B*N)

        # remove NaNs (mask along last dim)
        mask       = torch.isfinite(obs_1d)
        obs_1d     = obs_1d[mask]                # (M,)
        samp_2d    = samp_2d[:, mask]            # (S,M)

        # --- 3. rank calculation ------------------------------------------
        # ranks are counted 1 … (S+1);  histogram has S+1 bins
        # rank = number of samples < obs   (+1)
        ranks = (samp_2d < obs_1d).sum(dim=0).cpu().numpy()   # (M,)
        n_bins = n_samples + 1
        hist   = np.bincount(ranks, minlength=n_bins)
        hist_list.append((t, hist))

        # --- 4. plotting ---------------------------------------------------
        plt.figure(figsize=(6,4))
        plt.bar(np.arange(n_bins), hist, width=0.9, color='steelblue')
        plt.xlabel("Rank (0 … 20)")
        plt.ylabel("Count")
        plt.title(f"Rank histogram  –  t = {t}h ")
        plt.tight_layout()

        #fname = f"rank_hist_t{t}_ep{epoch}.png"
        #plt.savefig(fname); plt.close()
        #print(f"Logged rank-histogram {fname}")

        # -- folder & save ---------------------------------------------------
        out_dir = os.path.join(root_dir)
        os.makedirs(out_dir, exist_ok=True)
        fname   = os.path.join(out_dir, f"rank_hist_t{t}.png")
        plt.savefig(fname); plt.close()
        print(f"Saved histogram → {fname}")
    # --- 3. plot 2×2 grid ---------------------------------------------------
    rows, cols = 2, 2
    fig, axs = plt.subplots(rows, cols, figsize=(12, 8))
    for idx, (t, hist) in enumerate(hist_list):
        r, c = divmod(idx, cols)
        ax = axs[r][c]
        ax.bar(np.arange(n_bins), hist, width=0.9, color="steelblue")
        ax.set_title(f"Rank histogram — lead = {t}h")
        ax.set_xlabel(f"Rank (0 … {n_bins-1})")
        ax.set_ylabel("Count")

    plt.tight_layout()

    # --- 4. save one combined figure ---------------------------------------
    os.makedirs(root_dir, exist_ok=True)
    fname = os.path.join(root_dir, f"rank_hist_all_leadtimes_ep{epoch}.png")
    fig.savefig(fname, dpi=150)
    plt.close(fig)

    print(f"Saved combined rank‐histogram → {fname}")


def log_prediction_plots(x, y, pred_dist,
                         example_indices, stations,
                         epoch, input_denormalizer,
                          root_dir: str = "predictions"):

    #------------------------------------------------------------
    # 0. raw shapes coming in
    #------------------------------------------------------------
    #print(f"x  in  (raw)   : {tuple(x.shape)}   # expect (B,T,N,F)")
    #print(f"y  in  (raw)   : {tuple(y.shape)}   # expect (B,T,N,1)")
    #print(f"dist batch-sh  : {pred_dist.loc.shape}  # (B,N,T[,C])")
    #print()

    #------------------------------------------------------------
    # 1. denormalise inputs, move to numpy
    #------------------------------------------------------------
    x = input_denormalizer(x)
    x = x.detach().cpu().numpy()
    y = y.detach().cpu().numpy()

    B, T, N, _ = y.shape
    #print(f"After CPU: x={x.shape}  y={y.shape}")
    #print()

    #------------------------------------------------------------
    # 2. predictive mean  -> (B,T,N,1)
    #------------------------------------------------------------
    mean = pred_dist.mean.squeeze(-1)        # drop channel dim if 1
    mean = mean.permute(0, 2, 1).unsqueeze(-1)   # (B,T,N,1)
    #print(f"mean for MAE      : {mean.shape}")
    #print()

    #------------------------------------------------------------
    # 3. five predictive quantiles  -> (B,T,N,5)
    #------------------------------------------------------------
    probs = torch.tensor([0.05, 0.25, 0.5, 0.75, 0.95],
                         device=pred_dist.loc.device)
    q_list = [pred_dist.icdf(p).unsqueeze(-1) for p in probs]
    quant  = torch.cat(q_list, dim=-1).squeeze(-2)   # (B,N,T,5)
    quant  = quant.permute(0, 2, 1, 3).detach().cpu().numpy() # (B,T,N,5)
    #print(f"quantiles tensor  : {quant.shape}   # expect (B,T,N,5)")
    #print()

    #------------------------------------------------------------
    # 4. plotting loop
    #------------------------------------------------------------
    fig, axs = plt.subplots(2, 2, figsize=(15, 8)); axs = axs.flatten()
    time = np.arange(T)

    for ax, b_idx, st in zip(axs, example_indices, stations):
        #print(f"---- slice check  batch={b_idx}  station={st}")
        #print("  ens_mean slice :", x[b_idx, :, st, 0].shape)
        #for k in range(5):
        #    print(f"  quant(k={k})   :", quant[b_idx, :, st, k].shape)
        #print("  observation    :", y[b_idx, :, st, 0].shape)
        #print()

        ax.plot(time, x[b_idx, :, st, 0], label='ens_mean', color='forestgreen')
        ax.fill_between(time, quant[b_idx,:,st,0], quant[b_idx,:,st,4],
                        alpha=0.15, color='steelblue', label='5–95 %')
        ax.fill_between(time, quant[b_idx,:,st,1], quant[b_idx,:,st,3],
                        alpha=0.35, color='steelblue', label='25–75 %')
        ax.plot(time, quant[b_idx,:,st,2], ls='--', color='black', label='median')
        ax.plot(time, y[b_idx,:,st,0], color='crimson', label='observed')

        ax.set_title(f'Station {st}   batch {b_idx}')
        ax.set_xlabel('Lead time'); ax.set_ylabel('Wind speed')

    axs[-1].legend(loc='upper left')
    plt.suptitle(f'Predictions of test data')
    plt.tight_layout()

    #fname = f'predictions_epoch_{epoch}.png'
    #plt.savefig(fname); plt.close(fig)
    ##mlflow.log_artifact(fname)
    #print(f"Logged figure   : {fname}")

        # -- ensure folder & save -----------------------------------------------------
    os.makedirs(root_dir, exist_ok=True)
    fname = os.path.join(root_dir, f"predictions_epoch_{epoch}.png")
    plt.savefig(fname); plt.close(fig)
    print(f"Saved fan-chart → {fname}")


# Predefined hyperparameter configuration from grid search
config = {
    'emb_dim': 16,
    'channels': 16,
    'layers': 3,
    'lr': 1e-3,
    'drop': 0.2,
    'edge_dropout': 0.2,
    'dil': (1, 2, 4, 8),
    'kernels': (1, 3, 5, 7),
    'hist_drop': 0.07,
    'history_block': 12,
    'scheduler': 'OneCycleLR',
    'dynamic': [True],
    'node_emb_dim': 20,
    "seed": [0],
    "knn": [5],
    "treshold": 0.6,
    "theta": ["std"],
    "anomalous": False
}

# Global settings for data
data_cfg = {
    'data_path': '/path/to/data',       # adjust to your dataset location
    'batch_size': 64,
    'num_workers': 4,
    'history_len': 96,                  # input sequence length
    'horizon': 96                       # output sequence length
}

# Instantiate data manager and test loader
criterion = MaskedCRPSLogNormal()
mae = MaskedMAE()


# Build the run name to locate checkpoint
for dynamic in config["dynamic"]:
    for seed in  config["seed"]:
        for knn in config["knn"]:
            for theta in config["theta"]:
                run_name = (
                    f"MSGWN_"
                    f"emb{config['emb_dim']}_"
                    f"ch{config['channels']}_"
                    f"lay{config['layers']}_"
                    f"lr{config['lr']:.0e}_"
                    f"drop{config['drop']}_"
                    f"edge_drop{config['edge_dropout']}_"
                    f"dil{config['dil']}_"
                    f"kernels{config['kernels']}_"
                    f"histdrop{config['hist_drop']}_"
                    f"history_block{config['history_block']}_"
                    f"scheduler{config['scheduler']}_"
                    f"dynamic{dynamic}_"
                    f"node_emb_dim{config['node_emb_dim']}_"
                    f"seed{seed}"
                    #f"knn{knn}_"
                    #f"threshold{config["treshold"]}_"
                    #f"theta{theta}"
                    #f"anomalous{config["anomalous"]}"
                )
                checkpoint_path = f"{run_name}_best_model(1).pt"
                print("checkpoint_path")
                print(theta)
                graph_kwargs = {
                    "knn": knn,
                    "threshold": config["treshold"],
                    "theta": theta
                }
    
                data_path = "./data"
    
                ds = xr.open_dataset(f"{data_path}/features.nc")
                ds_targets = xr.open_dataset(f"{data_path}/targets.nc")
                
    
                dm = get_datamodule(
                    ds=ds,
                    ds_targets=ds_targets,
                    val_split=val_split,
                    test_start_date=test_start_date,
                    train_val_end_date=train_val_end_date,
                    lead_time_hours=hours_leadtime,
                    predictors=predictors,
                    target_var=target_var,
                    return_graph=True,
                    graph_kwargs=graph_kwargs,
                    anomalous=config["anomalous"]
                )

                test_dataloader = DataLoader(dm.test_dataset, batch_size=32, shuffle=False)

                
                # Prepare model instantiation kwargs
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                model_kwargs = {
                    'num_nodes': dm.train_dataset.stations,
                    'in_channels': len(predictors),
                    'history_len': data_cfg['history_len'],
                    'horizon': data_cfg['horizon'],
                    'layers': config['layers'],
                    'channels': config['channels'],
                    'emb_dim': config['emb_dim'],
                    'node_emb_dim': config['node_emb_dim'],
                    'adj_matrix': dm.adj_matrix,
                    'learnable_adj': True,
                    'kernels': config['kernels'],
                    'dil': config['dil'],
                    'drop': config['drop'],
                    'edge_drop_p': config['edge_dropout'],
                    'history_dropout_p': config['hist_drop'],
                    'history_block': config['history_block'],
                    'dynamic': config['dynamic'],
                    'output_dist': 'LogNormal'
                }
                
                # Define helper functions
                
                def load_model(path, device, model_kwargs):
                    model = MultiScaleGraphWaveNet(**model_kwargs).to(device)
                    state = torch.load(path, map_location=device)
                    model.load_state_dict(state, strict=False)
                    model.eval()
                    return model
                
                
                def test_model(model, loader, device):
                    means, stds = [], []
                    val_loss_sum = 0.0
                    val_acc_sum  = 0.0
            
                    sum_crps = {1:0, 24:0, 48:0, 96:0}
                    sum_mae  = {1:0, 24:0, 48:0, 96:0}
                    count_crps = {1:0, 24:0, 48:0, 96:0}
                    count_mae  = {1:0, 24:0, 48:0, 96:0}
                    total_iter = 0

                    all_dists = []    # will hold one torch.distributions.LogNormal per batch
                    all_y     = []    # ground truth for debugging/metrics
                    with torch.no_grad():
                        for x_batch, y_batch in loader:
                            # x_batch: (B, C, N, T)
                            x_batch = x_batch.to(device)
                            y_batch = y_batch.to(device)
                            x_gpu = x_batch.permute(0, 3, 2, 1)
                            dist = model(x_gpu)
                            all_dists.append(dist)
                            all_y.append(y_batch) 
                            means.append(dist.loc.cpu().numpy())
                            stds.append(dist.scale.cpu().numpy())
                            val_loss = criterion(dist, y_batch)

                            B, L, N, _ = y_batch.shape
                            mean = dist.mean
                            mean = mean.view(B, N, L, 1)
                            mean = mean.permute(0, 2, 1, 3)
            
                            pred_mean = mean.squeeze(-1)     # → [B, L, N]
                            true_vals = y_batch.squeeze(-1)
                            val_acc, elements   = mae(pred_mean, true_vals)
            
                            # 2.3 accumulate & update bar
                            val_loss_sum += val_loss.item()
                            val_acc_sum  += val_acc.item()
                            for h, idx in zip([1,24,48,96], [1,24,48,96]):
                                y_h    = y_batch[:,idx,:,:] # select time step
                                y_h    = y_h.unsqueeze(1) # extend shape at second dimension
                                crps_mean = criterion(dist, y_h, t=idx).item()
                                valid = torch.isfinite(y_h).sum().item() # count valid points
            
                                # 3) accumulate the *sum* of CRPS over all points
                                sum_crps[h] += crps_mean * valid
                                count_crps[h] += valid
            
                                B, L, N, _ = y_batch.shape
                                mean = dist.mean
                                mean = mean.view(B, N, L, 1)
                                mean = mean.permute(0, 2, 1, 3)
                                mean = mean[:, idx, :, :]
                                pred_mean = mean.unsqueeze(1)     # → [B, L, N]
                                pred_mean = pred_mean.squeeze(-1)
                                true_vals = y_h.squeeze(-1)
            
                                error, elements = mae(pred_mean, true_vals)
                                error = error.item()     # now a float
                                sum_mae[h]  += error * elements
                                count_mae[h]    += elements
                            if total_iter == 0:
                                log_prediction_plots(
                                    x=x_batch,  # raw inputs
                                    y= y_batch,
                                    pred_dist=dist,     # or re‐wrap mus/sigmas into distributions
                                    example_indices=[0,0,0,0,0],
                                    stations=[1,2,3,4],
                                    epoch=0,
                                    input_denormalizer=dm.test_dataset.input_denormalizer
                            )
                            total_iter += 1
                    val_loss_sum /= len(loader)
                    avg_val_acc  = val_acc_sum  / len(loader)
                    #scheduler.step(avg_val_loss)

                    avg_crps = {h: sum_crps[h]/count_crps[h] for h in sum_crps}
                    avg_mae  = {h: sum_mae [h]/count_mae[h] for h in sum_mae }

                    avg_crps = {h: float(v) for h, v in avg_crps.items()}
                    avg_mae  = {h: float(v) for h, v in avg_mae.items()}
                    ys     = torch.cat(all_y,                             dim=0)  # (n_total, T, N, 1)
                    return np.concatenate(means, axis=0), np.concatenate(stds, axis=0), val_loss_sum, avg_val_acc, avg_crps, avg_mae, ys
                
                # Load, test, and save outputs
                model = load_model(f"./models/{checkpoint_path}", device, model_kwargs)
                mean_vals, std_vals, val_loss_sum, avg_val_acc, avg_crps, avg_mae, ys = test_model(model, test_dataloader, device)

                make_rank_histograms(
                    pred_dist=torch.distributions.LogNormal(torch.from_numpy(mean_vals).to(device), torch.from_numpy(std_vals).to(device)),
                    y_true=ys,
                    lead_ids=[1,24,48,96],
                    n_samples=20,
                    epoch=0
                )
                
                # Save distributions over all test samples
                os.makedirs('outputs', exist_ok=True)
                output_file = os.path.join('outputs', f'{run_name}_test_outputs.npz')  
                np.savez_compressed(output_file, mean=mean_vals, std=std_vals)
                
                # Record the summary metrics
                summary = {
                    "run_name": run_name,
                    "val_loss": val_loss_sum,
                    "avg_val_acc": avg_val_acc,
                    "avg_crps": avg_crps,
                    "avg_mae": avg_mae
                }

                print(summary)

                # Write it out to a JSON named after the run
                json_path = os.path.join('outputs', f"{run_name}_results.json")
                with open(json_path, 'w') as jf:
                    json.dump(summary, jf, indent=2)

                    
                print(f"Saved test outputs to {output_file}")


  ds = xr.open_dataset(f"{data_path}/features.nc")
  ds_targets = xr.open_dataset(f"{data_path}/targets.nc")


checkpoint_path
std
Train&Val sets end at 2023-09-30
Test set starts at 2024-05-16


Outer Loop Progress: 100%|██████████| 152/152 [00:00<00:00, 496.08it/s]


Saved fan-chart → predictions/predictions_epoch_0.png
Saved histogram → histograms/rank_hist_t1.png
Saved histogram → histograms/rank_hist_t24.png
Saved histogram → histograms/rank_hist_t48.png
Saved histogram → histograms/rank_hist_t96.png
Saved combined rank‐histogram → histograms/rank_hist_all_leadtimes_ep0.png
{'run_name': 'MSGWN_emb16_ch16_lay3_lr1e-03_drop0.2_edge_drop0.2_dil(1, 2, 4, 8)_kernels(1, 3, 5, 7)_histdrop0.07_history_block12_schedulerOneCycleLR_dynamicTrue_node_emb_dim20_seed0', 'val_loss': 0.6078862925060093, 'avg_val_acc': 0.8730309940874577, 'avg_crps': {1: 0.560420460416879, 24: 0.5798298490246374, 48: 0.6014024943476386, 96: 0.669989733690849}, 'avg_mae': {1: 0.8108240962028503, 24: 0.8318228721618652, 48: 0.8641216158866882, 96: 0.9646117687225342}}
Saved test outputs to outputs/MSGWN_emb16_ch16_lay3_lr1e-03_drop0.2_edge_drop0.2_dil(1, 2, 4, 8)_kernels(1, 3, 5, 7)_histdrop0.07_history_block12_schedulerOneCycleLR_dynamicTrue_node_emb_dim20_seed0_test_outputs.npz
