In [None]:
import sys
import subprocess
import os
import json
from pathlib import Path

if 'google.colab' in sys.modules:
    print("Running on Colab")

    subprocess.run([
        'git', 'clone', 'https://github.com/walligot/don_thesis.git'
    ])

    os.chdir('/content/don_thesis')
    #%pip install git+https://github.com/mitchellostrow/DSA.git
    %pip install git+https://github.com/Melina-Jingting/foundational_ssm.git

    from google.colab import drive
    drive.mount('/content/drive')

    os.environ['HOME'] = '/content/drive/MyDrive/Thesis'
    ROOT_PATH = '/content/don_thesis'
    os.environ['ROOT_PATH'] = ROOT_PATH

    wandb_config_path = '/content/drive/MyDrive/Colab/wandb.config.json'

else:
    #%pip install git+https://github.com/mitchellostrow/DSA.git

    current_path = Path().resolve()
    ROOT_PATH = None
    for parent in [current_path] + list(current_path.parents):
        if "don_thesis" in parent.name.lower():
            ROOT_PATH = parent
            os.environ['ROOT_PATH'] = str(ROOT_PATH)
            break

    if not ROOT_PATH:
        raise FileNotFoundError("Directory with name 'don_thesis' not found.")

    print("Running locally or elsewhere")
    wandb_config_path = os.path.join(os.environ['ROOT_PATH'], 'config', 'wandb.config.json')

# Set WANDB_CONFIG_PATH
os.environ['WANDB_CONFIG_PATH'] = wandb_config_path

# Load API key from JSON
with open(wandb_config_path) as f:
    config = json.load(f)
    os.environ['WANDB_API_KEY'] = config['WANDB_API_KEY']

%pip install pynwb
%pip install equinox
%pip uninstall -y temporaldata
%pip install git+https://github.com/Melina-Jingting/temporaldata.git@melina-resample-irregular
#%pip install equinox==0.12.2 jax==0.7.0 jaxlib==0.7.0

print(f"Root path: {ROOT_PATH}")
print(f"WANDB config path: {wandb_config_path}")

Running on Colab
Collecting git+https://github.com/Melina-Jingting/foundational_ssm.git
  Cloning https://github.com/Melina-Jingting/foundational_ssm.git to /tmp/pip-req-build-fmjzcbk3
  Running command git clone --filter=blob:none --quiet https://github.com/Melina-Jingting/foundational_ssm.git /tmp/pip-req-build-fmjzcbk3
  Resolved https://github.com/Melina-Jingting/foundational_ssm.git to commit 0946f85db8cb835434ac12f50e22c187e8824699
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pytorch_brain (from foundational_ssm==0.1.0)
  Downloading pytorch_brain-0.1.0-py3-none-any.whl.metadata (4.1 kB)
Collecting equinox (from foundational_ssm==0.1.0)
  Downloading equinox-0.13.0-py3-none-any.whl.metadata (18 kB)
Collecting jaxtyping>=0.2.20 (from equinox->foundational_ssm==0.1.0)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting wadl

In [None]:
import foundational_ssm.models.s5 as s5  # adjust to actual package path

def patched_call_with_activations(self, x, state, layer_keys):
    activations = {}
    def capture(k, v):
        if layer_keys is None or k in layer_keys:
            activations[k] = v

    # normalisation
    x, state = self.norm(x.T, state)
    x = x.T

    # SSM forward with activations
    ssm_y, ssm_x = self.ssm.call_with_activations(x)
    capture("ssm_x", ssm_x)
    capture("ssm_y", ssm_y)

    # GELU + GLU
    post_gelu = jax.nn.gelu(ssm_y)
    capture("ssm_post_gelu", post_gelu)

    post_glu = jax.vmap(self.glu)(post_gelu)
    capture("ssm_post_glu", post_glu)

    # return the value to feed into the next block
    return post_glu, state, activations

# Patch the class
s5.S5Block.call_with_activations = patched_call_with_activations

In [None]:
import wandb
import equinox as eqx
import os

# Foundational SSM imports
from omegaconf import OmegaConf
import tempfile
from foundational_ssm.models import SSMDownstreamDecoder, SSMFoundationalDecoder
from foundational_ssm.utils import h5_to_dict
from foundational_ssm.transform import smooth_spikes
import jax
import jax.numpy as jnp
import numpy as np
from typing import Any, BinaryIO


#%load_ext autoreload
#%autoreload 2

