### Experiment: Generalization


**Question**: How well does each of the models (CNN, GRU, Transformer) generalize to other datasets?

**Hypothesis**: The models differ in generalizability, ranging from none at all to above-chance performance

**Result**:

How to perform this?
- Take model trained on 100 Hz SAT1
- Test on test-set
- Test on entire set for SAT2 and AR:
`train_data, val_data, test_data = split_data_on_participants(
    dataset, 100, norm_min1_to_1
)`

In [1]:
%load_ext autoreload
%autoreload 2
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT1_STAGES_SPEED
from hmpai.visualization import plot_confusion_matrix
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
# from braindecode.models.eegconformer import EEGConformer
from mne.io import read_info
import gc
import json

In [2]:
set_global_seed(42)

# data_path = Path("../data/sat1/split_stage_data_100hz.nc")
data_path = Path("../data/sat2/split_stage_data.nc")

dataset = xr.load_dataset(data_path)

In [3]:
# Split into acc and speed
# Use same participants for speed and acc
dataset_acc = dataset.where(dataset.event_name.str.contains("accuracy"), drop=True)
dataset_sp = dataset.where(dataset.event_name.str.contains("speed"), drop=True)
del dataset

In [4]:
def test_generalization(
    model_fn, model_kwargs, data, additional_test_data, additional_train_kwargs=None
):
    print(f"Testing model: {model_fn.__name__}")
    train_kwargs = {
        "logs_path": Path("../logs/exp_generalization_condition/"),
        "additional_info": {
            "model_fn": model_fn.__name__,
            "model_kwargs": str(model_kwargs),
        },
        "additional_name": f"model_fn-{model_fn.__name__}",
        "labels": SAT1_STAGES_ACCURACY,
    }
    if additional_train_kwargs is not None:
        train_kwargs.update(additional_train_kwargs)
    result = k_fold_cross_validate(
        model_fn,
        model_kwargs,
        data,
        k=len(data.participant),
        normalization_fn=norm_min1_to_1,
        train_kwargs=train_kwargs,
        additional_test_data=additional_test_data,
    )
    with open(
        train_kwargs["logs_path"] / f"results_{model_fn.__name__}.json", "w"
    ) as f:
        json.dump(result, f, indent=4)
    print_results(result)

#### CNN

In [5]:
test_generalization(
    SAT1Base,
    {"n_classes": len(dataset_acc.labels)},
    dataset_acc,
    additional_test_data=[dataset_sp],
    additional_train_kwargs={"weight_decay": 0.001, "label_smoothing": 0.0001},
)

Testing model: SAT1Base
Fold 1: test fold: ['S1']




  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 1: Accuracy: 0.8874900079936051
Fold 1: F1-Score: 0.8880845341864723
Fold 2: test fold: ['S10']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 2: Accuracy: 0.9106973209196276
Fold 2: F1-Score: 0.9099336844859641
Fold 3: test fold: ['S18']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 3: Accuracy: 0.9122304903644343
Fold 3: F1-Score: 0.9121378130583327
Fold 4: test fold: ['S15']




  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 4: Accuracy: 0.8814964983413196
Fold 4: F1-Score: 0.8822853048278854
Fold 5: test fold: ['S12']




  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 5: Accuracy: 0.8710398807305255
Fold 5: F1-Score: 0.8686345446523571
Fold 6: test fold: ['S5']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 6: Accuracy: 0.9053645636569702
Fold 6: F1-Score: 0.9060673578010231
Fold 7: test fold: ['S8']




  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 7: Accuracy: 0.9169625246548323
Fold 7: F1-Score: 0.9159497884059592
Fold 8: test fold: ['S7']




  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 8: Accuracy: 0.8639870077141697
Fold 8: F1-Score: 0.864444019030073
Fold 9: test fold: ['S3']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 9: Accuracy: 0.8706458294106572
Fold 9: F1-Score: 0.8707967180387142
Fold 10: test fold: ['S11']




  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 10: Accuracy: 0.873080204778157
Fold 10: F1-Score: 0.8738556790307236
Fold 11: test fold: ['S2']




  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 11: Accuracy: 0.8451939291736931
Fold 11: F1-Score: 0.8467775805356916
Fold 12: test fold: ['S9']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 12: Accuracy: 0.940154804606381
Fold 12: F1-Score: 0.9393356074011244
Fold 13: test fold: ['S13']




  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 13: Accuracy: 0.9337304542069993
Fold 13: F1-Score: 0.9339871139427658
Fold 14: test fold: ['S4']




  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 14: Accuracy: 0.9161142857142857
Fold 14: F1-Score: 0.9168522295728412
Fold 15: test fold: ['S17']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 15: Accuracy: 0.9455194439226
Fold 15: F1-Score: 0.9457067104153303
Fold 16: test fold: ['S20']




  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 16: Accuracy: 0.911491935483871
Fold 16: F1-Score: 0.9113122727704976
Fold 17: test fold: ['S6']




  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 17: Accuracy: 0.9189303011160245
Fold 17: F1-Score: 0.9195166981738137
Fold 18: test fold: ['S16']




  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

