This notebook produces visualizations of the model interpretable parameters (or parameter–derived quantities) in space and time.

In [None]:
import jax
jax.config.update("jax_enable_x64", True)

<contextlib.ExitStack at 0x7f85c85a9690>

In [None]:
from pathlib import Path

import dask
import diffrax as dfx
import equinox as eqx
import jax.numpy as jnp
import jax.random as jrd
from jaxtyping import Float, Key, Real
from hydra_zen import instantiate, load_from_yaml
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from tqdm import tqdm
import xarray as xr

from pastax.gridded import Gridded
from pastax.simulator import DeterministicSimulator
from pastax.trajectory import Trajectory

from src.ec_mlp.data_driven_model import DataDrivenModel as StochDataDrivenModel
from src.ec_mlp.drift_model import DriftModel as StochDriftModel
from src.ec_mlp.trainer_module import TrainerModule as StochTrainerModule
from src.ls_mlp.data_driven_model import DataDrivenModel as DeterDataDrivenModel
from src.ls_mlp.drift_model import DriftModel as DeterStochModel
from src.ls_mlp.trainer_module import TrainerModule as DeterTrainerModule

In [None]:
ROOT = Path(".")

STOCH_RUNS_DIR = ROOT / "euler_criterion/mlp/multirun/2026-02-08/09-28-33"
STOCH_RUN_ID = "15"

DETER_RUNS_DIR = ROOT / "least_squares/mlp/multirun/2026-02-05/11-31-57"
DETER_RUN_ID = "25"

DATA_DIR = ROOT / "data"

## Load best models

In [None]:
stoch_run_cfg = load_from_yaml(STOCH_RUNS_DIR / STOCH_RUN_ID / ".hydra/config.yaml")
deter_run_cfg = load_from_yaml(DETER_RUNS_DIR / DETER_RUN_ID / ".hydra/config.yaml")

In [None]:
trunk = instantiate(stoch_run_cfg.trunk)()
physical_head = instantiate(stoch_run_cfg.physical_head)()
mdn_head = instantiate(stoch_run_cfg.mdn_head)()

stoch_data_driven_model = StochDataDrivenModel(trunk, physical_head, mdn_head)
stoch_drift_model = StochDriftModel(
    data_driven_model=stoch_data_driven_model, 
    stress_normalization=1.0, 
    wind_normalization=1.0,
    delta_t=1.0 * 60.0 * 60.0  # 1 hour in seconds
)

stoch_drift_model = StochTrainerModule.load_from_checkpoint(
    STOCH_RUNS_DIR / STOCH_RUN_ID / "best_model.ckpt", drift_model=stoch_drift_model
).drift_model

In [None]:
deter_data_driven_model = DeterDataDrivenModel(trunk, physical_head, mdn_head)
deter_drift_model = DeterStochModel(
    data_driven_model=deter_data_driven_model, 
    stress_normalization=1.0, 
    wind_normalization=1.0
)

deter_drift_model = DeterTrainerModule.load_from_checkpoint(
    DETER_RUNS_DIR / DETER_RUN_ID / "best_model.ckpt", drift_model=deter_drift_model
).drift_model

/home/bertrava/.local/share/mamba/envs/pastax_global_calibration/lib/python3.11/site-packages/lightning/pytorch/core/saving.py:96: The state dict in PosixPath('euler_criterion/mlp/multirun/2026-02-08/09-28-33/15/best_model.ckpt') contains no parameters.


## Reference trajectories

In [8]:
from_datetime_str = "1994-06-01"
to_datetime_str = "2025-08-01"

In [None]:
traj_ds = xr.open_zarr(DATA_DIR / f"gdp_interp_clean_{from_datetime_str}_{to_datetime_str}_test_traj.zarr")

## Forcings

In [None]:
duacs_ds = xr.open_zarr(DATA_DIR / f".zarr")
era5_ds = xr.open_zarr(DATA_DIR / f".zarr")

## Dataloading