def default_deserialise_filter_spec(f: BinaryIO, x: Any) -> Any:
    """Default filter specification for deserialising saved data.

    **Arguments**

    -   `f`: file-like object
    -   `x`: The leaf for which the data needs to be loaded.

    **Returns**

    The new value for datatype `x`.

    !!! info

        This function can be extended to customise the deserialisation behaviour for
        leaves.

    !!! example

        Skipping loading of jax.Array.

        ```python
        import jax.numpy as jnp
        import equinox as eqx

        tree = (jnp.array([4,5,6]), [1,2,3])
        new_filter_spec = lambda f,x: (
            x if isinstance(x, jax.Array) else eqx.default_deserialise_filter_spec(f, x)
        )
        new_tree = eqx.tree_deserialise_leaves("some_filename.eqx", tree, filter_spec=new_filter_spec)
        ```
    """  # noqa: E501
    try:
        if isinstance(x, (jax.Array, jax.ShapeDtypeStruct)):
            return jnp.load(f)
        elif isinstance(x, np.ndarray):
            # Important to use `np` here to avoid promoting NumPy arrays to JAX.
            return np.load(f)
        elif eqx.is_array_like(x):
            # np.generic gets deserialised directly as an array, so convert back to a scalar
            # type here.
            # See also https://github.com/google/jax/issues/17858
            out = np.load(f)
            if isinstance(x, jax.dtypes.bfloat16):
                out = out.view(jax.dtypes.bfloat16)
            if np.size(out) == 1:
                return type(x)(out.item())
        else:
            return x
    except:
        print("Failed to load data for leaf with shape/ value:", x.shape if hasattr(x, 'shape') else x)
        return x

def load_model_and_state_from_checkpoint_wandb(artifact_full_name, model_cls=SSMDownstreamDecoder, model_cfg=None):
    """Load model, optimizer state, epoch, and step from a checkpoint file."""
    api = wandb.Api()
    try:
        artifact = api.artifact(artifact_full_name, type="checkpoint")
    except Exception as e:
        raise FileNotFoundError(f"Could not find checkpoint artifact: {artifact_full_name}")

    if model_cfg is None:
        run = artifact.logged_by()
        run_cfg = OmegaConf.create(run.config)
        print(run_cfg)
        model_cfg = OmegaConf.create(run_cfg.model)

    model_template, state_template = eqx.nn.make_with_state(model_cls)(
        **model_cfg
    )

    with tempfile.TemporaryDirectory() as temp_dir:
        artifact.download(temp_dir)
        model = eqx.tree_deserialise_leaves(os.path.join(temp_dir, "model.ckpt"), model_template, default_deserialise_filter_spec)
        state = eqx.tree_deserialise_leaves(os.path.join(temp_dir, "state.ckpt"), state_template, default_deserialise_filter_spec)

    meta = artifact.metadata
    return model, state, meta

# Downstream Model

In [None]:
layer = "2"
pretrain_mode = "scratch"
train_mode = "all"
alias = "best" # can be latest/best/ epoch_{any value in range(0,1000,100)}
# epoch 0 now stores a fresh model.
artifact_full_name = f"melinajingting-ucl/foundational_ssm_rtt/l{layer}_{pretrain_mode}_{train_mode}_checkpoint:{alias}"
model_2_block, state_2_block, meta = load_model_and_state_from_checkpoint_wandb(artifact_full_name)

layer = "4"
artifact_full_name = f"melinajingting-ucl/foundational_ssm_rtt/l{layer}_{pretrain_mode}_{train_mode}_checkpoint:{alias}"
model_4_block, state_4_block, meta = load_model_and_state_from_checkpoint_wandb(artifact_full_name)

NameError: name 'load_model_and_state_from_checkpoint_wandb' is not defined

## Calling with activations (Downstream)

In [None]:
import copy
# Download mc_rtt_trialized from https://huggingface.co/datasets/MelinaLaimon/nlb_processed/tree/main
# Edit dataset_dir to your directory
dataset_dir = "/content/drive/MyDrive/Thesis/data/"
dataset_path = os.path.join(dataset_dir, "mc_rtt_trialized.h5")
data = h5_to_dict(dataset_path)
data["neural_input_raw"] = copy.deepcopy(data["neural_input"])
data["neural_input"] = smooth_spikes(data["neural_input"], kern_sd_ms=20, bin_size_ms=5, time_axis=1)
input = data["neural_input"]
target_vel = data["behavior_input"]
data["targets"] = copy.deepcopy(data["behavior_input"])

# Specify the layers you want to generate the activations of.
# ["post_encoder", "ssm_pre_activation", "ssm_post_activation"]
layer_keys = ["ssm_x", "ssm_y", "ssm_post_glu"]
inf_model = eqx.nn.inference_mode(model_2_block) # Switches off dropout
pred_vel, _, activations_2_block = jax.vmap(inf_model.call_with_activations, axis_name="batch", in_axes=(0, None, None))(input, state_2_block, layer_keys)
activations_2_block['neural_input_raw'] = data['neural_input_raw']
activations_2_block['neural_input'] = data['neural_input']
activations_2_block['targets'] = data['targets']
activations_2_block_dict = {}
activations_2_block_dict['mc_rtt'] = activations_2_block

