# MnemoDyn HF Inference Tutorial

Load a foundation checkpoint from Hugging Face, run inference on one demo parcellated dtseries file, and visualize input vs output.

## Environment Setup

If you are starting from a fresh machine, clone this repo and install dependencies before running inference cells.

```bash
git clone https://github.com/<your-org-or-user>/MnemoDyn_Draft.git
cd MnemoDyn_Draft
python3 -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install -r highdim_req.txt
pip install huggingface_hub
```


In [None]:
# Optional: run setup commands from notebook (Linux/macOS shell).
# Remove `echo` if you want to execute for real in this notebook environment.
!echo git clone https://github.com/<your-org-or-user>/MnemoDyn_Draft.git
!echo cd MnemoDyn_Draft
!echo python3 -m venv .venv
!echo source .venv/bin/activate
!echo pip install -U pip
!echo pip install -r highdim_req.txt
!echo pip install huggingface_hub


In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
import torchcde
from huggingface_hub import hf_hub_download

# Ensure repository root is on sys.path when running notebook directly.
repo_root = Path.cwd()
while repo_root != repo_root.parent and not (repo_root / 'coe').exists():
    repo_root = repo_root.parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

from coe.light.model.main import LitORionModelOptimized


In [None]:
# ---- Configure HF model + demo data path ----
HF_MODEL_REPO_ID = 'vhluong/MnemoDyn'
HF_CKPT_PATH_IN_REPO = 'Orion_333/model.ckpt'  # e.g. Orion_333/model.ckpt
HF_REVISION = 'main'

# Demo parcellated dtseries file (local).
DEMO_DTSERIES_PATH = Path('/nas/vhluong/ds005747-download/dtseries/sub-011/sub-011_task-rest_space-MNI305_preproc.dtseries_Schaefer2018_400Parcels_7Networks_order_Tian_Subcortex_S3.dlabel_parcellated.dtseries.nii')

# Interpolation must match model expectation.
INTERPOL = 'spline'  # or 'linear'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


In [None]:
def load_dtseries(file_path: Path, target_length: int, num_parcels: int) -> torch.Tensor:
    arr = nib.load(str(file_path)).get_fdata().astype(np.float32)  # [T, D]
    if arr.ndim != 2:
        raise ValueError(f'Expected 2D dtseries array, got {arr.shape}')
    if arr.shape[1] != num_parcels:
        raise ValueError(f'Expected D={num_parcels}, got D={arr.shape[1]}')

    reps = int(np.ceil(target_length / arr.shape[0]))
    arr = np.tile(arr, (reps, 1))[:target_length]
    return torch.from_numpy(arr).unsqueeze(0)  # [1, T, D]


ckpt_path = hf_hub_download(
    repo_id=HF_MODEL_REPO_ID,
    filename=HF_CKPT_PATH_IN_REPO,
    revision=HF_REVISION,
)
print('Downloaded checkpoint:', ckpt_path)

lit = LitORionModelOptimized.load_from_checkpoint(ckpt_path, map_location=device)
lit.eval()
foundation = lit.model.to(device)
for p in foundation.parameters():
    p.requires_grad = False

seq_length = int(lit.hparams.seq_length)
num_parcels = int(lit.hparams.D)
duration = float(lit.hparams.duration)

if not DEMO_DTSERIES_PATH.exists():
    raise FileNotFoundError(f'Demo dtseries not found: {DEMO_DTSERIES_PATH}')

x = load_dtseries(DEMO_DTSERIES_PATH, target_length=seq_length, num_parcels=num_parcels).to(device)

if INTERPOL == 'linear':
    coeffs = torchcde.linear_interpolation_coeffs(x)
elif INTERPOL == 'spline':
    coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(x)
else:
    raise ValueError("INTERPOL must be 'linear' or 'spline'")

time_step = torch.from_numpy(np.arange(0, duration, duration / seq_length)).float().to(device)

with torch.no_grad():
    U = foundation(x, coeffs, time_step)

print('Input shape :', tuple(x.shape))
print('Output shape:', tuple(U.shape))


In [None]:
# Visualize parcel 0 (input vs foundation output)
parcel_idx = 0
x_np = x[0, :, parcel_idx].detach().cpu().numpy()
u_np = U[0, :, parcel_idx].detach().cpu().numpy()

plt.figure(figsize=(10, 4))
plt.plot(x_np, label='input', alpha=0.8)
plt.plot(u_np, label='foundation output', alpha=0.8)
plt.title(f'Parcel {parcel_idx}: Input vs Foundation Output')
plt.xlabel('Time')
plt.ylabel('Signal')
plt.legend()
plt.show()

# Optional: global reconstruction error summary
mae = torch.mean(torch.abs(U - x)).item()
mse = torch.mean((U - x) ** 2).item()
print(f'MAE: {mae:.6f}')
print(f'MSE: {mse:.6f}')
