In [4]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from hmpai.training import split_participants, split_participants_custom
from hmpai.pytorch.training import train_and_test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, load_model
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from torch.utils.data import DataLoader
from hmpai.pytorch.normalization import *
from torchvision.transforms import Compose
from hmpai.pytorch.transforms import *
from hmpai.pytorch.mamba import *
import os
import pandas as pd
from tqdm.notebook import tqdm
DATA_PATH = Path(os.getenv("DATA_PATH"))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# Load in models and data (val set)

# Predict on val set, saving embeddings and predicted probas

# Maybe as xr?
# Include dataset as variable (prp/task1, prp/task2, sat_weindel, etc)

In [8]:
set_global_seed(42)

data_paths = [DATA_PATH / "prp/stage_data_250hz_t1.nc", DATA_PATH / "prp/stage_data_250hz_t2.nc"]
# 80/20 train/val (no test)
splits = split_participants_custom(data_paths, 0)

# labels_t1 = ["negative", "prp_t1_1", "prp_t1_2", "prp_t1_3"]
# labels_t2 = ["negative", "prp_t2_1", "prp_t2_2", "prp_t2_3"]
labels = ["negative", "t1_1", "t1_2", "t1_3", "t2_1", "t2_2", "t2_3"]
info_to_keep = ["participant", "condition", "trial_index", "task"]
subset_cond = None
add_negative = True
skip_samples = 62 # 62
cut_samples = 63 # 63
add_pe = True

In [None]:
def get_embeddings(model, loader, labels, participants):
    torch.cuda.empty_cache()

    n_labels = len(labels) - 1  # Exclude 'negative' label
    emb_dim = model.mamba_dim
    epochs_per_participant = 1316

    # Create empty dataset
    final_ds = xr.Dataset(
        data_vars={
            "embeddings": (
                ("participant", "epochs", "labels", "emb_dim"),
                np.full((len(participants), epochs_per_participant, n_labels, emb_dim), np.nan, dtype=np.float32)
            ),
            "condition": (("participant", "epochs"), np.full((len(participants), epochs_per_participant), "", dtype=object)),
            # Add more info vars as needed
        },
        coords={
            "participant": participants,
            "epochs": np.arange(epochs_per_participant),
            "labels": labels[1:],
            "emb_dim": np.arange(emb_dim),
        }
    )

    with torch.no_grad():
        for batch in tqdm(loader, total=len(loader)):
            info = batch[2][0]
            pred, emb = model(batch[0].to(DEVICE), return_embeddings=True)
            pred = torch.nn.Softmax(dim=2)(pred).to("cpu")
            # pred: (batch_size, time, n_classes)
            # emb: (batch_size, time, model_dim)
            # info: dict of key: list
            pred_peaks = pred[..., 1:].argmax(dim=1)

            B, T, D = emb.shape
            window_radius = 6  # 6 on each side => 13 samples total
            peak_emb = torch.empty((B, n_labels, D), dtype=emb.dtype)

            # Average 13-sample window around each peak, clamped to [0, T)
            for i in range(B):
                for j in range(n_labels):
                    t = int(pred_peaks[i, j].item())
                    start = max(0, t - window_radius)
                    end = min(T, t + window_radius + 1)  # slice end is exclusive
                    peak_emb[i, j] = emb[i, start:end].mean(dim=0)

            # Write to xarray
            for i, p in enumerate(info["participant"]):
                idx = participants.index(p)
                epoch_idx = int(info["trial_index"][i])
                # TODO: Base label indices on task (1:4, 4:?)
                if info["task"][i] == "prp/t1":
                    final_ds["embeddings"][idx, epoch_idx, :, :] = peak_emb[i, :3].numpy()
                elif info["task"][i] == "prp/t2":
                    final_ds["embeddings"][idx, epoch_idx, :, :] = peak_emb[i, 3:].numpy()
                final_ds["condition"][idx, epoch_idx] = info["condition"][i]
    return final_ds

### Combined

In [9]:
norm_fn = norm_mad_zscore
all_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
)

In [None]:
loader = DataLoader(
    all_data, batch_size=32, shuffle=True, num_workers=0, pin_memory=True
)

In [None]:
chk_path = Path("../models/t1_pe.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 64,
    "n_classes": len(labels),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": add_pe,
}

model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

### prp/task1

In [4]:
norm_fn = norm_mad_zscore
train_data = MultiXArrayProbaDataset(
    [data_paths[0]],
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    labels=labels_t1,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    transform=Compose([StartJitterTransform(62, 1.0), EndJitterTransform(63, 1.0)]),
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics["class_weights"]
val_data = MultiXArrayProbaDataset(
    [data_paths[0]],
    participants_to_keep=splits[1],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    labels=labels_t1,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
)
del train_data

In [5]:
val_loader = DataLoader(
    val_data, batch_size=32, shuffle=True, num_workers=0, pin_memory=True
)

In [6]:
chk_path = Path("../models/t1_pe.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 64,
    "n_classes": len(labels_t1),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": add_pe,
}

model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

In [7]:
t1_embs_ds = get_embeddings(model, val_loader, labels_t1, splits[1])

  0%|          | 0/70 [00:00<?, ?it/s]

In [8]:
t1_embs_ds.to_netcdf("files/prp_t1_embeddings.nc")

### prp/task2

In [9]:
norm_fn = norm_mad_zscore
train_data = MultiXArrayProbaDataset(
    [data_paths[1]],
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    labels=labels_t1,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    transform=Compose([StartJitterTransform(62, 1.0), EndJitterTransform(63, 1.0)]),
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics["class_weights"]
val_data = MultiXArrayProbaDataset(
    [data_paths[1]],
    participants_to_keep=splits[1],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    labels=labels_t1,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
)
del train_data

In [10]:
val_loader = DataLoader(
    val_data, batch_size=32, shuffle=True, num_workers=0, pin_memory=True
)

In [None]:
chk_path = Path("../models/t2_pe.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 64,
    "n_classes": len(labels_t2),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": add_pe,
}

model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

In [12]:
t2_embs_ds = get_embeddings(model, val_loader, labels_t2, splits[1])

  0%|          | 0/68 [00:00<?, ?it/s]

In [13]:
t2_embs_ds.to_netcdf("files/prp_t2_embeddings.nc")