inf_model = eqx.nn.inference_mode(model_4_block) # Switches off dropout
pred_vel, _, activations_4_block = jax.vmap(inf_model.call_with_activations, axis_name="batch", in_axes=(0, None, None))(input, state_4_block, layer_keys)
activations_4_block['neural_input_raw'] = data['neural_input_raw']
activations_4_block['neural_input'] = data['neural_input']
activations_4_block['targets'] = data['targets']
activations_4_block_dict = {}
activations_4_block_dict['mc_rtt'] = activations_4_block

In [None]:
np.savez(dataset_dir + "activations_rtt_2block_20250831_2.npz", **activations_2_block_dict)
np.savez(dataset_dir + "activations_rtt_4block_20250831_2.npz", **activations_4_block_dict)

In [None]:
activations_2_block_dict['mc_rtt'].keys()

dict_keys(['ssm_post_gelu_0', 'ssm_post_gelu_1', 'ssm_x_0', 'ssm_x_1', 'ssm_y_0', 'ssm_y_1', 'neural_input_raw', 'neural_input', 'targets'])

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import RidgeCV

def add_prev_timestep(data):
    """
    Append previous timestep features to each timestep within a trial.

    Parameters
    ----------
    data : np.ndarray, shape (n_trials, n_timesteps, n_dims)

    Returns
    -------
    out : np.ndarray, shape (n_trials, n_timesteps, 2*n_dims)
          For t=0, previous timestep is all zeros.
    """
    N, T, D = data.shape

    # Shift along time axis within each trial
    prev = np.zeros_like(data)
    prev[:, 1:, :] = data[:, :-1, :]

    # Concatenate current and previous features
    out = np.concatenate([prev, data], axis=2)
    return out

def simple_linear_decoder_by_trial(hidden_states, behaviour, test_size=0.2,
                                   alphas=[1e-3, 1e-2, 1e-1, 1], top_n=10):
    """
    Trains a linear decoder from hidden states to behaviour, splitting by trial.

    Parameters
    ----------
    hidden_states : array, shape (n_trials, n_timesteps, n_features)
    behaviour     : array, shape (n_trials, n_timesteps, n_outputs)
                    or (n_trials, n_timesteps) for single output
    test_size     : float, fraction of trials to hold out
    alphas        : list, RidgeCV regularisation strengths
    top_n         : int, number of top features to return

    Returns
    -------
    R2 : float, held-out R^2 score (mean over outputs if multi-output)
    weights : np.ndarray, decoder coefficients (n_outputs, n_features)
              or (n_features,) if single output
    top_indices : np.ndarray, indices of top N contributing features
    """
    hs = np.asarray(hidden_states)
    beh = np.asarray(behaviour)
    N, T, H = hs.shape

    # Handle behaviour shape
    if beh.ndim == 2:  # (N, T) single output
        B = 1
        beh = beh[:, :, None]  # add output dim
    elif beh.ndim == 3:  # (N, T, B)
        B = beh.shape[2]
    else:
        raise ValueError("behaviour must be (N, T) or (N, T, B)")

    # Split by trial
    trial_indices = np.arange(N)
    train_trials, test_trials = train_test_split(
        trial_indices, test_size=test_size, random_state=0
    )

    # Flatten over timesteps within each split
    Xtr = hs[train_trials].reshape(-1, H)
    Xte = hs[test_trials].reshape(-1, H)
    ytr = beh[train_trials].reshape(-1, B)
    yte = beh[test_trials].reshape(-1, B)

    # Fit ridge regression
    decoder = RidgeCV(alphas=alphas).fit(Xtr, ytr)
    R2 = decoder.score(Xte, yte)

    # Feature importance
    importance = np.abs(decoder.coef_).sum(axis=0)  # sum over outputs if multi-output
    top_indices = np.argsort(importance)[::-1][:top_n]

    # Return shape for weights: squeeze if single output
    weights = decoder.coef_.squeeze() if B == 1 else decoder.coef_

    return R2, weights, top_indices

In [None]:
from sklearn.metrics import r2_score
encode_t = jax.vmap(model_2_block.decoder, in_axes=0)  # over time
encode_bt = jax.vmap(encode_t, in_axes=0)
decoded = encode_bt(activations_2_block['ssm_post_glu_1'])
decoded_flat = decoded.reshape(-1, decoded.shape[-1])
input_flat = data['behavior_input'].reshape(-1, data['behavior_input'].shape[-1])

decoded.reshape(-1, decoded.shape[-1])
data['behavior_input'].reshape(-1, data['behavior_input'].shape[-1])

r2_score(decoded_flat, input_flat)

