In [1]:
%load_ext autoreload
%autoreload 2
import xarray as xr
import numpy as np
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT_CLASSES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.normalization import *
from torchinfo import summary
from torchvision.transforms import Compose
from hmpai.pytorch.transforms import *
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
from mne.io import read_info
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [2]:
set_global_seed(42)
data_path_sat2 = DATA_PATH / "sat2/stage_data_100hz.nc"
dataset_sat2 = xr.load_dataset(data_path_sat2)
# data_path_sat1 = DATA_PATH / "sat1/stage_data_100hz.nc"
# dataset_sat1 = xr.load_dataset(data_path_sat1)
# dataset_sat1 = dataset_sat1.rename_vars({'RT': 'rt'})

dataset_sat2 = dataset_sat2.where(dataset_sat2.event_name.str.contains("accuracy"), drop=True)
# dataset_sp = dataset_sat1.where(dataset_sat1.event_name.str.contains("speed"), drop=True)

In [3]:
# shape_topological = False
# info_to_keep = ['rt']
# train_data_sat2, val_data_sat2, test_data_sat2 = split_data_on_participants(
#     dataset_sat2, 60, norm_min1_to_1
# )
# train_data_sat1, val_data_sat1, test_data_sat1 = split_data_on_participants(
#     dataset_sat1, 60, norm_min1_to_1
# )
# train_data = xr.merge([train_data_sat2, train_data_sat1])
# val_data = xr.merge([val_data_sat2, val_data_sat1])
# test_data = xr.merge([test_data_sat2, test_data_sat1])
# train_dataset = SAT1Dataset(train_data, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)
# val_dataset = SAT1Dataset(val_data, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)
# test_dataset = SAT1Dataset(test_data, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)

In [3]:
shape_topological = False
info_to_keep = ["rt"]
train_data_sat2, val_data_sat2, test_data_sat2 = split_data_on_participants(
    dataset_sat2, 60, norm_min1_to_1
)
# train_data_sat2 = train_data_sat2.where(train_data_sat2.labels != 0, drop=True)
# val_data_sat2 = val_data_sat2.where(val_data_sat2.labels != 0, drop=True)
# test_data_sat2 = test_data_sat2.where(test_data_sat2.labels != 0, drop=True)
# train_data_sat1, val_data_sat1, test_data_sat1 = split_data_on_participants(
#     dataset_sat1, 60, norm_min1_to_1
# )
# train_data = xr.merge([train_data_sat2, train_data_sat1])
# val_data = xr.merge([val_data_sat2, val_data_sat1])
# test_data = xr.merge([test_data_sat2, test_data_sat1])
train_dataset = SAT1Dataset(
    train_data_sat2,
    shape_topological=shape_topological,
    labels=SAT_CLASSES_ACCURACY,
    set_to_zero=True,
    info_to_keep=info_to_keep,
    order_by_rt=False,
    # transform=Compose([ShuffleOperations()]),
)
val_dataset = SAT1Dataset(
    val_data_sat2,
    shape_topological=shape_topological,
    labels=SAT_CLASSES_ACCURACY,
    set_to_zero=True,
    info_to_keep=info_to_keep,
    order_by_rt=False,
    # transform=Compose([RandomCropTransform()]),
)
test_dataset = SAT1Dataset(
    test_data_sat2,
    shape_topological=shape_topological,
    labels=SAT_CLASSES_ACCURACY,
    set_to_zero=True,
    info_to_keep=info_to_keep,
    order_by_rt=False,
    # transform=Compose([RandomCropTransform()]),
)

