In [1]:
%load_ext autoreload
%autoreload 2
import netCDF4
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants, split_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test, calculate_global_class_weights, prepare_data
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset, MultiXArrayDataset, MultiXArrayProbaDataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT_CLASSES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.pytorch.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
# from braindecode.models.eegconformer import EEGConformer
from mne.io import read_info
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [2]:
paths = [
    DATA_PATH / "sat2/stage_data_proba_250hz_part1.nc",
    DATA_PATH / "sat2/stage_data_proba_250hz_part2.nc",
]
labels = SAT_CLASSES_ACCURACY
model_params = {
    "embed_dim": 256,
    "n_channels": 19,
    "n_classes": len(labels),
    "n_layers": 5,
    "dropout": 0.1,
}
workers = 8
batch_size = 64
lr = 0.0001

# Negative - Both

In [None]:
whole_epoch = True
train_data, test_data, val_data, class_weights = prepare_data(
    paths,
    60,
    norm_mad_zscore,
    labels=labels,
    whole_epoch=whole_epoch,
    add_negative=True,
)
model = MambaModel(**model_params, global_pool=False)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/paper1"),
    workers=workers,
    batch_size=batch_size,
    labels=labels,
    lr=lr,
    do_spectral_decoupling=False,
    use_class_weights=False,
    class_weights=class_weights,
    whole_epoch=whole_epoch,
)

# Negative - AC

In [3]:
whole_epoch = True
train_data, test_data, val_data, class_weights = prepare_data(
    paths,
    60,
    norm_mad_zscore,
    labels=labels,
    whole_epoch=whole_epoch,
    add_negative=True,
    subset_cond="accuracy",
)
model = MambaModel(**model_params, global_pool=False)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/paper1"),
    workers=workers,
    batch_size=batch_size,
    labels=labels,
    lr=lr,
    do_spectral_decoupling=False,
    use_class_weights=False,
    class_weights=class_weights,
    whole_epoch=whole_epoch,
)

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

[]

# No negative - Both

In [3]:
whole_epoch = True
train_data, test_data, val_data, class_weights = prepare_data(
    paths,
    60,
    norm_mad_zscore,
    labels=labels,
    whole_epoch=whole_epoch,
    add_negative=False,
    subset_cond=None,
)
model = MambaModel(**model_params, global_pool=False)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/paper1"),
    workers=workers,
    batch_size=batch_size,
    labels=labels,
    lr=lr,
    do_spectral_decoupling=False,
    use_class_weights=False,
    class_weights=class_weights,
    whole_epoch=whole_epoch,
)

  0%|          | 0/323 [00:00<?, ? batch/s]

  0%|          | 0/323 [00:00<?, ? batch/s]

  0%|          | 0/323 [00:00<?, ? batch/s]

  0%|          | 0/323 [00:00<?, ? batch/s]

  0%|          | 0/323 [00:00<?, ? batch/s]

[]

# No negative - AC

In [3]:
whole_epoch = True
train_data, test_data, val_data, class_weights = prepare_data(
    paths,
    60,
    norm_mad_zscore,
    labels=labels,
    whole_epoch=whole_epoch,
    add_negative=False,
    subset_cond="accuracy",
)
model = MambaModel(**model_params, global_pool=False)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/paper1"),
    workers=workers,
    batch_size=batch_size,
    labels=labels,
    lr=lr,
    do_spectral_decoupling=False,
    use_class_weights=False,
    class_weights=class_weights,
    whole_epoch=whole_epoch,
)

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

  0%|          | 0/167 [00:00<?, ? batch/s]

[]

# Sliding window - categorical - Both