-137.2387528681632

In [None]:
from sklearn.metrics import r2_score
r2_score(data["neural_input"], data['behavior_input'])

ValueError: Found array with dim 3. None expected <= 2.

## Example: Plotting Output

In [None]:
import pandas as pd
from foundational_ssm.plotting import aggregate_bin_label_results, plot_pred_vs_targets_by_angle_bin


# Download mc_rtt_trialized from https://huggingface.co/datasets/MelinaLaimon/nlb_processed/tree/main
# Edit dataset_dir to your directory
dataset_dir = "../../data/foundational_ssm/processed/nlb"
trial_info = pd.read_csv(os.path.join(dataset_dir, "mc_rtt_trialized.csv"))
dataset_path = os.path.join(dataset_dir, "mc_rtt_trialized.h5")
data = h5_to_dict(dataset_path)
data["neural_input"] = smooth_spikes(data["neural_input"], kern_sd_ms=20, bin_size_ms=5, time_axis=1)
input = data["neural_input"]
target_vel = data["behavior_input"]

# Specify the layers you want to generate the activations of.
# ["post_encoder", "ssm_pre_activation", "ssm_post_activation"]
layer_keys = ["ssm_pre_activation"]
inf_model = eqx.nn.inference_mode(model) # Switches off dropout
pred_vel, _, activations = jax.vmap(inf_model.call_with_activations, axis_name="batch", in_axes=(0, None, None))(input, state, layer_keys)

results_df = aggregate_bin_label_results(trial_info, target_vel, pred_vel)
fig = plot_pred_vs_targets_by_angle_bin(results_df)
fig.show()

# Foundational Model

In [None]:

model = "l2"
#dataset = "reaching_normalized"
dataset = "reaching"
#alias = "best"
alias = "latest"

artifact_full_name = f"melinajingting-ucl/foundational_ssm_pretrain/{model}_{dataset}_checkpoint:{alias}"
foundational_model_2block, foundational_state_2block, meta = load_model_and_state_from_checkpoint_wandb(artifact_full_name, model_cls=SSMFoundationalDecoder)

