# MnemoDyn Foundation Inference Tutorial

This notebook only covers foundation-model inference (no downstream training/classification).

In [1]:
import re
from pathlib import Path

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
import torchcde

from coe.light.model.main import LitORionModelOptimized


def extract_val_mae(name: str) -> float:
    m = re.search(r"val_mae=([-+eE0-9.]+)", name)
    if not m:
        return float("inf")
    try:
        return float(m.group(1))
    except ValueError:
        return float("inf")


def find_best_checkpoint(version_dir: Path) -> Path:
    ckpt_dir = version_dir / "checkpoints"
    ckpts = sorted(ckpt_dir.glob("*.ckpt"))
    if not ckpts:
        raise FileNotFoundError(f"No checkpoints found under: {ckpt_dir}")
    return min(ckpts, key=lambda p: extract_val_mae(p.name))


def load_dtseries(file_path: Path, target_length: int, num_parcels: int) -> torch.Tensor:
    arr = nib.load(str(file_path)).get_fdata().astype(np.float32)  # [T, D]
    if arr.ndim != 2:
        raise ValueError(f"Expected 2D dtseries array, got shape {arr.shape}")
    if arr.shape[1] != num_parcels:
        raise ValueError(f"Expected D={num_parcels}, got D={arr.shape[1]}")

    # Repeat/trim to checkpoint sequence length.
    reps = int(np.ceil(target_length / arr.shape[0]))
    arr = np.tile(arr, (reps, 1))[:target_length]
    return torch.from_numpy(arr).unsqueeze(0)  # [1, T, D]


ModuleNotFoundError: No module named 'torch'

In [None]:
# ---- Configure paths ----
version_dir = Path("/nas/vhluong/Result/Orion_450_ukbiobank/debug_GordonHCP/version_2")
dtseries_path = Path("sub-011_task-rest_space-MNI305_preproc.dtseries_Schaefer2018_400Parcels_7Networks_order_Tian_Subcortex_S3.dlabel_parcellated.dtseries.nii")
interpol = "spline"  # or "linear"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_path = find_best_checkpoint(version_dir)

lit = LitORionModelOptimized.load_from_checkpoint(str(ckpt_path), map_location=device)
lit.eval()
foundation = lit.model.to(device)
for p in foundation.parameters():
    p.requires_grad = False

seq_length = int(lit.hparams.seq_length)
num_parcels = int(lit.hparams.D)
duration = float(lit.hparams.duration)

x = load_dtseries(dtseries_path, target_length=seq_length, num_parcels=num_parcels).to(device)

if interpol == "linear":
    coeffs = torchcde.linear_interpolation_coeffs(x)
elif interpol == "spline":
    coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(x)
else:
    raise ValueError("interpol must be 'linear' or 'spline'")

time_step = torch.from_numpy(np.arange(0, duration, duration / seq_length)).float().to(device)

with torch.no_grad():
    U = foundation(x, coeffs, time_step)

print(f"Checkpoint: {ckpt_path}")
print(f"Input shape:  {tuple(x.shape)}")
print(f"Output shape: {tuple(U.shape)}")

# quick visualization for parcel 0
x_np = x[0, :, 0].detach().cpu().numpy()
u_np = U[0, :, 0].detach().cpu().numpy()

plt.figure(figsize=(10, 4))
plt.plot(x_np, label="input (parcel 0)", alpha=0.8)
plt.plot(u_np, label="foundation output (parcel 0)", alpha=0.8)
plt.title("Foundation Inference: Input vs Output")
plt.xlabel("Time")
plt.ylabel("Signal")
plt.legend()
plt.show()
