In [10]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import xarray as xr
from hmpai.utilities import print_results
from hmpai.pytorch.models import *
from hmpai.pytorch.training import k_fold_cross_validate
from hmpai.normalization import *

n_folds = 25

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Set up data

In [11]:
data_path = Path("../data/sat1/split_stage_data.nc")

data = xr.load_dataset(data_path)

In [12]:
logs_path = Path("../logs/rnn_performance")
model_kwargs = {
    "n_channels": len(data.channels),
    "n_samples": len(data.samples),
    "n_classes": len(data.labels),
}
normalization_fn = norm_min1_to_1

In [None]:
model = SAT1LSTM
train_kwargs = {
    "logs_path": logs_path,
    "additional_name": "LSTM",
}
results = k_fold_cross_validate(
    model,
    model_kwargs,
    data,
    n_folds,
    normalization_fn=normalization_fn,
    train_kwargs=train_kwargs,
)
print_results(results)

In [None]:
model = SAT1GRU
train_kwargs = {
    "logs_path": logs_path,
    "additional_name": "GRU",
}
results = k_fold_cross_validate(
    model,
    model_kwargs,
    data,
    n_folds,
    normalization_fn=normalization_fn,
    train_kwargs=train_kwargs,
)
print_results(results)

In [2]:
data_path = Path("../data/sat1/split_stage_data_unprocessed_100hz.nc")

data = xr.load_dataset(data_path)

In [None]:
# Run unprocessed-100Hz data cell before use
model = SAT1GRU
train_kwargs = {
    "logs_path": logs_path,
    "additional_name": "GRU_unprocessed_100hz",
}
results = k_fold_cross_validate(
    model,
    model_kwargs,
    data,
    n_folds,
    normalization_fn=normalization_fn,
    train_kwargs=train_kwargs,
)
print_results(results)

In [2]:
data_path = Path("../data/sat1/split_stage_data_unprocessed_500hz.nc")

data = xr.load_dataset(data_path)

In [None]:
# Run unprocessed-500Hz data cell before use
model = SAT1GRU
train_kwargs = {
    "logs_path": logs_path,
    "additional_name": "GRU_unprocessed_500hz",
}
results = k_fold_cross_validate(
    model,
    model_kwargs,
    data,
    n_folds,
    normalization_fn=normalization_fn,
    train_kwargs=train_kwargs,
)
print_results(results)