In [1]:
%load_ext autoreload
%autoreload 2
import netCDF4
import xarray as xr
import numpy as np
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.pretraining import random_masking
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT_CLASSES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
from mne.io import read_info
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [2]:
# Try if pre-training on a pretext task of SAT2 works to predict SAT1

In [3]:
set_global_seed(42)
data_path_sat2 = DATA_PATH / "sat2/stage_data_100hz.nc"
dataset_sat2 = xr.load_dataset(data_path_sat2)
data_path_sat1 = DATA_PATH / "sat1/stage_data_100hz.nc"
dataset_sat1 = xr.load_dataset(data_path_sat1)
dataset_sat1 = dataset_sat1.rename_vars({'RT': 'rt'})

In [4]:
shape_topological = False
info_to_keep = []
train_data_sat2, val_data_sat2, test_data_sat2 = split_data_on_participants(
    dataset_sat2, 60, norm_min1_to_1
)
train_data_sat1, val_data_sat1, test_data_sat1 = split_data_on_participants(
    dataset_sat1, 60, norm_min1_to_1
)
train_dataset_sat2 = SAT1Dataset(train_data_sat2, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)
val_dataset_sat2 = SAT1Dataset(val_data_sat2, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)
test_dataset_sat2 = SAT1Dataset(test_data_sat2, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)
train_dataset_sat1 = SAT1Dataset(train_data_sat2, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)
val_dataset_sat1 = SAT1Dataset(val_data_sat2, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)
test_dataset_sat1 = SAT1Dataset(test_data_sat2, shape_topological=shape_topological, labels=SAT_CLASSES_ACCURACY, set_to_zero=True, info_to_keep=info_to_keep, order_by_rt=False)

In [5]:
chk_path = Path("../models/sat2_pretrained.pt")
checkpoint = load_model(chk_path)
model_kwargs = {
    "d_model": len(test_data_sat2.channels),
    "num_heads": 10,
    "ff_dim": 2048,
    "num_layers": 6,
    "num_classes": len(SAT_CLASSES_ACCURACY),
}
model = Seq2SeqTransformer(**model_kwargs)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)



In [7]:
torch.cuda.empty_cache()
model = Seq2SeqTransformer(d_model=30, ff_dim=2048, num_heads=10, num_layers=6, num_classes=len(SAT_CLASSES_ACCURACY))

train_and_test(
    model,
    train_dataset_sat2,
    test_dataset_sat2,
    val_dataset_sat2,
    logs_path=Path("../logs/"),
    workers=0,
    batch_size=128,
    labels=SAT_CLASSES_ACCURACY,
    epochs=10,
    # epochs=1,
    weight_decay=0.001,
    label_smoothing=0.01,
    lr=0.0001,
    pretrain_fn=random_masking,
    use_class_weights=False,
)

  0%|          | 0/162 [00:00<?, ? batch/s]

  0%|          | 0/162 [00:00<?, ? batch/s]

  0%|          | 0/162 [00:00<?, ? batch/s]

  0%|          | 0/162 [00:00<?, ? batch/s]

  0%|          | 0/162 [00:00<?, ? batch/s]

  0%|          | 0/162 [00:00<?, ? batch/s]

[[0.00014582174480892718]]

In [11]:
model = Seq2SeqTransformer(d_model=30, ff_dim=2048, num_heads=10, num_layers=6, num_classes=len(SAT_CLASSES_ACCURACY))




In [8]:
# Fine tune
model.set_pretraining(False)
train_and_test(
    model,
    train_dataset_sat1,
    test_dataset_sat1,
    val_dataset_sat1,
    logs_path=Path("../logs/"),
    workers=0,
    batch_size=128,
    labels=SAT_CLASSES_ACCURACY,
    epochs=10,
    # epochs=1,
    weight_decay=0.001,
    label_smoothing=0.01,
    lr=0.00001,
    pretrain_fn=None,
    use_class_weights=True,
)

  0%|          | 0/162 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.92      0.96      7101
           1       0.61      0.91      0.73       913
           2       0.82      0.41      0.55      2383
           3       0.54      0.62      0.58      1245
           4       0.11      0.43      0.18       390

    accuracy                           0.77     12032
   macro avg       0.62      0.66      0.60     12032
weighted avg       0.86      0.77      0.80     12032



  0%|          | 0/162 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.92      0.96      8004
           1       0.62      0.95      0.75       837
           2       0.83      0.66      0.74      2284
           3       0.52      0.45      0.48      1070
           4       0.28      0.82      0.42       401

    accuracy                           0.84     12596
   macro avg       0.65      0.76      0.67     12596