In [None]:
class Dataset(TorchDataset):
    def __init__(self, traj_ds: xr.Dataset, duacs_ds: xr.Dataset, era5_ds: xr.Dataset, periodic_domain: bool = True):
        self.traj_ds = traj_ds
        self.duacs_ds = duacs_ds
        self.era5_ds = era5_ds
        self.periodic_domain = periodic_domain

        max_travel_distance = .5  # in ° / day ; inferred from data
        traj_t0_t1 = traj_ds.time.isel(traj=0)[np.asarray([0, -1])]
        n_days = ((traj_t0_t1[-1] - traj_t0_t1[0]) / np.timedelta64(1, "D")).astype(int).item()
        self.max_travel_distance = max_travel_distance * n_days

        self.duacs_nt = duacs_ds.time.size
        self.duacs_mint = duacs_ds.time.min()
        self.duacs_dt = duacs_ds.time[1] - duacs_ds.time[0]
        self.duacs_t_di = np.ceil(np.timedelta64(n_days, "D") / self.duacs_dt)
        self.duacs_nlat = duacs_ds.latitude.size
        self.duacs_minlat = duacs_ds.latitude.min()
        self.duacs_dlat = duacs_ds.latitude[1] - duacs_ds.latitude[0]  # regular grid
        self.duacs_lat_di = np.ceil(self.max_travel_distance / self.duacs_dlat)
        self.duacs_nlon = duacs_ds.longitude.size
        self.duacs_minlon = duacs_ds.longitude.min()
        self.duacs_dlon = duacs_ds.longitude[1] - duacs_ds.longitude[0]  # regular grid
        self.duacs_lon_di = np.ceil(self.max_travel_distance / self.duacs_dlon)

        self.era5_nt = era5_ds.time.size
        self.era5_mint = era5_ds.time.min()
        self.era5_dt = era5_ds.time[1] - era5_ds.time[0]
        self.era5_t_di = np.ceil(np.timedelta64(n_days, "D") / self.era5_dt)
        self.era5_nlat = era5_ds.latitude.size
        self.era5_minlat = era5_ds.latitude.min()
        self.era5_dlat = era5_ds.latitude[1] - era5_ds.latitude[0]  # regular grid
        self.era5_lat_di = np.ceil(self.max_travel_distance / self.era5_dlat)
        self.era5_nlon = era5_ds.longitude.size
        self.era5_minlon = era5_ds.longitude.min()
        self.era5_dlon = era5_ds.longitude[1] - era5_ds.longitude[0]  # regular grid
        self.era5_lon_di = np.ceil(self.max_travel_distance / self.era5_dlon)

    def __len__(self):
        return self.traj_ds.traj.size

    def __getitem__(
        self, idx: int
    ) -> tuple[
        tuple[
            Float[np.ndarray, "traj_length"], 
            Float[np.ndarray, "traj_length"], 
            Float[np.ndarray, "traj_length"], 
            Float[np.ndarray, "traj_length"]
        ],
        tuple[dict[str, Float[np.ndarray, "T N N"]], 
            Float[np.ndarray, "T"], 
            Float[np.ndarray, "N"], 
            Float[np.ndarray, "N"]
        ],
        tuple[dict[str, Float[np.ndarray, "T N N"]], 
            Float[np.ndarray, "T"], 
            Float[np.ndarray, "N"], 
            Float[np.ndarray, "N"]
        ]
    ]:
        traj_arrays = self.__get_traj_arrays(idx)
        duacs_arrays = self.__get_duacs_arrays(*traj_arrays[:3])
        era5_arrays = self.__get_era5_arrays(*traj_arrays[:3])
        
        return traj_arrays, duacs_arrays, era5_arrays
    
    def __get_traj_arrays(
        self, idx: int
    ) -> tuple[
        Float[np.ndarray, "traj_length"], 
        Float[np.ndarray, "traj_length"], 
        Float[np.ndarray, "traj_length"], 
        Float[np.ndarray, "traj_length"]
    ]:
        traj_subset = self.traj_ds.isel(traj=idx)
        
        traj_lat = traj_subset.lat.values.ravel()
        traj_lon = traj_subset.lon.values.ravel()
        traj_time = traj_subset.time.values.ravel().astype("datetime64[s]").astype(int)  # in seconds
        traj_id = traj_subset.id.values.ravel()
        
        return traj_lat, traj_lon, traj_time, traj_id
    
    def __get_duacs_arrays(
        self, 
        traj_lat: Float[np.ndarray, "traj_length"], 
        traj_lon: Float[np.ndarray, "traj_length"], 
        traj_time: Float[np.ndarray, "traj_length"]
    ) -> tuple[
        dict[str, Float[np.ndarray, "T N N"]], 
        Float[np.ndarray, "T"], 
        Float[np.ndarray, "N"], 
        Float[np.ndarray, "N"]
    ]:
        return self.__get_forcing_arrays(
            traj_lat, traj_lon, traj_time,
            self.duacs_ds, ("ugos", "vgos"),
            self.duacs_nt, self.duacs_mint, self.duacs_dt, self.duacs_t_di,
            self.duacs_nlat, self.duacs_minlat, self.duacs_dlat, self.duacs_lat_di,
            self.duacs_nlon, self.duacs_minlon, self.duacs_dlon, self.duacs_lon_di
        )
    
    def __get_era5_arrays(
        self, 
        traj_lat: Float[np.ndarray, "traj_length"], 
        traj_lon: Float[np.ndarray, "traj_length"], 
        traj_time: Float[np.ndarray, "traj_length"]
    ) -> tuple[
        dict[str, Float[np.ndarray, "T N N"]], 
        Float[np.ndarray, "T"], 
        Float[np.ndarray, "N"], 
        Float[np.ndarray, "N"]
    ]:
        return self.__get_forcing_arrays(
            traj_lat, traj_lon, traj_time,
            self.era5_ds, ("eastward_stress", "northward_stress", "eastward_wind", "northward_wind"),
            self.era5_nt, self.era5_mint, self.era5_dt, self.era5_t_di,
            self.era5_nlat, self.era5_minlat, self.era5_dlat, self.era5_lat_di,
            self.era5_nlon, self.era5_minlon, self.era5_dlon, self.era5_lon_di
        )
    
    def __get_forcing_arrays(
        self, 
        traj_lat: Float[np.ndarray, "traj_length"], 
        traj_lon: Float[np.ndarray, "traj_length"], 
        traj_time: Float[np.ndarray, "traj_length"],
        ds: xr.Dataset, 
        vars_names: tuple[str],
        nt: int, mint: np.datetime64, dt: np.timedelta64, t_di: int, 
        nlat: int, minlat: float, dlat: float, lat_di: float, 
        nlon: int, minlon: float, dlon: float, lon_di: float
    ) -> tuple[
        dict[str, Float[np.ndarray, "T N N"]], 
        Float[np.ndarray, "T"], 
        Float[np.ndarray, "N"], 
        Float[np.ndarray, "N"]
    ]:
        def get_latlon_minmax(latlon0_i, latlon_di):
            min_i = (latlon0_i - latlon_di).astype(int).item()
            max_i = (latlon0_i + latlon_di).astype(int).item()
            return min_i, max_i
        
        def get_pads(min_i, max_i, n):
            padleft = max(0, -min_i)
            min_i = max(0, min_i)
            padright = max(0, max_i - (n - 1))
            max_i = min(n - 1, max_i)
            return (padleft, padright), (min_i, max_i)
    
        t0 = traj_time[0].astype("datetime64[s]")
        lat0 = traj_lat[0]
        lon0 = traj_lon[0]

        t0_i = np.floor((t0 - mint) / dt)
        lat0_i = ((lat0 - minlat) / dlat).round()
        lon0_i = ((lon0 - minlon) / dlon).round()

        tmin_i = t0_i.astype(int).item()
        tmax_i = (t0_i + t_di).astype(int).item()
        latmin_i, latmax_i = get_latlon_minmax(lat0_i, lat_di)
        lonmin_i, lonmax_i = get_latlon_minmax(lon0_i, lon_di)

        (t_padleft, t_padright), (tmin_i, tmax_i) = get_pads(tmin_i, tmax_i, nt)
        (lat_padleft, lat_padright), (latmin_i, latmax_i) = get_pads(latmin_i, latmax_i, nlat)
        (lon_padleft, lon_padright), (lonmin_i, lonmax_i) = get_pads(lonmin_i, lonmax_i, nlon)

        patch = ds.isel(
            time=slice(tmin_i, tmax_i + 1),
            latitude=slice(latmin_i, latmax_i + 1), 
            longitude=slice(lonmin_i, lonmax_i + 1)
        )

        patch_vars = dict((var_name, patch[var_name]) for var_name in vars_names)
        patch_time = patch.time.astype("datetime64[s]").astype(int)  # in seconds
        patch_lat = patch.latitude
        patch_lon = patch.longitude

        if self.periodic_domain:  # periodic global domain
            if lon_padleft != 0:
                patch_left = ds.isel(
                    time=slice(tmin_i, tmax_i + 1),
                    latitude=slice(latmin_i, latmax_i + 1), 
                    longitude=slice(nlon - lon_padleft, nlon)  # right part goes to the left
                )

                patch_vars_left = dict((var, patch_left[var]) for var in vars_names)
                lon_left = patch_left.longitude

                patch_vars = dict(
                    (var_name, np.concat([patch_vars_left[var_name], patch_vars[var_name]], axis=-1)) 
                    for var_name in vars_names
                )
                patch_lon = np.concat([lon_left, patch_lon])

                lon_padleft = 0

            if lon_padright != 0:
                patch_right = ds.isel(
                    time=slice(tmin_i, tmax_i + 1),
                    latitude=slice(latmin_i, latmax_i + 1), 
                    longitude=slice(0, lon_padright)  # left part goes to the right
                )

                patch_vars_right = [patch_right[var] for var in vars_names]
                lon_right = patch_right.longitude

                patch_vars = dict(
                    (var_name, np.concat([patch_vars[var_name], patch_vars_right[var_name]], axis=-1)) 
                    for var_name in vars_names
                )
                patch_lon = np.concat([patch_lon, lon_right])

                lon_padright = 0

        patch_vars = dict(
            (
                var_name,
                np.pad(
                    patch_vars[var_name], 
                    ((t_padleft, t_padright), (lat_padleft, lat_padright), (lon_padleft, lon_padright)), 
                    mode="edge"
                )
            ) for var_name in vars_names
        )
        patch_time = np.pad(patch_time, (t_padleft, t_padright), mode="edge")
        patch_lat = np.pad(patch_lat, (lat_padleft, lat_padright), mode="edge")
        patch_lon = np.pad(patch_lon, (lon_padleft, lon_padright), mode="edge")
        
        return patch_vars, patch_time, patch_lat, patch_lon