In [5]:
model = MambaModel(embed_dim=256, n_channels=30, n_classes=5, n_layers=3, global_pool=False)
# model.set_pretraining(False)
train_and_test(
    model,
    train_dataset,
    test_dataset,
    val_dataset,
    logs_path=Path("../logs/"),
    workers=0,
    batch_size=64,
    labels=SAT_CLASSES_ACCURACY,
    epochs=100,
    # epochs=1,
    weight_decay=0.001,
    label_smoothing=0.1,
    lr=0.00001
)

  0%|          | 0/153 [00:00<?, ? batch/s]

  return F.conv1d(input, weight, bias, self.stride,


              precision    recall  f1-score   support

           0       1.00      0.97      0.99    772193
           1       0.24      0.42      0.30     41315
           2       0.40      0.39      0.40    113216
           3       0.42      0.35      0.38    110763
           4       0.57      0.90      0.70     18745

    accuracy                           0.82   1056232
   macro avg       0.53      0.61      0.55   1056232
weighted avg       0.84      0.82      0.83   1056232



  0%|          | 0/153 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    772193
           1       0.60      0.88      0.71     41315
           2       0.73      0.43      0.54    113216
           3       0.66      0.88      0.75    110763
           4       0.99      0.90      0.95     18745

    accuracy                           0.91   1056232
   macro avg       0.80      0.82      0.79   1056232
weighted avg       0.92      0.91      0.91   1056232



  0%|          | 0/153 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    772193
           1       0.63      0.95      0.76     41315
           2       0.68      0.79      0.73    113216
           3       0.87      0.65      0.74    110763
           4       0.99      0.90      0.94     18745

    accuracy                           0.93   1056232
   macro avg       0.83      0.86      0.83   1056232
weighted avg       0.94      0.93      0.93   1056232



  0%|          | 0/153 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    772193
           1       0.65      0.95      0.77     41315
           2       0.81      0.69      0.75    113216
           3       0.80      0.85      0.83    110763
           4       0.99      0.90      0.94     18745

    accuracy                           0.94   1056232
   macro avg       0.85      0.88      0.86   1056232
weighted avg       0.94      0.94      0.94   1056232



  0%|          | 0/153 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    772193
           1       0.73      0.90      0.81     41315
           2       0.77      0.80      0.79    113216
           3       0.83      0.79      0.81    110763
           4       0.92      0.90      0.91     18745

    accuracy                           0.94   1056232
   macro avg       0.85      0.88      0.86   1056232
weighted avg       0.95      0.94      0.94   1056232



  0%|          | 0/153 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    772193
           1       0.71      0.92      0.80     41315
           2       0.79      0.77      0.78    113216
           3       0.82      0.82      0.82    110763
           4       0.90      0.90      0.90     18745

    accuracy                           0.94   1056232
   macro avg       0.84      0.88      0.86   1056232
weighted avg       0.95      0.94      0.94   1056232



  0%|          | 0/153 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    772193
           1       0.72      0.92      0.81     41315
           2       0.82      0.74      0.78    113216
           3       0.80      0.86      0.83    110763
           4       0.99      0.90      0.94     18745

    accuracy                           0.95   1056232
   macro avg       0.87      0.88      0.87   1056232
weighted avg       0.95      0.95      0.95   1056232



  0%|          | 0/153 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    772193
           1       0.70      0.94      0.80     41315
           2       0.81      0.74      0.77    113216
           3       0.81      0.85      0.83    110763
           4       0.98      0.90      0.94     18745

    accuracy                           0.94   1056232
   macro avg       0.86      0.88      0.87   1056232
weighted avg       0.95      0.94      0.94   1056232



  0%|          | 0/153 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    772193
           1       0.67      0.96      0.79     41315
           2       0.82      0.70      0.76    113216
           3       0.80      0.86      0.83    110763
           4       0.98      0.90      0.94     18745

    accuracy                           0.94   1056232
   macro avg       0.85      0.88      0.86   1056232
weighted avg       0.95      0.94      0.94   1056232



  0%|          | 0/153 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    772193
           1       0.66      0.96      0.78     41315
           2       0.81      0.71      0.76    113216
           3       0.81      0.84      0.83    110763
           4       0.98      0.90      0.94     18745

    accuracy                           0.94   1056232
   macro avg       0.85      0.88      0.86   1056232
weighted avg       0.95      0.94      0.94   1056232



  0%|          | 0/153 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    772193
           1       0.54      0.99      0.70     41315
           2       0.84      0.55      0.66    113216
           3       0.78      0.89      0.83    110763
           4       0.97      0.90      0.94     18745

    accuracy                           0.93   1056232
   macro avg       0.83      0.86      0.83   1056232
weighted avg       0.94      0.93      0.93   1056232



  return F.conv1d(input, weight, bias, self.stride,


[{'0.0': {'precision': 1.0,
   'recall': 0.988918158609363,
   'f1-score': 0.9944282064384261,
   'support': 737964.0},
  '1.0': {'precision': 0.7187680713339676,
   'recall': 0.8771198980172551,
   'f1-score': 0.7900877624671916,
   'support': 43929.0},
  '2.0': {'precision': 0.7910087274203369,
   'recall': 0.7182837712063548,
   'f1-score': 0.7528941306017184,
   'support': 108517.0},
  '3.0': {'precision': 0.7828979686813234,
   'recall': 0.8557003042548582,
   'f1-score': 0.8176818455943905,
   'support': 106161.0},
  '4.0': {'precision': 0.9807854475364414,
   'recall': 0.9047619047619048,
   'f1-score': 0.9412410763316859,
   'support': 17997.0},
  'accuracy': 0.9396984726504286,
  'macro avg': {'precision': 0.8546920429944139,
   'recall': 0.8689568073699473,
   'f1-score': 0.8592666042866824,
   'support': 1014568.0},
  'weighted avg': {'precision': 0.9424120252477841,
   'recall': 0.9396984726504286,
   'f1-score': 0.940308813884276,
   'support': 1014568.0}}]

In [6]:
model = Seq2SeqTransformer(d_model=30, ff_dim=1024, num_heads=8, num_layers=6, num_classes=len(SAT_CLASSES_ACCURACY), emb_dim=512)
model.set_pretraining(False)
train_and_test(
    model,
    train_dataset,
    test_dataset,
    val_dataset,
    logs_path=Path("../logs/"),
    workers=4,
    batch_size=64,
    labels=SAT_CLASSES_ACCURACY,
    epochs=10,
    # epochs=1,
    weight_decay=0.001,
    label_smoothing=0.1,
    lr=0.00001
)

  0%|          | 0/167 [00:00<?, ? batch/s]

ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 

RuntimeError: DataLoader worker (pid(s) 166091, 166096, 166097, 166098, 166099) exited unexpectedly