In [1]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from hmpai.training import split_participants, split_participants_custom
from hmpai.pytorch.training import train_and_test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, load_model
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from torch.utils.data import DataLoader
from hmpai.pytorch.normalization import *
from torchvision.transforms import Compose
from hmpai.pytorch.transforms import *
from hmpai.pytorch.mamba import *
import os
import pandas as pd
from tqdm.notebook import tqdm
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [None]:
# Load in models and data (val set)

# Predict on val set, saving embeddings and predicted probas

# Maybe as xr?
# Include dataset as variable (prp/task1, prp/task2, sat_weindel, etc)

In [2]:
set_global_seed(42)

data_paths = [DATA_PATH / "prp/stage_data_250hz_t1.nc", DATA_PATH / "prp/stage_data_250hz_t2.nc"]
# 80/20 train/val (no test)
splits = split_participants_custom(data_paths, 0)
# Labels per dataset (same order as data_paths)
data_labels = [["negative", "t1_1", "t1_2", "t1_3"], ["negative", "t2_1", "t2_2", "t2_3"]]
# Labels combined
labels = ["negative", "t1_1", "t1_2", "t1_3", "t2_1", "t2_2", "t2_3"]
info_to_keep = ["participant", "condition", "trial_index", "task"]
# subset_cond = ('condition', 'equal', 'long')
subset_cond = None
add_negative = True
skip_samples = 62 # 62
cut_samples = 63 # 63
add_pe = True

In [8]:
def get_embeddings(model, loader, labels, participants, window_size=6):
    torch.cuda.empty_cache()

    n_labels = len(labels) - 1  # Exclude 'negative' label
    emb_dim = model.mamba_dim
    epochs_per_participant = 1316

    # Create empty dataset
    final_ds = xr.Dataset(
        data_vars={
            "embeddings": (
                ("participant", "epochs", "labels", "emb_dim"),
                np.full((len(participants), epochs_per_participant, n_labels, emb_dim), np.nan, dtype=np.float32)
            ),
            "condition": (("participant", "epochs"), np.full((len(participants), epochs_per_participant), "", dtype=object)),
            "confidence": (("participant", "epochs", "labels"), np.full((len(participants), epochs_per_participant, n_labels), np.nan, dtype=np.float32)),
        },
        coords={
            "participant": participants,
            "epochs": np.arange(epochs_per_participant),
            "labels": labels[1:],
            "emb_dim": np.arange(emb_dim),
        }
    )

    with torch.no_grad():
        for batch in tqdm(loader, total=len(loader)):
            info = batch[2][0]
            pred, emb = model(batch[0].to(DEVICE), return_embeddings=True, task=info["task"])
            pred = torch.nn.Softmax(dim=2)(pred).to("cpu")
            # pred: (batch_size, time, n_classes)
            # emb: (batch_size, time, model_dim)
            # info: dict of key: list
            peaks = pred[..., 1:].argmax(dim=1)

            # Get embeddings for each label's peak (shape: batch × labels × emb_dim)
            B, T, D = emb.shape
            emb_cpu = emb.cpu()
            # now build a (B, n_labels, emb_dim) tensor
            window_emb = torch.zeros((B, n_labels, D),
                                     dtype=emb_cpu.dtype)
            window_conf = torch.zeros((B, n_labels), dtype=pred.dtype)
            pre = post = window_size

            for i in range(B):
                for lab in range(n_labels):
                    p = peaks[i, lab].item()
                    start = max(0, p - pre)
                    stop  = min(T, p + post + 1)
                    # average over that little window
                    window_emb[i, lab, :] = emb_cpu[i, start:stop, :].mean(dim=0)
                    window_conf[i, lab] = pred[i, start:stop, lab + 1].mean()

            # bring back to cpu for numpy/xarray
            peak_emb = window_emb.numpy()  
            
            for i, p in enumerate(info["participant"]):
                idx = participants.index(p)
                epoch_idx = info["trial_index"][i].int()
                
                if info["task"][i] == "prp1/t1":
                    final_ds["embeddings"][idx, epoch_idx, :3, :] = peak_emb[i, :3]
                    final_ds["confidence"][idx, epoch_idx, :3] = window_conf[i, :3]
                elif info["task"][i] == "prp1/t2":
                    final_ds["embeddings"][idx, epoch_idx, 3:, :] = peak_emb[i, 3:]
                    final_ds["confidence"][idx, epoch_idx, 3:] = window_conf[i, 3:]
                final_ds["condition"][idx, epoch_idx] = info["condition"][i]
    return final_ds

### prp

In [4]:
norm_fn = norm_mad_zscore
train_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    labels=labels,
    data_labels=data_labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    # transform=Compose([StartJitterTransform(62, 1.0), EndJitterTransform(63, 1.0)]),
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
)
# norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
# class_weights = train_data.statistics["class_weights"]
# val_data = MultiXArrayProbaDataset(
#     data_paths,
#     participants_to_keep=splits[1],
#     normalization_fn=norm_fn,
#     norm_vars=norm_vars,
#     labels=labels,
#     data_labels=data_labels,
#     info_to_keep=info_to_keep,
#     subset_cond=subset_cond,
#     add_negative=add_negative,
#     skip_samples=skip_samples,
#     cut_samples=cut_samples,
#     add_pe=add_pe,
# )
# del train_data

In [5]:
train_loader = DataLoader(
    train_data, batch_size=32, shuffle=True, num_workers=12, pin_memory=True
)

In [6]:
chk_path = Path("../models/cmb_shared.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 64,
    "n_classes": len(labels),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": add_pe,
}

model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

In [None]:
for batch in train_loader:
    print(batch[0].shape)

In [9]:
t1_embs_ds = get_embeddings(model, train_loader, labels, splits[0])

  0%|          | 0/1464 [00:00<?, ?it/s]

In [12]:
t1_embs_ds.to_netcdf("files/cmb_embeddings.nc")