In [1]:
%load_ext autoreload
%autoreload 2
import xarray as xr
import numpy as np
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT_CLASSES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
from mne.io import read_info
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [2]:
set_global_seed(42)
data_path_sat2 = DATA_PATH / "sat2/stage_data_100hz.nc"
dataset_sat2 = xr.load_dataset(data_path_sat2)
data_path_sat1 = DATA_PATH / "sat1/stage_data_100hz.nc"
dataset_sat1 = xr.load_dataset(data_path_sat1)
dataset_sat1 = dataset_sat1.rename_vars({'RT': 'rt'})

# dataset_acc = dataset_sat1.where(dataset_sat1.event_name.str.contains("accuracy"), drop=True)
# dataset_sp = dataset_sat1.where(dataset_sat1.event_name.str.contains("speed"), drop=True)

In [3]:
shape_topological = False
info_to_keep = ['rt']
train_data_sat2, val_data_sat2, test_data_sat2 = split_data_on_participants(
    dataset_sat2, 60, norm_min1_to_1
)
train_data_sat1, val_data_sat1, test_data_sat1 = split_data_on_participants(
    dataset_sat1, 60, norm_min1_to_1
)
train_data = xr.merge([train_data_sat2, train_data_sat1])
val_data = xr.merge([val_data_sat2, val_data_sat1])
test_data = xr.merge([test_data_sat2, test_data_sat1])
train_dataset = SAT1Dataset(train_data, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)
val_dataset = SAT1Dataset(val_data, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)
test_dataset = SAT1Dataset(test_data, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)

In [5]:
model = Seq2SeqTransformer(d_model=30, ff_dim=2048, num_heads=10, num_layers=6, num_classes=len(SAT_CLASSES_ACCURACY))

train_and_test(
    model,
    train_dataset,
    test_dataset,
    val_dataset,
    logs_path=Path("../logs/"),
    workers=4,
    batch_size=128,
    labels=SAT_CLASSES_ACCURACY,
    epochs=10,
    # epochs=1,
    weight_decay=0.001,
    label_smoothing=0.01,
    lr=0.0001
)



  0%|          | 0/183 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.93      0.96     10749
           1       0.61      0.95      0.74       900
           2       0.85      0.65      0.74      2791
           3       0.58      0.85      0.69      1365
           4       0.58      0.96      0.73       413

    accuracy                           0.88     16218
   macro avg       0.72      0.87      0.77     16218
weighted avg       0.91      0.88      0.88     16218



  0%|          | 0/183 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.96      0.98     12769
           1       0.60      0.93      0.73       991
           2       0.82      0.65      0.73      2728
           3       0.64      0.74      0.68      1512
           4       0.59      0.98      0.74       462

    accuracy                           0.90     18462
   macro avg       0.73      0.85      0.77     18462
weighted avg       0.91      0.90      0.90     18462



  0%|          | 0/183 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.95      0.97     12898
           1       0.58      0.94      0.72       941
           2       0.84      0.68      0.75      2611
           3       0.65      0.81      0.72      1469
           4       0.64      0.96      0.77       441

    accuracy                           0.90     18360
   macro avg       0.74      0.87      0.78     18360
weighted avg       0.92      0.90      0.90     18360



  0%|          | 0/183 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.93      0.96      5622
           1       0.69      0.91      0.78      1001
           2       0.86      0.69      0.77      2752
           3       0.52      0.75      0.62       967
           4       0.67      0.98      0.80       470

    accuracy                           0.85     10812
   macro avg       0.75      0.85      0.79     10812
weighted avg       0.88      0.85      0.86     10812



  0%|          | 0/183 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.98      0.99     17697
           1       0.66      0.91      0.76      1019
           2       0.88      0.63      0.73      2916
           3       0.58      0.82      0.68      1467
           4       0.68      0.97      0.80       463

    accuracy                           0.92     23562
   macro avg       0.76      0.86      0.79     23562
weighted avg       0.94      0.92      0.93     23562



  0%|          | 0/183 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.93      0.97     13258
           1       0.55      0.98      0.70       907
           2       0.83      0.61      0.70      2599
           3       0.56      0.84      0.67      1369
           4       0.69      0.97      0.81       431

    accuracy                           0.88     18564
   macro avg       0.73      0.87      0.77     18564
weighted avg       0.91      0.88      0.89     18564



  0%|          | 0/183 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.94      0.97     10473
           1       0.60      0.84      0.70      1021
           2       0.81      0.65      0.72      2632
           3       0.60      0.83      0.69      1336
           4       0.70      0.97      0.81       450

    accuracy                           0.88     15912
   macro avg       0.74      0.85      0.78     15912
weighted avg       0.90      0.88      0.89     15912



  0%|          | 0/183 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.93      0.96      8555
           1       0.58      0.98      0.73       873
           2       0.85      0.68      0.76      2659
           3       0.63      0.79      0.70      1253
           4       0.65      0.99      0.78       430

    accuracy                           0.87     13770
   macro avg       0.74      0.88      0.79     13770
weighted avg       0.90      0.87      0.88     13770



  0%|          | 0/183 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.96      0.98     12839
           1       0.60      0.91      0.72      1014
           2       0.88      0.72      0.79      2761
           3       0.69      0.81      0.74      1487
           4       0.67      0.98      0.79       463

    accuracy                           0.91     18564
   macro avg       0.77      0.88      0.81     18564
weighted avg       0.93      0.91      0.91     18564



  0%|          | 0/183 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.96      0.98     10856
           1       0.62      0.94      0.75       972
           2       0.88      0.67      0.76      2623
           3       0.63      0.84      0.72      1196
           4       0.75      0.96      0.84       469

    accuracy                           0.91     16116
   macro avg       0.78      0.87      0.81     16116
weighted avg       0.92      0.91      0.91     16116



[{'0.0': {'precision': 0.999992075378995,
   'recall': 0.9576270043104902,
   'f1-score': 0.9783511260334806,
   'support': 1185944.0},
  '1.0': {'precision': 0.6461274991918677,
   'recall': 0.8782861523503235,
   'f1-score': 0.7445286712289141,
   'support': 106964.0},
  '2.0': {'precision': 0.8384234622006573,
   'recall': 0.6460505189841931,
   'f1-score': 0.729772188328856,
   'support': 265326.0},
  '3.0': {'precision': 0.5741204213523765,
   'recall': 0.8318130539887187,
   'f1-score': 0.6793505801212233,
   'support': 124100.0},
  '4.0': {'precision': 0.6947454679242758,
   'recall': 0.9704041111889745,
   'f1-score': 0.8097577139738417,
   'support': 42810.0},
  'accuracy': 0.8960538946314047,
  'macro avg': {'precision': 0.7506817852096345,
   'recall': 0.8568361681645401,
   'f1-score': 0.7883520559372632,
   'support': 1725144.0},
  'weighted avg': {'precision': 0.9149919803624195,
   'recall': 0.8960538946314047,
   'f1-score': 0.8999295612254895,
   'support': 1725144.0}}