In [None]:
torch_dataset = Dataset(traj_ds, duacs_ds, era5_ds)

In [None]:
batch_size = 512
dl_num_workers = 32
multiprocessing_context = "spawn"
prefetch_factor = 2

dask_n_workers = 8

In [None]:
xr_jax_dataloader = DataLoader(
    torch_dataset,
    batch_size=batch_size, 
    shuffle=False,
    num_workers=dl_num_workers, 
    pin_memory=False,
    multiprocessing_context=multiprocessing_context,
    prefetch_factor=prefetch_factor,
    persistent_workers=True, 
    in_order=False
)

In [11]:
def to_jax(arr: Float[np.ndarray, "..."]) -> Float[jax.Array, "..."]:
    return jnp.asarray(arr)


@eqx.filter_jit  # this improves performances
def to_trajectories(
    traj_arrays: tuple[
        Float[jax.Array, "batch traj_length"], 
        Float[jax.Array, "batch traj_length"], 
        Float[jax.Array, "batch traj_length"], 
        Float[jax.Array, "batch traj_length"]
    ]
) -> Trajectory:
    traj_lat, traj_lon, traj_time, traj_id = traj_arrays

    traj_latlon = jnp.stack((traj_lat, traj_lon), axis=-1)
    trajectories = eqx.filter_vmap(
        lambda _latlon, _time, _id: Trajectory.from_array(values=_latlon, times=_time, id=_id)
    )(
        traj_latlon, traj_time, traj_id
    )

    return trajectories


