In [1]:
%load_ext autoreload
%autoreload 2
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT_CLASSES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
from mne.io import read_info
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [14]:
set_global_seed(42)
data_path_sat1 = DATA_PATH / "sat2/stage_data_100hz.nc"
dataset_sat1 = xr.load_dataset(data_path_sat1)
# dataset_acc = dataset_sat1.where(dataset_sat1.event_name.str.contains("accuracy"), drop=True)
# dataset_sp = dataset_sat1.where(dataset_sat1.event_name.str.contains("speed"), drop=True)

In [15]:
shape_topological = False
train_data_sat1, val_data_sat1, test_data_sat1 = split_data_on_participants(
    dataset_sat1, 60, norm_min1_to_1
)
train_dataset = SAT1Dataset(train_data_sat1, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, interpolate_to=100)
val_dataset = SAT1Dataset(val_data_sat1, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, interpolate_to=100)
test_dataset = SAT1Dataset(test_data_sat1, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, interpolate_to=100)


In [19]:
train_dataset = SAT1Dataset(train_data_sat1, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True)
val_dataset = SAT1Dataset(val_data_sat1, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True)
test_dataset = SAT1Dataset(test_data_sat1, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True)

In [22]:
#Interpolate
train_dataset = SAT1Dataset(train_data_sat1, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, interpolate_to=100)
val_dataset = SAT1Dataset(val_data_sat1, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, interpolate_to=100)
test_dataset = SAT1Dataset(test_data_sat1, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, interpolate_to=100)


In [25]:
model = Seq2SeqTransformer(d_model=30, ff_dim=2048, num_heads=10, num_layers=6, num_classes=len(SAT_CLASSES_ACCURACY))

train_and_test(
    model,
    train_dataset,
    test_dataset,
    val_dataset,
    logs_path=Path("../logs/"),
    workers=0,
    batch_size=64,
    labels=SAT_CLASSES_ACCURACY,
    epochs=10,
    # epochs=1,
    weight_decay=0.001,
    label_smoothing=0.0001,
    lr=0.0001
)



  0%|          | 0/323 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       0.99      0.80      0.89       797
           1       0.54      0.95      0.69       101
           2       0.80      0.61      0.69       272
           3       0.49      0.87      0.62       143
           4       0.48      0.78      0.59        37

    accuracy                           0.78      1350
   macro avg       0.66      0.80      0.70      1350
weighted avg       0.85      0.78      0.80      1350



  0%|          | 0/323 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       0.99      0.88      0.93      1092
           1       0.52      0.92      0.67       103
           2       0.79      0.62      0.69       263
           3       0.50      0.74      0.60       126
           4       0.42      0.89      0.57        36

    accuracy                           0.83      1620
   macro avg       0.65      0.81      0.69      1620
weighted avg       0.88      0.83      0.84      1620



  0%|          | 0/323 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       0.99      0.82      0.90       644
           1       0.60      0.92      0.72       112
           2       0.78      0.68      0.73       247
           3       0.52      0.67      0.58        99
           4       0.41      0.92      0.57        38

    accuracy                           0.79      1140
   macro avg       0.66      0.80      0.70      1140
weighted avg       0.84      0.79      0.81      1140



  0%|          | 0/323 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       0.99      0.93      0.96       882
           1       0.62      0.88      0.73       115
           2       0.90      0.66      0.76       341
           3       0.44      0.79      0.57        87
           4       0.51      0.89      0.65        45

    accuracy                           0.85      1470
   macro avg       0.69      0.83      0.73      1470
weighted avg       0.89      0.85      0.86      1470



  0%|          | 0/323 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       0.98      0.87      0.92       504
           1       0.60      0.92      0.72       110
           2       0.85      0.69      0.77       347
           3       0.52      0.62      0.57       104
           4       0.49      0.93      0.65        45

    accuracy                           0.80      1110
   macro avg       0.69      0.81      0.72      1110
weighted avg       0.84      0.80      0.81      1110



  0%|          | 0/323 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       0.99      0.96      0.98       830
           1       0.67      0.86      0.75       132
           2       0.87      0.68      0.76       333
           3       0.67      0.76      0.71       194
           4       0.51      0.98      0.67        41

    accuracy                           0.86      1530
   macro avg       0.74      0.85      0.77      1530
weighted avg       0.88      0.86      0.87      1530



  0%|          | 0/323 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       0.99      0.94      0.97       906
           1       0.52      0.94      0.67       102
           2       0.86      0.65      0.74       312
           3       0.64      0.72      0.68       139
           4       0.57      0.95      0.71        41

    accuracy                           0.86      1500
   macro avg       0.72      0.84      0.75      1500
weighted avg       0.89      0.86      0.86      1500



  0%|          | 0/323 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       0.99      0.92      0.95       676
           1       0.63      0.92      0.75       118
           2       0.87      0.74      0.80       288
           3       0.68      0.73      0.70       137
           4       0.56      0.98      0.71        41

    accuracy                           0.86      1260
   macro avg       0.75      0.86      0.78      1260
weighted avg       0.88      0.86      0.86      1260



  0%|          | 0/323 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       0.99      0.93      0.96       691
           1       0.58      0.94      0.71       109
           2       0.88      0.77      0.82       308
           3       0.86      0.80      0.83       232
           4       0.56      1.00      0.72        40

    accuracy                           0.87      1380
   macro avg       0.77      0.89      0.81      1380
weighted avg       0.90      0.87      0.88      1380



  0%|          | 0/323 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.92      0.96      2108
           1       0.54      0.89      0.67       104
           2       0.74      0.64      0.68       332
           3       0.48      0.80      0.60       176
           4       0.60      0.82      0.69        40

    accuracy                           0.88      2760
   macro avg       0.67      0.82      0.72      2760
weighted avg       0.91      0.88      0.89      2760



[{'0.0': {'precision': 0.994931696693724,
   'recall': 0.944043539284382,
   'f1-score': 0.9688198413933371,
   'support': 372629.0},
  '1.0': {'precision': 0.629037334277509,
   'recall': 0.8580589766293703,
   'f1-score': 0.7259128105598874,
   'support': 37269.0},
  '2.0': {'precision': 0.8425704192741754,
   'recall': 0.6503942154499776,
   'f1-score': 0.7341137412895752,
   'support': 97916.0},
  '3.0': {'precision': 0.5500146988287354,
   'recall': 0.8107467043744013,
   'f1-score': 0.6554016058703689,
   'support': 43846.0},
  '4.0': {'precision': 0.5894841166345001,
   'recall': 0.9314266875463612,
   'f1-score': 0.722016355737286,
   'support': 12133.0},
  'accuracy': 0.8767224850255324,
  'macro avg': {'precision': 0.7212076531417286,
   'recall': 0.8389340246568985,
   'f1-score': 0.7612528709700909,
   'support': 563793.0},
  'weighted avg': {'precision': 0.9009570464054989,
   'recall': 0.8767224850255324,
   'f1-score': 0.8823147104745482,
   'support': 563793.0}}]