In [5]:
whole_epoch = False
train_data, test_data, val_data, class_weights = prepare_data(
    paths,
    60,
    norm_mad_zscore,
    labels=labels,
    whole_epoch=whole_epoch,
    add_negative=False,
    subset_cond=None,
    window_size=(1, 11),
    jiggle=3,
    probabilistic_labels=False,
)
model = MambaModel(**model_params, global_pool=True)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/paper1"),
    workers=workers,
    batch_size=batch_size,
    labels=labels,
    lr=lr,
    do_spectral_decoupling=False,
    use_class_weights=True,
    class_weights=class_weights,
    whole_epoch=whole_epoch,
)

  0%|          | 0/1136 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           1       0.58      0.76      0.66      8787
           2       0.60      0.30      0.40      8787
           3       0.40      0.68      0.50      4425
           4       0.99      0.83      0.91      8787

    accuracy                           0.64     30786
   macro avg       0.64      0.64      0.62     30786
weighted avg       0.68      0.64      0.63     30786



  0%|          | 0/1136 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           1       0.60      0.73      0.66      8787
           2       0.58      0.37      0.45      8787
           3       0.41      0.61      0.49      4425
           4       0.96      0.86      0.91      8787

    accuracy                           0.65     30786
   macro avg       0.64      0.64      0.63     30786
weighted avg       0.67      0.65      0.65     30786



  0%|          | 0/1136 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           1       0.56      0.80      0.66      8787
           2       0.56      0.29      0.38      8787
           3       0.42      0.53      0.47      4425
           4       0.95      0.87      0.91      8787

    accuracy                           0.64     30786
   macro avg       0.62      0.62      0.61     30786
weighted avg       0.65      0.64      0.63     30786



  0%|          | 0/1136 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           1       0.59      0.74      0.65      8787
           2       0.57      0.35      0.44      8787
           3       0.42      0.57      0.48      4425
           4       0.95      0.88      0.92      8787

    accuracy                           0.65     30786
   macro avg       0.63      0.64      0.62     30786
weighted avg       0.66      0.65      0.64     30786



[{'1.0': {'precision': 0.6495055561219288,
   'recall': 0.6968172372306682,
   'f1-score': 0.6723300970873787,
   'support': 9143.0},
  '2.0': {'precision': 0.6213510083204061,
   'recall': 0.4818987203324948,
   'f1-score': 0.5428113835160774,
   'support': 9143.0},
  '3.0': {'precision': 0.42442860653723274,
   'recall': 0.7447175506683915,
   'f1-score': 0.5407013149655604,
   'support': 4638.0},
  '4.0': {'precision': 0.9900412576468914,
   'recall': 0.7611287323635568,
   'f1-score': 0.8606232995300519,
   'support': 9143.0},
  'accuracy': 0.6608039417469673,
  'macro avg': {'precision': 0.6713316071566148,
   'recall': 0.6711405601487778,
   'f1-score': 0.6541165237747671,
   'support': 32067.0},
  'weighted avg': {'precision': 0.7060182949599738,
   'recall': 0.6608039417469673,
   'f1-score': 0.6700498981373666,
   'support': 32067.0}}]

# Sliding window - categorical - AC

In [4]:
whole_epoch = False
train_data, test_data, val_data, class_weights = prepare_data(
    paths,
    60,
    norm_mad_zscore,
    labels=labels,
    whole_epoch=whole_epoch,
    add_negative=False,
    subset_cond="accuracy",
    window_size=(1, 11),
    jiggle=3,
    probabilistic_labels=False,
)
model = MambaModel(**model_params, global_pool=True)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/paper1"),
    workers=workers,
    batch_size=batch_size,
    labels=labels,
    lr=lr,
    do_spectral_decoupling=False,
    use_class_weights=True,
    class_weights=class_weights,
    whole_epoch=whole_epoch,
)

KeyboardInterrupt: 

# Sliding window - proba - Both

