### Experiment: Generalization


**Question**: How well does each of the models (CNN, GRU, Transformer) generalize to other datasets?

**Hypothesis**: The models differ in generalizability, ranging from none at all to above-chance performance

**Result**:

How to perform this?
- Take model trained on 100 Hz SAT1
- Test on test-set
- Test on entire set for SAT2 and AR:
`train_data, val_data, test_data = split_data_on_participants(
    dataset, 100, norm_min1_to_1
)`

In [2]:
%load_ext autoreload
%autoreload 2
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.data import SAT1_STAGES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
# from braindecode.models.eegconformer import EEGConformer
from mne.io import read_info

In [3]:
set_global_seed(42)

data_path = Path("../data/sat1/split_stage_data_100hz.nc")
dataset = xr.load_dataset(data_path)

In [None]:
def test_model(model_fn, model_kwargs, data):
    print(f"Testing model: {model_fn.__name__}")
    train_kwargs = {
        "logs_path": Path("../logs/exp_performance/"),
        "additional_info": {
            "model_fn": model_fn.__name__,
            "model_kwargs": model_kwargs,
        },
        "additional_name": f"model_fn-{model_fn.__name__}",
        "labels": SAT1_STAGES_ACCURACY,
    }
    result = k_fold_cross_validate(
        model_fn,
        model_kwargs,
        data,
        k=25,
        normalization_fn=norm_min1_to_1,
        train_kwargs=train_kwargs,
    )
    print_results(result)

#### CNN

In [None]:
test_model(SAT1Base, {"n_classes": len(dataset.labels)})

#### GRU

In [None]:
test_model(SAT1GRU, {"n_channels": len(dataset.channels), "n_samples": len(dataset.samples), "n_classes": len(dataset.labels)})

#### Transformer

In [None]:
test_model()