weighted avg       0.88      0.84      0.85     12596



  0%|          | 0/162 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.91      0.95      6827
           1       0.64      0.91      0.75       885
           2       0.82      0.68      0.74      2255
           3       0.57      0.58      0.57       913
           4       0.37      0.96      0.54       400

    accuracy                           0.84     11280
   macro avg       0.68      0.80      0.71     11280
weighted avg       0.88      0.84      0.85     11280



  0%|          | 0/162 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.96      0.98      6681
           1       0.65      0.95      0.77       848
           2       0.91      0.65      0.76      2587
           3       0.64      0.68      0.66      1018
           4       0.43      0.96      0.59       428

    accuracy                           0.87     11562
   macro avg       0.73      0.84      0.75     11562
weighted avg       0.90      0.87      0.87     11562



  0%|          | 0/162 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.93      0.96      7088
           1       0.63      0.94      0.75       816
           2       0.87      0.66      0.75      2551
           3       0.51      0.67      0.58       781
           4       0.47      0.99      0.64       420

    accuracy                           0.86     11656
   macro avg       0.69      0.84      0.74     11656
weighted avg       0.89      0.86      0.86     11656



  0%|          | 0/162 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.92      0.96      7028
           1       0.63      0.94      0.76       822
           2       0.85      0.68      0.75      2520
           3       0.59      0.72      0.65      1161
           4       0.48      0.99      0.65       407

    accuracy                           0.85     11938
   macro avg       0.71      0.85      0.75     11938
weighted avg       0.89      0.85      0.86     11938



  0%|          | 0/162 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.95      0.97      7265
           1       0.66      0.96      0.78       849
           2       0.87      0.70      0.77      2274
           3       0.66      0.66      0.66      1038
           4       0.48      1.00      0.65       418

    accuracy                           0.88     11844
   macro avg       0.73      0.85      0.77     11844
weighted avg       0.90      0.88      0.88     11844



  0%|          | 0/162 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.94      0.97      7239
           1       0.67      0.93      0.77       887
           2       0.87      0.70      0.77      2335
           3       0.64      0.70      0.67      1047
           4       0.50      0.99      0.67       430

    accuracy                           0.87     11938
   macro avg       0.74      0.85      0.77     11938
weighted avg       0.90      0.87      0.88     11938



  0%|          | 0/162 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.97      0.98     16340
           1       0.65      0.94      0.77       853
           2       0.85      0.71      0.77      2516
           3       0.68      0.76      0.72      1588
           4       0.51      0.98      0.67       417

    accuracy                           0.92     21714
   macro avg       0.74      0.87      0.78     21714
weighted avg       0.94      0.92      0.93     21714



  0%|          | 0/162 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           0       1.00      0.97      0.98     16600
           1       0.62      0.95      0.75       820
           2       0.83      0.67      0.74      2543
           3       0.62      0.75      0.68      1350
           4       0.52      0.97      0.67       401

    accuracy                           0.92     21714
   macro avg       0.72      0.86      0.77     21714
weighted avg       0.93      0.92      0.92     21714



[{'0.0': {'precision': 1.0,
   'recall': 0.9615188592322816,
   'f1-score': 0.9803819674806596,
   'support': 1102696.0},
  '1.0': {'precision': 0.6869078307185414,
   'recall': 0.8667564219936453,
   'f1-score': 0.7664226972509393,
   'support': 92845.0},
  '2.0': {'precision': 0.8420760593608501,
   'recall': 0.6581997765984869,
   'f1-score': 0.7388698276218428,
   'support': 244403.0},
  '3.0': {'precision': 0.5597795188241004,
   'recall': 0.760514689952054,
   'f1-score': 0.64488739611248,
   'support': 106161.0},
  '4.0': {'precision': 0.5277597604228481,
   'recall': 0.9796096111481121,
   'f1-score': 0.6859613090700365,
   'support': 39038.0},
  'accuracy': 0.8961853914757217,
  'macro avg': {'precision': 0.7233046338652681,
   'recall': 0.845319871784916,
   'f1-score': 0.7633046395071916,
   'support': 1585143.0},
  'weighted avg': {'precision': 0.9161995553109753,
   'recall': 0.8961853914757217,
   'f1-score': 0.9008930047710214,
   'support': 1585143.0}}]