In [3]:
whole_epoch = False
train_data, test_data, val_data, class_weights = prepare_data(
    paths,
    60,
    norm_mad_zscore,
    labels=labels,
    whole_epoch=whole_epoch,
    add_negative=False,
    subset_cond=None,
    window_size=(1, 11),
    jiggle=3,
    probabilistic_labels=True,
)
model = MambaModel(**model_params, global_pool=False)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/paper1"),
    workers=workers,
    batch_size=batch_size,
    labels=labels,
    lr=lr,
    do_spectral_decoupling=False,
    use_class_weights=False,
    class_weights=class_weights,
    whole_epoch=whole_epoch,
    probabilistic_labels=True,
)

  0%|          | 0/1136 [00:00<?, ? batch/s]

  0%|          | 0/1136 [00:00<?, ? batch/s]

  0%|          | 0/1136 [00:00<?, ? batch/s]

  0%|          | 0/1136 [00:00<?, ? batch/s]

  0%|          | 0/1136 [00:00<?, ? batch/s]

[]

# Sliding window - proba - AC

In [5]:
whole_epoch = False
train_data, test_data, val_data, class_weights = prepare_data(
    paths,
    60,
    norm_mad_zscore,
    labels=labels,
    whole_epoch=whole_epoch,
    add_negative=False,
    subset_cond="accuracy",
    window_size=(1, 11),
    jiggle=3,
    probabilistic_labels=True,
)
model = MambaModel(**model_params, global_pool=False)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/paper1"),
    workers=workers,
    batch_size=batch_size,
    labels=labels,
    lr=lr,
    do_spectral_decoupling=False,
    use_class_weights=False,
    class_weights=class_weights,
    whole_epoch=whole_epoch,
    probabilistic_labels=True,
)

  0%|          | 0/668 [00:00<?, ? batch/s]

  0%|          | 0/668 [00:00<?, ? batch/s]

  0%|          | 0/668 [00:00<?, ? batch/s]

  0%|          | 0/668 [00:00<?, ? batch/s]

  0%|          | 0/668 [00:00<?, ? batch/s]

[]

# Sliding window - proba (negative) - Both

In [6]:
whole_epoch = False
train_data, test_data, val_data, class_weights = prepare_data(
    paths,
    60,
    norm_mad_zscore,
    labels=labels,
    whole_epoch=whole_epoch,
    add_negative=True,
    subset_cond=None,
    window_size=(1, 11),
    jiggle=3,
    probabilistic_labels=True,
)
model = MambaModel(**model_params, global_pool=False)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/paper1"),
    workers=workers,
    batch_size=batch_size,
    labels=labels,
    lr=lr,
    do_spectral_decoupling=False,
    use_class_weights=False,
    class_weights=class_weights,
    whole_epoch=whole_epoch,
    probabilistic_labels=True,
)

  0%|          | 0/1136 [00:00<?, ? batch/s]

  0%|          | 0/1136 [00:00<?, ? batch/s]

  0%|          | 0/1136 [00:00<?, ? batch/s]

  0%|          | 0/1136 [00:00<?, ? batch/s]

  0%|          | 0/1136 [00:00<?, ? batch/s]

[]

# Sliding window - proba (negative) - AC

In [7]:
whole_epoch = False
train_data, test_data, val_data, class_weights = prepare_data(
    paths,
    60,
    norm_mad_zscore,
    labels=labels,
    whole_epoch=whole_epoch,
    add_negative=True,
    subset_cond="accuracy",
    window_size=(1, 11),
    jiggle=3,
    probabilistic_labels=True,
)
model = MambaModel(**model_params, global_pool=False)
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/paper1"),
    workers=workers,
    batch_size=batch_size,
    labels=labels,
    lr=lr,
    do_spectral_decoupling=False,
    use_class_weights=False,
    class_weights=class_weights,
    whole_epoch=whole_epoch,
    probabilistic_labels=True,
)

  0%|          | 0/668 [00:00<?, ? batch/s]

  0%|          | 0/668 [00:00<?, ? batch/s]

  0%|          | 0/668 [00:00<?, ? batch/s]

  0%|          | 0/668 [00:00<?, ? batch/s]

  0%|          | 0/668 [00:00<?, ? batch/s]

  0%|          | 0/668 [00:00<?, ? batch/s]

[]