Fold 18: Accuracy: 0.8689155535431288
Fold 18: F1-Score: 0.869425349710372
Test set 0
Accuracies
[0.8874900079936051, 0.9106973209196276, 0.9122304903644343, 0.8814964983413196, 0.8710398807305255, 0.9053645636569702, 0.9169625246548323, 0.8639870077141697, 0.8706458294106572, 0.873080204778157, 0.8451939291736931, 0.940154804606381, 0.9337304542069993, 0.9161142857142857, 0.9455194439226, 0.911491935483871, 0.9189303011160245, 0.8689155535431288]
F1-Scores
[0.8880845341864723, 0.9099336844859641, 0.9121378130583327, 0.8822853048278854, 0.8686345446523571, 0.9060673578010231, 0.9159497884059592, 0.864444019030073, 0.8707967180387142, 0.8738556790307236, 0.8467775805356916, 0.9393356074011244, 0.9339871139427658, 0.9168522295728412, 0.9457067104153303, 0.9113122727704976, 0.9195166981738137, 0.869425349710372]
Average Accuracy: 0.8985025020184045, std: 0.028189059735665656
Average F1-Score: 0.8986168336688856, std: 0.027982892738533598
Test set 1
Accuracies
[0.9223626515346918, 0.930616

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### GRU

In [6]:
test_generalization(
    SAT1GRU,
    {
        "n_channels": len(dataset_acc.channels),
        "n_samples": len(dataset_acc.samples),
        "n_classes": len(dataset_acc.labels),
    },
    dataset_acc,
    additional_test_data=[dataset_sp],
    additional_train_kwargs={"weight_decay": 0.001, "label_smoothing": 0.0001},
)

Testing model: SAT1GRU
Fold 1: test fold: ['S1']




  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 1: Accuracy: 0.9052757793764988
Fold 1: F1-Score: 0.9060584956254578
Fold 2: test fold: ['S10']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 2: Accuracy: 0.9120273608208246
Fold 2: F1-Score: 0.9114861083482728
Fold 3: test fold: ['S18']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 3: Accuracy: 0.9126120969280672
Fold 3: F1-Score: 0.9120515001977955
Fold 4: test fold: ['S15']




  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 4: Accuracy: 0.8931072613343163
Fold 4: F1-Score: 0.8951465468086797
Fold 5: test fold: ['S12']




  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 5: Accuracy: 0.8950801341781588
Fold 5: F1-Score: 0.8923201690002454
Fold 6: test fold: ['S5']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 6: Accuracy: 0.9191537589724216
Fold 6: F1-Score: 0.9198297170767173
Fold 7: test fold: ['S8']




  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 7: Accuracy: 0.9159763313609467
Fold 7: F1-Score: 0.9146843221512043
Fold 8: test fold: ['S7']




  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 8: Accuracy: 0.8893625659764515
Fold 8: F1-Score: 0.8901758755815458
Fold 9: test fold: ['S3']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 9: Accuracy: 0.8721521370739974
Fold 9: F1-Score: 0.8733514199053249
Fold 10: test fold: ['S11']




  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 10: Accuracy: 0.7391211604095563
Fold 10: F1-Score: 0.7402790517790672
Fold 11: test fold: ['S2']




  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 11: Accuracy: 0.8529510961214165
Fold 11: F1-Score: 0.8550089806067651
Fold 12: test fold: ['S9']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 12: Accuracy: 0.9248631300736265
Fold 12: F1-Score: 0.9235211797187455
Fold 13: test fold: ['S13']




  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 13: Accuracy: 0.9393149664929263
Fold 13: F1-Score: 0.9395014424886583
Fold 14: test fold: ['S4']




  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 14: Accuracy: 0.9309714285714286
Fold 14: F1-Score: 0.9311516938523223
Fold 15: test fold: ['S17']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 15: Accuracy: 0.9398835243283863
Fold 15: F1-Score: 0.9394466527515769
Fold 16: test fold: ['S20']




  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 16: Accuracy: 0.9159274193548387
Fold 16: F1-Score: 0.9151266313551906
Fold 17: test fold: ['S6']




  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 17: Accuracy: 0.9286165508528111
Fold 17: F1-Score: 0.9290876729284999
Fold 18: test fold: ['S16']




  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

Fold 18: Accuracy: 0.9021960606746661
Fold 18: F1-Score: 0.9027902102012744
Test set 0
Accuracies
[0.9052757793764988, 0.9120273608208246, 0.9126120969280672, 0.8931072613343163, 0.8950801341781588, 0.9191537589724216, 0.9159763313609467, 0.8893625659764515, 0.8721521370739974, 0.7391211604095563, 0.8529510961214165, 0.9248631300736265, 0.9393149664929263, 0.9309714285714286, 0.9398835243283863, 0.9159274193548387, 0.9286165508528111, 0.9021960606746661]
F1-Scores
[0.9060584956254578, 0.9114861083482728, 0.9120515001977955, 0.8951465468086797, 0.8923201690002454, 0.9198297170767173, 0.9146843221512043, 0.8901758755815458, 0.8733514199053249, 0.7402790517790672, 0.8550089806067651, 0.9235211797187455, 0.9395014424886583, 0.9311516938523223, 0.9394466527515769, 0.9151266313551906, 0.9290876729284999, 0.9027902102012744]
Average Accuracy: 0.8993662646056299, std: 0.04459296754765765
Average F1-Score: 0.89950098168763, std: 0.044136894244405575
Test set 1
Accuracies
[0.92004126902244, 0.92

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### Transformer

In [8]:
test_generalization(
    TransformerModel,
    {
        "n_features": len(dataset_acc.channels),
        "n_heads": 10,
        "ff_dim": 512,
        "n_layers": 6,
        "n_samples": len(dataset_acc.samples),
        "n_classes": len(dataset_acc.labels),
    },
    dataset_acc,
    additional_test_data=[dataset_sp],
    additional_train_kwargs={"weight_decay": 0.001, "label_smoothing": 0.0001},
)

Testing model: TransformerModel
Fold 1: test fold: ['S1']




  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 1: Accuracy: 0.8677058353317346
Fold 1: F1-Score: 0.8689372900374132
Fold 2: test fold: ['S10']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 2: Accuracy: 0.9002470074102223
Fold 2: F1-Score: 0.9001277760402411
Fold 3: test fold: ['S18']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 3: Accuracy: 0.8784583094829231
Fold 3: F1-Score: 0.8781092890157616
Fold 4: test fold: ['S15']




  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 4: Accuracy: 0.8875783265757464
Fold 4: F1-Score: 0.8887251604606015
Fold 5: test fold: ['S12']




  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 5: Accuracy: 0.8317182258665673
Fold 5: F1-Score: 0.8241451596932088
Fold 6: test fold: ['S5']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 6: Accuracy: 0.8675859463543635
Fold 6: F1-Score: 0.8691268921871007
Fold 7: test fold: ['S8']




  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 7: Accuracy: 0.8593688362919132
Fold 7: F1-Score: 0.8569738175833346
Fold 8: test fold: ['S7']




  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  0%|          | 0/658 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 8: Accuracy: 0.8223710921640276
Fold 8: F1-Score: 0.8223208940280223
Fold 9: test fold: ['S3']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 9: Accuracy: 0.8514404067030691
Fold 9: F1-Score: 0.8516950330657533
Fold 10: test fold: ['S11']




  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  0%|          | 0/660 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 10: Accuracy: 0.6819539249146758
Fold 10: F1-Score: 0.6217590937764896
Fold 11: test fold: ['S2']




  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  0%|          | 0/673 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 11: Accuracy: 0.8263069139966274
Fold 11: F1-Score: 0.8280214236628518
Fold 12: test fold: ['S9']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 12: Accuracy: 0.8763451010005664
Fold 12: F1-Score: 0.8727961606486361
Fold 13: test fold: ['S13']




  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  0%|          | 0/654 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 13: Accuracy: 0.9180938198064036
Fold 13: F1-Score: 0.9188519055576905
Fold 14: test fold: ['S4']




  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 14: Accuracy: 0.9138285714285714
Fold 14: F1-Score: 0.9140924833061861
Fold 15: test fold: ['S17']




  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  0%|          | 0/655 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 15: Accuracy: 0.9177155739244787
Fold 15: F1-Score: 0.9176603329151355
Fold 16: test fold: ['S20']




  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  0%|          | 0/657 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 16: Accuracy: 0.8602822580645161
Fold 16: F1-Score: 0.8578540636618808
Fold 17: test fold: ['S6']




  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  0%|          | 0/659 [00:00<?, ? batch/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Fold 17: Accuracy: 0.8949252474205096
Fold 17: F1-Score: 0.8957739891746572
Fold 18: test fold: ['S16']




  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

  0%|          | 0/662 [00:00<?, ? batch/s]

Fold 18: Accuracy: 0.8632556033506905
Fold 18: F1-Score: 0.8638125038123426
Test set 0
Accuracies
[0.8677058353317346, 0.9002470074102223, 0.8784583094829231, 0.8875783265757464, 0.8317182258665673, 0.8675859463543635, 0.8593688362919132, 0.8223710921640276, 0.8514404067030691, 0.6819539249146758, 0.8263069139966274, 0.8763451010005664, 0.9180938198064036, 0.9138285714285714, 0.9177155739244787, 0.8602822580645161, 0.8949252474205096, 0.8632556033506905]
F1-Scores
[0.8689372900374132, 0.9001277760402411, 0.8781092890157616, 0.8887251604606015, 0.8241451596932088, 0.8691268921871007, 0.8569738175833346, 0.8223208940280223, 0.8516950330657533, 0.6217590937764896, 0.8280214236628518, 0.8727961606486361, 0.9188519055576905, 0.9140924833061861, 0.9176603329151355, 0.8578540636618808, 0.8957739891746572, 0.8638125038123426]
Average Accuracy: 0.8621767222270893, std: 0.05216370114336046
Average F1-Score: 0.8583768482570726, std: 0.0643863180399913
Test set 1
Accuracies
[0.8906370905339179, 0.

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
