### Experiment: Generalization


**Question**: How well does each of the models (CNN, GRU, Transformer) generalize to other datasets?

**Hypothesis**: The models differ in generalizability, ranging from none at all to above-chance performance

**Result**:

How to perform this?
- Take model trained on 100 Hz SAT1
- Test on test-set
- Test on entire set for SAT2 and AR:
`train_data, val_data, test_data = split_data_on_participants(
    dataset, 100, norm_min1_to_1
)`

In [1]:
%load_ext autoreload
%autoreload 2
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT1_STAGES_SPEED
from hmpai.visualization import plot_confusion_matrix
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
# from braindecode.models.eegconformer import EEGConformer
from mne.io import read_info
import gc
import json

In [2]:
set_global_seed(42)

# data_path = Path("../data/sat1/split_stage_data_100hz.nc")
data_path = Path("../data/sat2/split_stage_data.nc")

dataset = xr.load_dataset(data_path)

In [3]:
# Truncate to SAT1 length
dataset = dataset.sel(samples=slice(0, 161 - 1))

In [4]:
# data_path_sat2 = Path("../data/sat2/split_stage_data.nc")
data_path_sat2 = Path("../data/sat1/split_stage_data_100hz.nc")
dataset_sat2 = xr.load_dataset(data_path_sat2)

test_data_sat2, _, _ = split_data_on_participants(
    dataset_sat2,
    100,
    normalization_fn=norm_min1_to_1,
    truncate_sample=len(dataset.samples),
)
test_dataset_sat2 = SAT1Dataset(test_data_sat2, labels=SAT1_STAGES_ACCURACY)

data_path_ar = Path("../data/ar/split_stage_data_new.nc")
dataset_ar = xr.load_dataset(data_path_ar)

test_data_ar, _, _ = split_data_on_participants(
    dataset_ar,
    100,
    normalization_fn=norm_min1_to_1,
    truncate_sample=len(dataset.samples),
)
# Subset only labels that exist in SAT1_STAGES_SPEED, these are the stages that exist in both SAT and AR experiments
test_data_ar = test_data_ar.sel(labels=SAT1_STAGES_SPEED)
test_data_ar["channels"] = [
    "P8",
    "CP2",
    "P7",
    "FC1",
    "FCz",
    "P4",
    "T8",
    "F7",
    "CP5",
    "T7",
    "Fp2",
    "P3",
    "O1",
    "FC2",
    "FC6",
    "CPz",
    "Fp1",
    "CP1",
    "C3",
    "Cz",
    "F4",
    "F8",
    "CP6",
    "O2",
    "C4",
    "F3",
    "Pz",
    "AFz",
    "Fz",
    "FC5",
]
test_data_ar = test_data_ar.reindex({"channels": dataset.channels.values}, copy=False)
# Accuracy since model is trained on SAT1_STAGES_ACCURACY
test_dataset_ar = SAT1Dataset(test_data_ar, labels=SAT1_STAGES_ACCURACY)
del test_data_ar, test_data_sat2, dataset_ar, dataset_sat2
gc.collect()

63

In [5]:
def test_generalization(
    model_fn, model_kwargs, data, additional_test_data, additional_train_kwargs=None
):
    print(f"Testing model: {model_fn.__name__}")
    train_kwargs = {
        "logs_path": Path("../logs/exp_generalization_datasets/"),
        "additional_info": {
            "model_fn": model_fn.__name__,
            "model_kwargs": str(model_kwargs),
        },
        "additional_name": f"model_fn-{model_fn.__name__}",
        "labels": SAT1_STAGES_ACCURACY,
    }
    if additional_train_kwargs is not None:
        train_kwargs.update(additional_train_kwargs)
    result = k_fold_cross_validate(
        model_fn,
        model_kwargs,
        data,
        k=len(data.participant),
        normalization_fn=norm_min1_to_1,
        train_kwargs=train_kwargs,
        additional_test_data=additional_test_data,
    )
    with open(
        train_kwargs["logs_path"] / f"results_{model_fn.__name__}.json", "w"
    ) as f:
        json.dump(result, f, indent=4)
    print_results(result)

#### CNN

In [6]:
test_generalization(
    SAT1Base,
    {"n_classes": len(dataset.labels)},
    dataset,
    additional_test_data=[test_dataset_sat2, test_dataset_ar],
    additional_train_kwargs={"weight_decay": 0.001, "label_smoothing": 0.0001},
)

