In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.pytorch.utilities import DEVICE, set_global_seed, load_model
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from hmpai.data import SAT_CLASSES_ACCURACY
from hmpai.pytorch.normalization import *
from torch.utils.data import DataLoader
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))
from hmpai.visualization import predict_with_auc
from hmpai.behaviour.sat2 import read_behavioural_info, SAT2_SPLITS
from hmpai.pytorch.mamba import *

#### Create split of Low | Med | High for both conditions and save participant + epoch + split information

In [None]:
set_global_seed(42)
data_paths = [DATA_PATH / "sat2/stage_data_250hz.nc"]
splits = SAT2_SPLITS
labels = SAT_CLASSES_ACCURACY
whole_epoch = True
info_to_keep = ["rt", "participant", "epochs", "condition", "response", "side"]
subset_cond = None
skip_samples = 62
cut_samples = 63
add_pe = True

In [None]:
norm_fn = norm_mad_zscore

train_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics["class_weights"]
testval_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[1] + splits[2],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
)

In [4]:
behaviour_sat2 = read_behavioural_info(DATA_PATH / "sat2/behavioural/df_full.csv")
test_loader = DataLoader(
    testval_data, batch_size=128, shuffle=False, num_workers=8, pin_memory=True
)

In [None]:
chk_path = Path("../models/final.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 64,
    "n_classes": len(labels),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": True,
}

model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

In [6]:
data = predict_with_auc(model, test_loader, info_to_keep, SAT_CLASSES_ACCURACY)

In [9]:
# Load data
ds = xr.load_dataset(data_paths[0])

In [None]:
# Add tertile info to data
data_pivot = data.pivot(index='participant', columns='epochs', values='tertile')
da_tertile = xr.DataArray(data_pivot.values, dims=['participant', 'epochs'], coords={'participant': data_pivot.index, 'epochs': data_pivot.columns})
ds = ds.assign_coords(tertile=da_tertile)

# Drop participants that were not part of test/val sets, so where tertile is not assigned
ds = ds.where(ds['tertile'].notnull(), drop=True)

In [None]:
# Save data
ds.to_netcdf(DATA_PATH / "sat2/stage_data_250hz_tertile.nc")