In [1]:
%load_ext autoreload
%autoreload 2
import netCDF4
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants, split_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test, calculate_global_class_weights
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset, MultiXArrayDataset, MultiXArrayProbaDataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT_CLASSES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.pytorch.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
# from braindecode.models.eegconformer import EEGConformer
from mne.io import read_info
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [2]:
set_global_seed(42)
data_path_1 = DATA_PATH / "sat2/stage_data_proba_250hz_part1.nc"
data_path_2 = DATA_PATH / "sat2/stage_data_proba_250hz_part2.nc"
data_paths = [data_path_1, data_path_2]
# train_percentage=100 makes test and val 100 as well
splits = split_participants(data_paths, train_percentage=60)
labels = SAT_CLASSES_ACCURACY
window_size = (1, 11)
jiggle = 3

In [2]:
# Split
set_global_seed(42)
data_path_1 = DATA_PATH / "sat2/split_stage_data_250hz_part1.nc"
data_path_2 = DATA_PATH / "sat2/split_stage_data_250hz_part2.nc"
data_paths = [data_path_1, data_path_2]
# train_percentage=100 makes test and val 100 as well
splits = split_participants(data_paths, train_percentage=60)
labels = SAT_CLASSES_ACCURACY

In [4]:
norm_fn = norm_mad_zscore
train_data = MultiXArrayDataset(data_paths, participants_to_keep=splits[0], normalization_fn=norm_fn)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics['class_weights']
test_data = MultiXArrayDataset(data_paths, participants_to_keep=splits[1], normalization_fn=norm_fn, norm_vars=norm_vars)
val_data = MultiXArrayDataset(data_paths, participants_to_keep=splits[2], normalization_fn=norm_fn, norm_vars=norm_vars)

In [3]:

# TODO: Maybe try split stage? Or just proba :-)
# TODO: Negative samples??
norm_fn = norm_mad_zscore
train_data = MultiXArrayProbaDataset(data_paths, participants_to_keep=splits[0], normalization_fn=norm_fn, window_size=window_size, jiggle=jiggle)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics['class_weights']
test_data = MultiXArrayProbaDataset(data_paths, participants_to_keep=splits[1], normalization_fn=norm_fn, norm_vars=norm_vars, window_size=window_size, jiggle=jiggle)
val_data = MultiXArrayProbaDataset(data_paths, participants_to_keep=splits[2], normalization_fn=norm_fn, norm_vars=norm_vars, window_size=window_size, jiggle=jiggle)

  sample_min = np.nanmin(data)
  sample_max = np.nanmax(data)


In [4]:
norm_fn = norm_mad_zscore
train_data = MultiXArrayProbaDataset(data_paths, participants_to_keep=splits[0], normalization_fn=norm_fn, whole_epoch=True)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics['class_weights']
test_data = MultiXArrayProbaDataset(data_paths, participants_to_keep=splits[1], normalization_fn=norm_fn, norm_vars=norm_vars, whole_epoch=True)
val_data = MultiXArrayProbaDataset(data_paths, participants_to_keep=splits[2], normalization_fn=norm_fn, norm_vars=norm_vars, whole_epoch=True)

  sample_min = np.nanmin(data)
  sample_max = np.nanmax(data)


In [5]:
chk_path = Path("../models/tueg_mamba_70G.pt")
checkpoint = load_model(chk_path)

model_kwargs = {
    "embed_dim": 256,
    "n_channels": 19,
    "n_classes": 0,
    "n_layers": 5,
    "global_pool": False,
    "dropout": 0.1,

}
model = MambaModel(**model_kwargs)
# model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)

In [13]:
model_kwargs = {
    "embed_dim": 512,
    "n_channels": 19,
    "n_classes": 5,
    "n_layers": 10,
    "global_pool": False,
    "dropout": 0.1,

}
model = MambaModel(**model_kwargs)

In [9]:
model.pretraining = False
model.linear_out = nn.Linear(256, 5)
model.global_pool = True


In [10]:
train_and_test(
    model,
    train_data,
    test_data,
    val_data,
    logs_path=Path("../logs/"),
    workers=8,
    batch_size=64,
    labels=SAT_CLASSES_ACCURACY,
    lr=0.00001,
    # label_smoothing=0.1,
    # weight_decay=0.01,
    do_spectral_decoupling=False,
    use_class_weights=False,
    class_weights=class_weights,
)
# Save probability distribution model and run/visualize on testset

  0%|          | 0/623 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           1       0.43      0.50      0.46      4099
           2       0.42      0.46      0.44      4099
           3       0.41      0.42      0.41      4043
           4       0.84      0.62      0.71      4099

    accuracy                           0.50     16340
   macro avg       0.53      0.50      0.51     16340
weighted avg       0.53      0.50      0.51     16340



  0%|          | 0/623 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           1       0.47      0.44      0.45      4099
           2       0.43      0.47      0.45      4099
           3       0.39      0.48      0.43      4043
           4       0.85      0.63      0.72      4099

    accuracy                           0.50     16340
   macro avg       0.53      0.50      0.51     16340
weighted avg       0.53      0.50      0.51     16340



  0%|          | 0/623 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           1       0.47      0.43      0.45      4099
           2       0.44      0.48      0.46      4099
           3       0.39      0.49      0.43      4043
           4       0.84      0.63      0.72      4099

    accuracy                           0.51     16340
   macro avg       0.54      0.51      0.51     16340
weighted avg       0.54      0.51      0.52     16340



  0%|          | 0/623 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           1       0.47      0.42      0.45      4099
           2       0.44      0.48      0.46      4099
           3       0.40      0.50      0.44      4043
           4       0.84      0.63      0.72      4099

    accuracy                           0.51     16340
   macro avg       0.54      0.51      0.52     16340
weighted avg       0.54      0.51      0.52     16340



  0%|          | 0/623 [00:00<?, ? batch/s]

              precision    recall  f1-score   support

           1       0.48      0.41      0.44      4099
           2       0.45      0.46      0.46      4099
           3       0.39      0.54      0.45      4043
           4       0.83      0.63      0.72      4099

    accuracy                           0.51     16340
   macro avg       0.54      0.51      0.52     16340
weighted avg       0.54      0.51      0.52     16340



  0%|          | 0/623 [00:00<?, ? batch/s]

KeyboardInterrupt: 