In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from hmpai.pytorch.training import train_and_test
from hmpai.pytorch.utilities import set_global_seed
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from hmpai.data import SAT_CLASSES_ACCURACY
from hmpai.pytorch.normalization import *
from torchvision.transforms import Compose
from hmpai.pytorch.transforms import *
from hmpai.pytorch.mamba import *
from hmpai.behaviour.sat2 import SAT2_SPLITS
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

Uses Weindel data with electrodes occurring in both sets

In [None]:
sat1_ch = [
    "Fp1",
    "Fp2",
    "AFz",
    "F7",
    "F3",
    "Fz",
    "F4",
    "F8",
    "T7",
    "C3",
    "Cz",
    "C4",
    "T8",
    "P7",
    "P3",
    "Pz",
    "P4",
    "P8",
    "O1",
    "O2",
    "FC1",
    "FCz",
    "FC2",
    "FC5",
    "FC6",
    "CP5",
    "CP1",
    "CPz",
    "CP2",
    "CP6",
]

In [None]:
set_global_seed(42)

data_paths = [DATA_PATH / "sat2/stage_data_250hz.nc"]

splits = SAT2_SPLITS
labels = SAT_CLASSES_ACCURACY
info_to_keep = ['event_name', 'participant', 'epochs', 'rt']
subset_cond = None
add_negative = True
skip_samples = 0
cut_samples = 0
add_pe = True
subset_channels = sat1_ch

In [None]:
norm_fn = norm_mad_zscore
train_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    transform=Compose([StartJitterTransform(62, 1.0), EndJitterTransform(63, 1.0)]),
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
    subset_channels=sat1_ch,
)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics["class_weights"]
test_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[1],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
    subset_channels=sat1_ch,
)
val_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[2],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
    subset_channels=sat1_ch,
)

In [None]:
config = {
    "n_channels": 30,
    "n_classes": len(labels),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": True,
}

model = build_mamba(config)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/"),
    workers=12,
    batch_size=32,
    lr=0.0001,
    epochs=50,
)