model = "l4"
artifact_full_name = f"melinajingting-ucl/foundational_ssm_pretrain/{model}_{dataset}_checkpoint:{alias}"
foundational_model_4block, foundational_state_4block, meta = load_model_and_state_from_checkpoint_wandb(artifact_full_name, model_cls=SSMFoundationalDecoder)

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdavekk[0m ([33mdavekk-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'output_dim': 2, 'ssm_io_dim': 256, 'context_dim': 0, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'masking'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_pretrain', 'resume_run_id': None}, 'rng_seed': 42, 'training': {'epochs': 501, 'log_val_every': 50, 'checkpoint_every': 1}, 'model_cfg': 'configs/model/l2_no_context.yaml', 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'val_loader': {'sampler': 'TrialSampler', 'dataloader_args': {'batch_size': 512, 'num_workers': 4, 'persistent_workers': True}}, 'dataset_args': {'config': 'configs/dataset/reaching.yaml'}, 'train_loader': {'sampler': 'RandomVariableWindowSampler', 'sampler_args': {'drop_short': True, 'max_window_length': 5, 'min_window_length': 1}, 'dataloader_args': {'batch_size': 512, 'num_workers': 20, 'persistent_workers': True}}, 'sampling_rate': 200, 'prep

[34m[1mwandb[0m:   3 of 3 files downloaded.  


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.01, 'output_dim': 2, 'ssm_io_dim': 256, 'context_dim': 0, 'ssm_num_layers': 4, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'masking'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_pretrain', 'resume_run_id': None}, 'rng_seed': 42, 'training': {'epochs': 501, 'log_val_every': 50, 'checkpoint_every': 1}, 'model_cfg': 'configs/model/l4.yaml', 'optimizer': {'lr': 0.001, 'mode': 'all', 'weight_decay': 0.01}, 'val_loader': {'sampler': 'TrialSampler', 'dataloader_args': {'batch_size': 512, 'num_workers': 4, 'persistent_workers': True}}, 'dataset_args': {'config': 'configs/dataset/reaching.yaml'}, 'train_loader': {'sampler': 'RandomVariableWindowSampler', 'sampler_args': {'drop_short': True, 'max_window_length': 5, 'min_window_length': 1}, 'dataloader_args': {'batch_size': 512, 'num_workers': 20, 'persistent_workers': True}}, 'sampling_rate': 200, 'prepend_history

[34m[1mwandb[0m:   3 of 3 files downloaded.  


In [None]:
blk = foundational_model_2block.ssm_blocks[0]
print(blk.call_with_activations.__func__ is patched_call_with_activations)  # True
print(blk.call_with_activations.__name__)

True
patched_call_with_activations


## Loading the dataset

In [None]:
from functools import partial
from torch.utils.data import DataLoader

from foundational_ssm.constants import DATA_ROOT, MAX_NEURAL_UNITS, DATASET_GROUP_INFO
from foundational_ssm.dataset import TorchBrainDataset
from foundational_ssm.transform import transform_brainsets_regular_time_series_smoothed, parse_session_id
from foundational_ssm.collate import pad_collate
import foundational_ssm.samplers as samplers
import numpy as np

from typing import Dict, Any, Tuple

import numpy as np
import torch
import re

from foundational_ssm.constants import (
    DATASET_GROUP_TO_IDX,
    MAX_NEURAL_UNITS,
    MAX_BEHAVIOR_DIM,
    DATASET_IDX_TO_STD
)
from foundational_ssm.spikes import bin_spikes, smooth_spikes

def parse_session_id(session_id: str) -> Tuple[str, str, str]:
    patterns = {
        "churchland_shenoy_neural_2012": re.compile(r"([^/]+)/([^_]+)_[0-9]+_(.+)"),
        "flint_slutzky_accurate_2012": re.compile(r"([^/]+)/monkey_([^_]+)_e1_(.+)"),
        "odoherty_sabes_nonhuman_2017": re.compile(r"([^/]+)/([^_]+)_[0-9]{8}_[0-9]+"),
        "pei_pandarinath_nlb_2021": re.compile(r"([^/]+)/([^_]+)_(.+)"),
        "perich_miller_population_2018": re.compile(r"([^/]+)/([^_]+)_[0-9]+_(.+)"),
    }

    dataset = session_id.split('/')[0]
    if dataset not in patterns:
        raise ValueError(f"Unknown dataset: {dataset}")

    match = patterns[dataset].match(session_id)
    if not match:
        raise ValueError(f"Could not parse session_id: {session_id!r}")

    if dataset == "odoherty_sabes_nonhuman_2017":
        # Always assign task as 'random_target_reaching'
        _, subject = match.groups()
        return dataset, subject, "random_target_reaching"
    elif dataset == "flint_slutzky_accurate_2012":
        # task is always 'center_out_reaching'
        _, subject, _ = match.groups()
        return dataset, subject, "center_out_reaching"
    else:
        return match.groups()

def _ensure_dim(arr: np.ndarray, target_dim: int, *, axis: int = 1) -> np.ndarray:
    """Crop or zero-pad *arr* along *axis* to match *target_dim*.

    This is a thin wrapper around :pymod:`numpy` slicing and :func:`numpy.pad` that
    avoids several conditional blocks in the main routine.
    """
    current_dim = arr.shape[axis]
    if current_dim == target_dim:
        return arr  # nothing to do
    if current_dim > target_dim:
        # Crop
        slicer = [slice(None)] * arr.ndim
        slicer[axis] = slice(None, target_dim)
        return arr[tuple(slicer)]
    # Pad (current_dim < target_dim)
    pad_width = [(0, 0)] * arr.ndim
    pad_width[axis] = (0, target_dim - current_dim)
    return np.pad(arr, pad_width, mode="constant")

def transform_brainsets_regular_time_series_raw(
    data: Any,
    *,
    max_neural_units: int = MAX_NEURAL_UNITS,
    sampling_rate: int = 200,
) -> Dict[str, torch.Tensor | str]:
    """
    Like `transform_brainsets_regular_time_series_smoothed` but WITHOUT smoothing.
    Produces raw binned spike counts at `sampling_rate`, aligned with behaviour.
    """
    # ----------------------------
    # 1) Raw binned spikes (no smoothing)
    # ----------------------------
    # Always bin from spike indices to avoid using any pre-smoothed fields.
    binned_spikes, _ = data.spikes.get_regular_time_series_array(
        sampling_rate=sampling_rate,
        raw_array_name="unit_index",
        is_index=True,            # yields raw spike counts per bin
    )  # shape: (timesteps, units)

    # ----------------------------
    # 2) Behaviour (cursor/hand velocity)
    # ----------------------------
    if hasattr(data, "vel_regular"):
        behavior_input = data.vel_regular.data
    else:
        if data.session.id.startswith("pei_pandarinath_nlb_2021"):
            behavior_input, _ = data.hand.get_regular_time_series_array(
                sampling_rate=sampling_rate,
                raw_array_name="vel",
            )
        else:
            behavior_input, _ = data.cursor.get_regular_time_series_array(
                sampling_rate=sampling_rate,
                raw_array_name="vel",
            )

    # ----------------------------
    # 3) Drop timesteps with invalid behaviour
    # ----------------------------
    valid_mask = ~(np.isinf(behavior_input).any(axis=1) | np.isnan(behavior_input).any(axis=1))
    if not valid_mask.all():
        behavior_input = behavior_input[valid_mask]
        binned_spikes = binned_spikes[valid_mask]

    # ----------------------------
    # 4) Group/index info, padding & normalisation
    # ----------------------------
    dataset, subject, task = parse_session_id(data.session.id)
    group_tuple = (dataset, subject, task)
    try:
        group_idx = DATASET_GROUP_TO_IDX[group_tuple]
    except Exception:
        group_idx = 9

    match = re.findall(r"\d+", data.session.id.split("/")[1])
    session_date = int("".join(match)) if len(match) > 0 else 0

    # pad/crop neural to max_neural_units
    neural_input = _ensure_dim(binned_spikes, max_neural_units, axis=1)

    # normalise behaviour by dataset std, then pad/crop to MAX_BEHAVIOR_DIM
    behavior_input = behavior_input / DATASET_IDX_TO_STD[group_idx]
    behavior_input = _ensure_dim(behavior_input, MAX_BEHAVIOR_DIM, axis=1)

    # ----------------------------
    # 5) Pack tensors
    # ----------------------------
    return {
        "neural_input": torch.as_tensor(neural_input, dtype=torch.float32),
        "behavior_input": torch.as_tensor(behavior_input, dtype=torch.float32),
        "dataset_group_idx": torch.as_tensor(group_idx, dtype=torch.int32),
        "session_date": torch.as_tensor(session_date, dtype=torch.int32),
    }

def get_brainset_data_loader_raw(
    dataset_args,
    dataloader_args,
    sampler,
    split = None,
    sampler_args = {},
    sampling_rate = 200,
    prepend_history = 0,
    data_root = DATA_ROOT,
):
    dataset = TorchBrainDataset(
        root=data_root,                # root directory where .h5 files are found
        **dataset_args,
        split=split
    )

    sampling_intervals = dataset.get_sampling_intervals()
    sampler_cls = getattr(samplers, sampler)
    sampler = sampler_cls(
        sampling_intervals=sampling_intervals,
        **(sampler_args or {}),
        prepend_history=prepend_history
    )
    max_neural_units = int(np.max( [DATASET_GROUP_INFO[parse_session_id(k)]["max_num_units"] for k in sampling_intervals.keys()]))
    #dataset.transform = partial(transform_brainsets_regular_time_series_smoothed, sampling_rate=sampling_rate, max_neural_units=max_neural_units)
    dataset.transform = partial(transform_brainsets_regular_time_series_raw, sampling_rate=sampling_rate, max_neural_units=max_neural_units)
    total_window_length = sampler_args.get('window_length', sampler_args.get('max_window_length', 1)) + prepend_history  # Default to 1 if not provided
    loader = DataLoader(
        dataset=dataset,      # dataset
        sampler=sampler,      # sampler
        collate_fn=partial(pad_collate, fixed_seq_len=int(total_window_length*sampling_rate)),         # the collator
        pin_memory=True,
        **dataloader_args
    )
    return dataset, loader, max_neural_units

In [None]:
import multiprocessing as mp

# Foundational SSM core imports
from foundational_ssm.loaders import get_brainset_data_loader, get_brainset_train_val_loaders
#from foundational_ssm.constants import DATA_ROOT
from foundational_ssm.samplers import TrialSampler
import os

data_root = '/content/drive/MyDrive/Thesis/data/foundational_ssm/processed'

dataset, loader, max_neural_input = get_brainset_data_loader(
    dataset_args = {
        'keep_files_open': False,
        'lazy': True,
        'config': '/content/drive/MyDrive/Thesis/data/reaching.yaml'
    },
    dataloader_args={
        'batch_size': 1500,
        'num_workers': 0,
        'persistent_workers': False,
    },
    sampler = 'TrialSampler',
    sampler_args = {
        'max_window_length': 5.0
    },
    data_root = data_root,
    prepend_history = 0.3,
    sampling_rate = 200,
    split = 'val_trial' #train_trial overlaps with training data. use val_trial, test_trial if you want to observe what happens outside
 )

dataset_raw, loader_raw, max_neural_input_raw = get_brainset_data_loader_raw(
    dataset_args = {
        'keep_files_open': False,
        'lazy': True,
        'config': '/content/drive/MyDrive/Thesis/data/reaching.yaml'
    },
    dataloader_args={
        'batch_size': 1500,
        'num_workers': 0,
        'persistent_workers': False,
    },
    sampler = 'TrialSampler',
    sampler_args = {
        'max_window_length': 5.0
    },
    data_root = data_root,
    prepend_history = 0.3,
    sampling_rate = 200,
    split = 'val_trial' #train_trial overlaps with training data. use val_trial, test_trial if you want to observe what happens outside
 )

In [None]:
from foundational_ssm.constants import DATASET_IDX_TO_GROUP_SHORT
from tqdm import tqdm
input_by_dataset = {}
skip_timesteps=56

for batch_idx, batch in enumerate(tqdm(loader_raw, desc="Batches")):
        batch = {k: jax.device_put(np.array(v)) for k, v in batch.items()}
        dataset_group_idxs = batch["dataset_group_idx"]
        inputs = batch["neural_input"]
        for i in range(inputs.shape[0]):
            ds_id = int(dataset_group_idxs[i])

            if ds_id not in input_by_dataset:
                input_by_dataset[ds_id] = {"neural_input_raw": []}
            inputs_sliced = inputs[i][skip_timesteps:]

            input_by_dataset[ds_id]["neural_input_raw"].append(inputs_sliced)

for ds_id in input_by_dataset:
        input_by_dataset[ds_id] = {
            k: jnp.stack(v, axis=0) for k, v in input_by_dataset[ds_id].items()
        }

input_by_dataset = {str(k): np.array(v) for k, v in input_by_dataset.items()}

np.savez("/content/drive/MyDrive/Thesis/data/f_input_raw_20250910_1.npz", **input_by_dataset)

Batches: 100%|██████████| 1/1 [00:15<00:00, 15.47s/it]


In [None]:
print(np.count_nonzero(input_by_dataset['pm_c_co']['neural_input_raw']))
print(input_by_dataset['pm_c_co']['neural_input_raw'].size)
#input_by_dataset.keys()

29050
11210000


In [None]:
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx

def collect_activations_by_dataset(val_loader, model, state, layer_keys, skip_timesteps=0):
    inference_model = eqx.nn.inference_mode(model)

    activations_by_dataset = {}

    testidxs = list(val_loader.sampler)  # or val_loader.sampler
    print("batches announced:", len(testidxs))

    for batch_idx, batch in enumerate(tqdm(val_loader, desc="Batches")):
        batch = {k: jax.device_put(np.array(v)) for k, v in batch.items()}
        dataset_group_idxs = batch["dataset_group_idx"]
        inputs = batch["neural_input"]
        targets = batch["behavior_input"]
        mask = batch["mask"][..., None]

        state_out = state

        for i in range(inputs.shape[0]):
            pred, state_out, acts = inference_model.call_with_activations(
                inputs[i], state_out, dataset_group_idxs[i], layer_keys
            )

            # Apply skip_timesteps
            #acts0 = acts[0][skip_timesteps:]
            #acts1 = acts[1][skip_timesteps:]
            #targs = targets[i][skip_timesteps:]

            ds_id = int(dataset_group_idxs[i])

            if ds_id not in activations_by_dataset:
                activations_by_dataset[ds_id] = {k: [] for k in acts.keys()}
                activations_by_dataset[ds_id]["targets"] = []
                activations_by_dataset[ds_id]["neural_input"] = []
                activations_by_dataset[ds_id]["mask"] = []

            # slice acts and targets
            acts_sliced = {k: v[skip_timesteps:, :] for k, v in acts.items()}
            targets_sliced = targets[i][skip_timesteps:, :]
            inputs_sliced = inputs[i][skip_timesteps:, :]
            mask_sliced = mask[i][skip_timesteps:, :]

            # append dynamically for all keys except "targets"
            for k, v in acts_sliced.items():
                activations_by_dataset[ds_id][k].append(v)

            # handle targets separately
            activations_by_dataset[ds_id]["targets"].append(targets_sliced)
            activations_by_dataset[ds_id]["neural_input"].append(inputs_sliced)
            activations_by_dataset[ds_id]["mask"].append(mask_sliced)

    # Stack into arrays
    for ds_id in activations_by_dataset:
        activations_by_dataset[ds_id] = {
            k: jnp.stack(v, axis=0) for k, v in activations_by_dataset[ds_id].items()
        }

    return activations_by_dataset


In [None]:
layer_keys = ["ssm_x", "ssm_y", "ssm_post_glu"]
activations_2block = collect_activations_by_dataset(
    val_loader=loader,   # your DataLoader
    model=foundational_model_2block,          # model with patched call_with_activations
    state=foundational_state_2block,              # initial recurrent state
    layer_keys=layer_keys,
    skip_timesteps=56         # same skip as before, or 0 for none
)

activations_4block = collect_activations_by_dataset(
    val_loader=loader,   # your DataLoader
    model=foundational_model_4block,          # model with patched call_with_activations
    state=foundational_state_4block,              # initial recurrent state
    layer_keys=layer_keys,
    skip_timesteps=56         # same skip as before, or 0 for none
)

batches announced: 357


Batches: 100%|██████████| 1/1 [01:15<00:00, 75.55s/it]


batches announced: 357


Batches: 100%|██████████| 1/1 [01:43<00:00, 103.76s/it]


In [None]:
from foundational_ssm.constants import DATASET_IDX_TO_GROUP_SHORT

activations_group_2block = {str(k): v for k, v in activations_2block.items()}
np.savez("/content/drive/MyDrive/Thesis/data/activations_reaching_2block_20250910_1.npz", **activations_group_2block)
activations_group_4block = {str(k): v for k, v in activations_4block.items()}
np.savez("/content/drive/MyDrive/Thesis/data/activations_reaching_4block_20250910_1.npz", **activations_group_4block)


In [None]:
activations_group_2block = {DATASET_IDX_TO_GROUP_SHORT[k]: v for k, v in activations_2block.items()}
activations_group_4block = {DATASET_IDX_TO_GROUP_SHORT[k]: v for k, v in activations_4block.items()}

In [None]:
print(activations_2block[0]['ssm_y_1'].shape)
print(activations_2block[0]['neural_input'].shape)
print(activations_2block[0]['targets'].shape)

In [None]:
[input_by_dataset[k]['neural_input_raw'].shape for k in activations_group_2block.keys()]

[(78, 944, 625),
 (19, 944, 625),
 (20, 944, 625),
 (12, 944, 625),
 (17, 944, 625),
 (47, 944, 625),
 (40, 944, 625),
 (124, 944, 625)]

In [None]:
np.count_nonzero(activations_group_2block['pm_c_co']['neural_input'])

355518

In [None]:
it = iter(loader)
for b in range(1000000):
    try:
        batch = next(it)
    except StopIteration:
        print("actually yielded:", b)  # probably 23
        break
    except Exception as e:
        print("error at batch", b, "->", repr(e))
        break

NameError: name 'loader' is not defined

In [None]:
import multiprocessing as mp

# Foundational SSM core imports
from foundational_ssm.loaders import get_brainset_data_loader, get_brainset_train_val_loaders
#from foundational_ssm.constants import DATA_ROOT
from foundational_ssm.samplers import TrialSampler
import os

mp.set_start_method("spawn", force=True) # otherwise causes deadlock on jax.

data_root = '/content/drive/MyDrive/Thesis/data/' + DATA_ROOT # change to the folder holding the brainsets
#DATA_ROOT = '/content/drive/MyDrive/Thesis/data/foundational_ssm/processed'
config_dir = '/content/drive/MyDrive/Thesis/data/' # change
dataset_args = {
    'keep_files_open': False,
    'lazy': True,
    'config' : os.path.join(config_dir, 'reaching_analysis.yaml'),
    #'split': 'val' # or 'train'
}
dataloader_args = {
    'batch_size': 128, # Adjust per your system capacity
    'num_workers': 4,
    'persistent_workers': False
}


sampler = 'SequentialFixedWindowSampler'
sampler_args = {
                'window_length': 3.279,
                'drop_short': True
                }

loader_cfg = {'sampler': sampler, 'sampler_args': sampler_args, 'dataloader_args': dataloader_args}

#dataset, data_loader = get_brainset_data_loader(
#    dataset_args=dataset_args,
#    sampler = sampler,
#    sampler_args = sampler_args,
#    dataloader_args = dataloader_args,
#    sampling_rate = 200,
    #dataset_cfg = os.path.join(config_dir, 'reaching_analysis.yaml'),
#    data_root = data_root,
#    split = 'val'
#)

#{dataset_cfg = os.path.join(config_dir, 'reaching_new.yml'),


# }


get_brainset_train_val_loaders(
    dataset_args,
    loader_cfg,
    loader_cfg,
    prepend_history=0,
    data_root=DATA_ROOT,
    )

sessions = dataset.get_session_ids() # list of sessions in your dataset
sampling_intervals = dataset.get_sampling_intervals() # list of sampling intervals for each session

TypeError: SequentialFixedWindowSampler.__init__() got an unexpected keyword argument 'prepend_history'

## Validation

In [None]:
from foundational_ssm.utils.pretrain_utils import validate_one_epoch

metrics = validate_one_epoch(
    data_loader, model, state, skip_timesteps=56 # only when computing R2, we would keep this for analysis
)
metrics



{'val/r2_pm_c_co': 0.9161773920059204,
 'val/r2_pm_c_rt': 0.8242394924163818,
 'val/r2_pm_m_rt': 0.8003662824630737,
 'val/r2_pm_m_co': 0.8416915535926819,
 'val/r2_os_i_rt': 0.7526258230209351,
 'val/r2_os_l_rt': 0.5253342390060425,
 'val/r2_cs_j_co': 0.90641850233078,
 'val/r2_cs_n_co': 0.9437413215637207,
 'val/r2_avg': 0.813824325799942,
 'val/r2_all': 0.9150385856628418,
 'val/time': 26.957586765289307}