Testing model: SAT1Base
Fold 1: test fold: ['S1']




  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 1: Accuracy: 0.8992230604661637
Fold 1: F1-Score: 0.8920526091137259
Fold 2: test fold: ['S10']




  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 2: Accuracy: 0.9028331584470094
Fold 2: F1-Score: 0.8977066297053504
Fold 3: test fold: ['S18']




  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 3: Accuracy: 0.8926952141057934
Fold 3: F1-Score: 0.8830962040050316
Fold 4: test fold: ['S15']




  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 4: Accuracy: 0.908563736149943
Fold 4: F1-Score: 0.9019085798705937
Fold 5: test fold: ['S12']




  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 5: Accuracy: 0.872933159635234
Fold 5: F1-Score: 0.8684208051490483
Fold 6: test fold: ['S5']




  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 6: Accuracy: 0.9127353715326989
Fold 6: F1-Score: 0.9042840407625959
Fold 7: test fold: ['S8']




  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 7: Accuracy: 0.9217947254252769
Fold 7: F1-Score: 0.9213972098923344
Fold 8: test fold: ['S7']




  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 8: Accuracy: 0.8825887743413516
Fold 8: F1-Score: 0.8725097345020949
Fold 9: test fold: ['S3']




  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 9: Accuracy: 0.8831168831168831
Fold 9: F1-Score: 0.8729636136447707
Fold 10: test fold: ['S11']




  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 10: Accuracy: 0.8958451906659078
Fold 10: F1-Score: 0.8891605323848586
Fold 11: test fold: ['S2']




  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 11: Accuracy: 0.8432796174960155
Fold 11: F1-Score: 0.8383056998424838
Fold 12: test fold: ['S9']




  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 12: Accuracy: 0.9324504670978339
Fold 12: F1-Score: 0.9342785713670466
Fold 13: test fold: ['S13']




  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 13: Accuracy: 0.935286445269389
Fold 13: F1-Score: 0.9330954385365551
Fold 14: test fold: ['S4']




  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 14: Accuracy: 0.9357774644276374
Fold 14: F1-Score: 0.9276113155191996
Fold 15: test fold: ['S17']




  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 15: Accuracy: 0.9428785767870539
Fold 15: F1-Score: 0.9386931509265581
Fold 16: test fold: ['S20']




  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 16: Accuracy: 0.8915853935132683
Fold 16: F1-Score: 0.892453413223557
Fold 17: test fold: ['S6']




  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 17: Accuracy: 0.9259380453752182
Fold 17: F1-Score: 0.9178252287032505
Fold 18: test fold: ['S16']




  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

