### Experiment: Sampling Frequency

**Question**: Will a model trained using HMP data from a HMP model fitted on a lower sampling frequency perform worse?

**Hypothesis**: I think so, but the degree of worsening is dependent on the type of model, a CNN will be worsened less than an RNN, as an RNN uses this lost temporal information more.

**Result**:

In [1]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import xarray as xr
from hmpai.utilities import print_results
from hmpai.pytorch.models import *
from hmpai.pytorch.training import k_fold_cross_validate
from hmpai.normalization import *

n_folds = 25
logs_path = Path("../logs/exp_sampling_frequency/")
normalization_fn = norm_min1_to_1

### Part 1: 500 Hz

In [2]:
data_path = Path("../data/sat1/split_stage_data.nc")
data = xr.load_dataset(data_path)

#### CNN

In [3]:
model = SAT1Deep
model_kwargs = {
    "n_channels": len(data.channels),
    "n_samples": len(data.samples),
    "n_classes": len(data.labels),
}
train_kwargs = {
    "logs_path": logs_path,
    "additional_info": {"sampling_frequency": "500hz_cnn"},
    "additional_name": f"500hz_cnn",
}
results = k_fold_cross_validate(
    model,
    model_kwargs,
    data,
    n_folds,
    train_kwargs=train_kwargs,
    normalization_fn=normalization_fn,
)
print_results(results)

Fold 1: test fold: ['0009']




  0%|          | 0/153 [00:00<?, ? batch/s]

KeyboardInterrupt: 

#### RNN (GRU)

In [None]:
model = SAT1GRU
model_kwargs = {
    "n_channels": len(data.channels),
    "n_samples": len(data.samples),
    "n_classes": len(data.labels),
}
train_kwargs = {
    "logs_path": logs_path,
    "additional_info": {"sampling_frequency": "500hz_rnn"},
    "additional_name": f"500hz_rnn",
}
results = k_fold_cross_validate(
    model,
    model_kwargs,
    data,
    n_folds,
    train_kwargs=train_kwargs,
    normalization_fn=normalization_fn,
)
print_results(results)

### Part 2: 100 Hz

In [None]:
data_path = Path("../data/sat1/split_stage_data_100hz.nc")
data = xr.load_dataset(data_path)

#### CNN

In [None]:
model = SAT1Base
model_kwargs = {
    "n_channels": len(data.channels),
    "n_samples": len(data.samples),
    "n_classes": len(data.labels),
}
train_kwargs = {
    "logs_path": logs_path,
    "additional_info": {"sampling_frequency": "100hz_cnn"},
    "additional_name": f"100hz_cnn",
}
results = k_fold_cross_validate(
    model,
    model_kwargs,
    data,
    n_folds,
    train_kwargs=train_kwargs,
    normalization_fn=normalization_fn,
)
print_results(results)

#### RNN (GRU)

In [None]:
model = SAT1GRU
model_kwargs = {
    "n_channels": len(data.channels),
    "n_samples": len(data.samples),
    "n_classes": len(data.labels),
}
train_kwargs = {
    "logs_path": logs_path,
    "additional_info": {"sampling_frequency": "100hz_rnn"},
    "additional_name": f"100hz_rnn",
}
results = k_fold_cross_validate(
    model,
    model_kwargs,
    data,
    n_folds,
    train_kwargs=train_kwargs,
    normalization_fn=normalization_fn,
)
print_results(results)

In [None]:
# View results in Tensorboard
! tensorboard --logdir logs/exp_sampling_frequency/