In [5]:
%load_ext autoreload
%autoreload 2
import netCDF4
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_participants
from hmpai.pytorch.utilities import set_global_seed
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from hmpai.pytorch.normalization import *
from hmpai.pytorch.transforms import *
from hmpai.pytorch.mamba import *
from hmpai.pytorch.training import train_and_test

from mne.io import read_info
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
set_global_seed(42)

data_paths = [DATA_PATH / "prp/stage_data_250hz_long_combined.nc"]

# train_percentage=100 makes test and val 100 as well
splits = split_participants(data_paths, train_percentage=60)
labels = ["negative", "t1_1", "t1_2", "t1_3", "t1_4", "t2_1", "t2_2", "t2_3"]
# info_to_keep = ['event_name', 'participant', 'epochs', 'rt'] #TODO: Might not be present
info_to_keep = []
whole_epoch = True
subset_cond = ('condition', 'equal', 'long')
add_negative = True
skip_samples = 0 # 62
cut_samples = 0 # 63

In [11]:
norm_fn = norm_mad_zscore
train_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    # transform=Compose([StartJitterTransform(62, 1.0), EndJitterTransform(63, 1.0)]),
    skip_samples=skip_samples,
    cut_samples=cut_samples,
)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics["class_weights"]
test_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[1],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
)
val_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[2],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
)

In [None]:
model = TestMamba(64, 128, 64, len(labels), 5, False)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/"),
    workers=8,
    batch_size=32,
    labels=labels,
    lr=0.0001,
    # lr=0.0005,
    # label_smoothing=0.1,
    # weight_decay=0.0001,
    do_spectral_decoupling=False,
    use_class_weights=False,
    class_weights=class_weights,
    whole_epoch=True,
    epochs=20,
)