Fold 18: Accuracy: 0.9097866419294991
Fold 18: F1-Score: 0.9023790363589388
Test set 0
Accuracies
[0.8992230604661637, 0.9028331584470094, 0.8926952141057934, 0.908563736149943, 0.872933159635234, 0.9127353715326989, 0.9217947254252769, 0.8825887743413516, 0.8831168831168831, 0.8958451906659078, 0.8432796174960155, 0.9324504670978339, 0.935286445269389, 0.9357774644276374, 0.9428785767870539, 0.8915853935132683, 0.9259380453752182, 0.9097866419294991]
F1-Scores
[0.8920526091137259, 0.8977066297053504, 0.8830962040050316, 0.9019085798705937, 0.8684208051490483, 0.9042840407625959, 0.9213972098923344, 0.8725097345020949, 0.8729636136447707, 0.8891605323848586, 0.8383056998424838, 0.9342785713670466, 0.9330954385365551, 0.9276113155191996, 0.9386931509265581, 0.892453413223557, 0.9178252287032505, 0.9023790363589388]
Average Accuracy: 0.9049617736545655, std: 0.02487835078144667
Average F1-Score: 0.8993412118615554, std: 0.02597436025898659
Test set 1
Accuracies
[0.5768115942028985, 0.581

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### GRU

In [7]:
test_generalization(
    SAT1GRU,
    {
        "n_channels": len(dataset.channels),
        "n_samples": len(dataset.samples),
        "n_classes": len(dataset.labels),
    },
    dataset,
    additional_test_data=[test_dataset_sat2, test_dataset_ar],
    additional_train_kwargs={"weight_decay": 0.001, "label_smoothing": 0.0001},
)

Testing model: SAT1GRU
Fold 1: test fold: ['S1']




  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 1: Accuracy: 0.91690124985925
Fold 1: F1-Score: 0.9105951784421353
Fold 2: test fold: ['S10']




  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 2: Accuracy: 0.9055613850996852
Fold 2: F1-Score: 0.9015411158698674
Fold 3: test fold: ['S18']




  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 3: Accuracy: 0.9048866498740554
Fold 3: F1-Score: 0.8984857535157594
Fold 4: test fold: ['S15']




  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 4: Accuracy: 0.9104276690483587
Fold 4: F1-Score: 0.9050528510268411
Fold 5: test fold: ['S12']




  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 5: Accuracy: 0.8859605170858803
Fold 5: F1-Score: 0.8803987149508155
Fold 6: test fold: ['S5']




  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 6: Accuracy: 0.9266045758250658
Fold 6: F1-Score: 0.9207644036306473
Fold 7: test fold: ['S8']




  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 7: Accuracy: 0.9310423564333828
Fold 7: F1-Score: 0.9318816519124189
Fold 8: test fold: ['S7']




  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 8: Accuracy: 0.9008018327605957
Fold 8: F1-Score: 0.8933654231196746
Fold 9: test fold: ['S3']




  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 9: Accuracy: 0.8931948051948052
Fold 9: F1-Score: 0.8848837095571704
Fold 10: test fold: ['S11']




  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 10: Accuracy: 0.8908366533864542
Fold 10: F1-Score: 0.8862439281728165
Fold 11: test fold: ['S2']




  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 11: Accuracy: 0.8523109615725164
Fold 11: F1-Score: 0.8505569227417322
Fold 12: test fold: ['S9']




  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 12: Accuracy: 0.9279334770557438
Fold 12: F1-Score: 0.9299048146009993
Fold 13: test fold: ['S13']




  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 13: Accuracy: 0.9386977024179793
Fold 13: F1-Score: 0.9391194562940663
Fold 14: test fold: ['S4']




  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 14: Accuracy: 0.9387258043840533
Fold 14: F1-Score: 0.9320719504920707
Fold 15: test fold: ['S17']




  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 15: Accuracy: 0.9408423534455043
Fold 15: F1-Score: 0.9376737319371602
Fold 16: test fold: ['S20']




  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 16: Accuracy: 0.9002041279201634
Fold 16: F1-Score: 0.9024602742991913
Fold 17: test fold: ['S6']




  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 17: Accuracy: 0.9261561954624782
Fold 17: F1-Score: 0.9207259921513661
Fold 18: test fold: ['S16']




  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

Fold 18: Accuracy: 0.9203385899814471
Fold 18: F1-Score: 0.9152851606386978
Test set 0
Accuracies
[0.91690124985925, 0.9055613850996852, 0.9048866498740554, 0.9104276690483587, 0.8859605170858803, 0.9266045758250658, 0.9310423564333828, 0.9008018327605957, 0.8931948051948052, 0.8908366533864542, 0.8523109615725164, 0.9279334770557438, 0.9386977024179793, 0.9387258043840533, 0.9408423534455043, 0.9002041279201634, 0.9261561954624782, 0.9203385899814471]
F1-Scores
[0.9105951784421353, 0.9015411158698674, 0.8984857535157594, 0.9050528510268411, 0.8803987149508155, 0.9207644036306473, 0.9318816519124189, 0.8933654231196746, 0.8848837095571704, 0.8862439281728165, 0.8505569227417322, 0.9299048146009993, 0.9391194562940663, 0.9320719504920707, 0.9376737319371602, 0.9024602742991913, 0.9207259921513661, 0.9152851606386978]
Average Accuracy: 0.9117459392670789, std: 0.022089359134384673
Average F1-Score: 0.9078339462974129, std: 0.022752005720570886
Test set 1
Accuracies
[0.582117415868337, 0.

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### Transformer

In [13]:
test_generalization(
    TransformerModel,
    {
        "n_features": len(dataset.channels),
        "n_heads": 10,
        "ff_dim": 512,
        "n_layers": 6,
        "n_samples": len(dataset.samples),
        "n_classes": len(dataset.labels),
    },
    dataset,
    additional_test_data=[test_dataset_sat2, test_dataset_ar],
    additional_train_kwargs={"weight_decay": 0.001, "label_smoothing": 0.0001},
)

Testing model: TransformerModel
Fold 1: test fold: ['S1']




  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  0%|          | 0/1203 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 1: Accuracy: 0.886386668167999
Fold 1: F1-Score: 0.8762339007490066
Fold 2: test fold: ['S10']




  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 2: Accuracy: 0.8782791185729276
Fold 2: F1-Score: 0.8742333953580277
Fold 3: test fold: ['S18']




  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 3: Accuracy: 0.8604534005037784
Fold 3: F1-Score: 0.8510037324305741
Fold 4: test fold: ['S15']




  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 4: Accuracy: 0.8906492699596148
Fold 4: F1-Score: 0.8832034387714248
Fold 5: test fold: ['S12']




  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 5: Accuracy: 0.8232287804389218
Fold 5: F1-Score: 0.814470388185503
Fold 6: test fold: ['S5']




  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  0%|          | 0/1196 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 6: Accuracy: 0.811399068637376
Fold 6: F1-Score: 0.7590760128640459
Fold 7: test fold: ['S8']




  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 7: Accuracy: 0.897362712638429
Fold 7: F1-Score: 0.8988219477363147
Fold 8: test fold: ['S7']




  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 8: Accuracy: 0.8560137457044673
Fold 8: F1-Score: 0.8418563708180802
Fold 9: test fold: ['S3']




  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  0%|          | 0/1198 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 9: Accuracy: 0.8671168831168831
Fold 9: F1-Score: 0.8554171908556129
Fold 10: test fold: ['S11']




  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 10: Accuracy: 0.8523619806488333
Fold 10: F1-Score: 0.8455959156275927
Fold 11: test fold: ['S2']




  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  0%|          | 0/1229 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 11: Accuracy: 0.8087480077917478
Fold 11: F1-Score: 0.8029170816530071
Fold 12: test fold: ['S9']




  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  0%|          | 0/1197 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 12: Accuracy: 0.8464223385689355
Fold 12: F1-Score: 0.8430289144040138
Fold 13: test fold: ['S13']




  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  0%|          | 0/1195 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 13: Accuracy: 0.9134142670813685
Fold 13: F1-Score: 0.9100934438445274
Fold 14: test fold: ['S4']




  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  0%|          | 0/1212 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 14: Accuracy: 0.9196256890142289
Fold 14: F1-Score: 0.910921671762388
Fold 15: test fold: ['S17']




  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  0%|          | 0/1200 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 15: Accuracy: 0.7912335226663809
Fold 15: F1-Score: 0.733414808319469
Fold 16: test fold: ['S20']




  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  0%|          | 0/1204 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 16: Accuracy: 0.8565434338852348
Fold 16: F1-Score: 0.8550632377024631
Fold 17: test fold: ['S6']




  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  0%|          | 0/1201 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 17: Accuracy: 0.9010689354275742
Fold 17: F1-Score: 0.8927993313221654
Fold 18: test fold: ['S16']




  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

  0%|          | 0/1205 [00:00<?, ? batch/s]

Fold 18: Accuracy: 0.8673469387755102
Fold 18: F1-Score: 0.8597433716168495
Test set 0
Accuracies
[0.886386668167999, 0.8782791185729276, 0.8604534005037784, 0.8906492699596148, 0.8232287804389218, 0.811399068637376, 0.897362712638429, 0.8560137457044673, 0.8671168831168831, 0.8523619806488333, 0.8087480077917478, 0.8464223385689355, 0.9134142670813685, 0.9196256890142289, 0.7912335226663809, 0.8565434338852348, 0.9010689354275742, 0.8673469387755102]
F1-Scores
[0.8762339007490066, 0.8742333953580277, 0.8510037324305741, 0.8832034387714248, 0.814470388185503, 0.7590760128640459, 0.8988219477363147, 0.8418563708180802, 0.8554171908556129, 0.8455959156275927, 0.8029170816530071, 0.8430289144040138, 0.9100934438445274, 0.910921671762388, 0.733414808319469, 0.8550632377024631, 0.8927993313221654, 0.8597433716168495]
Average Accuracy: 0.8626474867555672, std: 0.0355743844634577
Average F1-Score: 0.8504385641122814, std: 0.04686634977488418
Test set 1
Accuracies
[0.4095799557848194, 0.400049

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