@eqx.filter_jit  # this improves performances
def to_gridded(
    forcings_arrays: tuple[
        dict[str, Float[jax.Array, "batch T N N"]], 
        Float[jax.Array, "batch T"], 
        Float[jax.Array, "batch N"], 
        Float[jax.Array, "batch N"]
    ]
) -> Gridded:
    dict_vars, time, lat, lon = forcings_arrays
    
    gridded = eqx.filter_vmap(Gridded.from_array)(
        dict_vars, time, lat, lon
    )

    return gridded

In [None]:
dask.config.set(scheduler="threads", num_workers=dask_n_workers)

## Drift integration

In [None]:
def month_from_epoch(seconds):
    # Days since Unix epoch (1970-01-01)
    days = seconds // 86400

    # Convert Unix epoch days to civil date where a year starts on March 1st
    # Unix epoch 1970-01-01 corresponds to civil date offset 719468
    z = days + 719468

    # The leap-year pattern repeats every 400 years (146097 days), one era corresponds to one such period
    era = (z >= 0).astype(jnp.int64) * z // 146097 + (z < 0).astype(jnp.int64) * ((z - 146096) // 146097)
    doe = z - era * 146097
    yoe = (doe - doe // 1460 + doe // 36524 - doe // 146096) // 365
    doy = doe - (365 * yoe + yoe // 4 - yoe // 100)
    
    # The month is shifted back so that January is month 1
    mp = (5 * doy + 2) // 153
    month = mp + 3 - 12 * (mp // 10)

    return month


def stoch_dynamics(
    t: Real[jax.Array, ""], 
    x: Float[jax.Array, "2"], 
    args: tuple[StochDriftModel, Float[jax.Array, ""], Gridded, Gridded, Real[jax.Array, ""], Key[jax.Array, "n_steps"]]
) -> Float[jax.Array, "2"]:
    lat, lon = x
    drift_model, delta_t, duacs, era5, t0, keys = args

    key = keys[((t - t0) / delta_t).astype(int)]

    duacs_forcing = duacs.interp("ugos", "vgos", time=t, latitude=lat, longitude=lon)
    era5_forcing = era5.interp(
        "eastward_stress", "northward_stress", "eastward_wind", "northward_wind", time=t, latitude=lat, longitude=lon
    )

    ugos, vgos = duacs_forcing["ugos"], duacs_forcing["vgos"]
    eastward_stress, northward_stress = era5_forcing["eastward_stress"], era5_forcing["northward_stress"]
    eastward_wind, northward_wind = era5_forcing["eastward_wind"], era5_forcing["northward_wind"]

    u, v = drift_model.sample_velocity(
        ugos, vgos, eastward_stress, northward_stress, eastward_wind, northward_wind, 
        month_from_epoch(t), lat, lon,
        key, 
        delta_t
    )

    return jnp.stack((v, u), axis=-1)


def mode_dynamics(
    t: Real[jax.Array, ""], 
    x: Float[jax.Array, "2"], 
    args: tuple[StochDriftModel, Gridded, Gridded]
) -> Float[jax.Array, "2"]:
    lat, lon = x
    drift_model, duacs, era5 = args

    duacs_forcing = duacs.interp("ugos", "vgos", time=t, latitude=lat, longitude=lon)
    era5_forcing = era5.interp(
        "eastward_stress", "northward_stress", "eastward_wind", "northward_wind", time=t, latitude=lat, longitude=lon
    )

    ugos, vgos = duacs_forcing["ugos"], duacs_forcing["vgos"]
    eastward_stress, northward_stress = era5_forcing["eastward_stress"], era5_forcing["northward_stress"]
    eastward_wind, northward_wind = era5_forcing["eastward_wind"], era5_forcing["northward_wind"]

    u, v = drift_model.estimate_mode_velocity(
        ugos, vgos, eastward_stress, northward_stress, eastward_wind, northward_wind, 
        month_from_epoch(t), lat, lon
    )

    return jnp.stack((v, u), axis=-1)


def deter_dynamics(
    t: Real[jax.Array, ""], 
    x: Float[jax.Array, "2"], 
    args: tuple[DeterStochModel, Gridded, Gridded]
) -> Float[jax.Array, "2"]:
    lat, lon = x
    drift_model, duacs, era5 = args

    duacs_forcing = duacs.interp("ugos", "vgos", time=t, latitude=lat, longitude=lon)
    era5_forcing = era5.interp(
        "eastward_stress", "northward_stress", "eastward_wind", "northward_wind", time=t, latitude=lat, longitude=lon
    )

    ugos, vgos = duacs_forcing["ugos"], duacs_forcing["vgos"]
    eastward_stress, northward_stress = era5_forcing["eastward_stress"], era5_forcing["northward_stress"]
    eastward_wind, northward_wind = era5_forcing["eastward_wind"], era5_forcing["northward_wind"]

    uv = drift_model(
        ugos, vgos, eastward_stress, northward_stress, eastward_wind, northward_wind, 
        month_from_epoch(t), lat, lon
    )

    u = uv.real
    v = uv.imag

    return jnp.stack((v, u), axis=-1)

In [None]:
delta_t = jnp.asarray(15 * 60.0)  # 15 minutes in seconds
sim_days = 6
n_steps = (sim_days * 24 * 60 * 60 / delta_t).astype(int)

n_samples = 100

In [None]:
simulator = DeterministicSimulator()

In [None]:
def simulate(
    trajectory: Trajectory, 
    duacs: Gridded, 
    era5: Gridded, 
    key: Key[jax.Array, ""]
) -> tuple[Trajectory, Trajectory, Trajectory]:
    x0 = trajectory.origin
    ts = trajectory.time.value

    solver_cls = dfx.Tsit5
    saveat = lambda: dfx.SaveAt(ts=ts)
    step_size_controller = lambda: dfx.StepTo(ts[0], ts[-1], n_steps + 1)

    sample_keys = jax.random.split(key, n_samples)

    stoch_args = (stoch_drift_model, delta_t, duacs, era5, ts[0])
    mode_args = (stoch_drift_model, duacs, era5)
    deter_args = (deter_drift_model, duacs, era5)

    stoch_sol = jax.vmap(
        lambda sample_key: simulator(
            dynamics=stoch_dynamics,
            args=(*stoch_args, jax.random.split(sample_key, n_steps)),
            x0=x0,
            ts=ts,
            solver=solver_cls(),
            saveat=saveat(),
            stepsize_controller=step_size_controller(),
            max_steps=n_steps,
        )
    )(sample_keys)

    mode_sol = simulator(
        dynamics=mode_dynamics,
        args=mode_args,
        x0=x0,
        ts=ts,
        solver=solver_cls(),
        saveat=saveat(),
        stepsize_controller=step_size_controller()
    )

    deter_sol = simulator(
        dynamics=deter_dynamics,
        args=deter_args,
        x0=x0,
        ts=ts,
        solver=solver_cls(),
        saveat=saveat(),
        stepsize_controller=step_size_controller()
    )

    return stoch_sol, mode_sol, deter_sol

In [12]:
reference_trajectories = []
stoch_results = []
mode_results = []
deter_results = []

for trajectory_batch, duacs_batch, era5_batch in tqdm(xr_jax_dataloader):
    trajectory_batch = [to_jax(arr) for arr in trajectory_batch]
    duacs_batch = [to_jax(arr) for arr in duacs_batch]
    era5_batch = [to_jax(arr) for arr in era5_batch]
    batch_size = trajectory_batch[0].shape[0]
    
    trajectory_batch = to_trajectories(trajectory_batch)
    duacs_batch = to_gridded(duacs_batch)
    era5_batch = to_gridded(era5_batch)
    
    key, subkey = jrd.split(key, 2)
    batch_keys = jrd.split(subkey, batch_size)

    batch_results = jax.vmap(simulate)(trajectory_batch, duacs_batch, era5_batch, batch_keys)

    reference_trajectories.append(trajectory_batch)
    stoch_results.append(batch_results[0])
    mode_results.append(batch_results[1])
    deter_results.append(batch_results[2])