In [1]:
%load_ext autoreload
%autoreload 2
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate
from hmpai.pytorch.utilities import DEVICE, set_global_seed
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results

In [4]:
set_global_seed(42)
data_path = Path("../data/sat1/split_stage_data.nc")
# data_path = Path("../data/sat1/split_stage_data_unprocessed_500hz.nc")

dataset = xr.load_dataset(data_path)
train_data, val_data, test_data = split_data_on_participants(dataset, 60, norm_dummy)

In [5]:
train_kwargs = {"logs_path": Path("../logs/")}
model_kwargs = {
    "n_channels": len(dataset.channels),
    "n_samples": len(dataset.samples),
    "n_classes": len(dataset.labels),
}
results = k_fold_cross_validate(
    SAT1Base,
    model_kwargs,
    dataset,
    5,
    batch_size=128,
    normalization_fn=norm_dummy,
    train_kwargs={
        "logs_path": Path("../logs"),
        "additional_name": "CVtest",
    },
)
print_results(results)
# Size: 128
# 46s/epoch

# Size: 64
# 51s/epoch

# Size: 32
# 59s/epoch

# Size: 16
# 1:20/epoch

Fold 1: test fold: ['0009' '0017' '0001' '0024' '0012']


  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

Fold 1: Accuracy: 0.8206388206388207
Fold 1: F1-Score: 0.817760061208966
Fold 2: test fold: ['0010' '0014' '0002' '0023' '0006']


  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

  0%|          | 0/125 [00:00<?, ? batch/s]

Fold 2: Accuracy: 0.8325219084712756
Fold 2: F1-Score: 0.8307538525782265
Fold 3: test fold: ['0003' '0013' '0016' '0004' '0005']


  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

Fold 3: Accuracy: 0.8104656985478217
Fold 3: F1-Score: 0.8110666703304543
Fold 4: test fold: ['0021' '0018' '0022' '0019' '0025']


  0%|          | 0/127 [00:00<?, ? batch/s]

  0%|          | 0/127 [00:00<?, ? batch/s]

  0%|          | 0/127 [00:00<?, ? batch/s]

  0%|          | 0/127 [00:00<?, ? batch/s]

  0%|          | 0/127 [00:00<?, ? batch/s]

  0%|          | 0/127 [00:00<?, ? batch/s]

  0%|          | 0/127 [00:00<?, ? batch/s]

  0%|          | 0/127 [00:00<?, ? batch/s]

Fold 4: Accuracy: 0.8399896265560166
Fold 4: F1-Score: 0.8426208995169441
Fold 5: test fold: ['0008' '0011' '0015' '0020' '0007']


  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

  0%|          | 0/126 [00:00<?, ? batch/s]

Fold 5: Accuracy: 0.8147315855181023
Fold 5: F1-Score: 0.8128425057588011
Average Accuracy: 0.8236695279464075
Average F1-Score: 0.8230087978786784


In [7]:
model = SAT1Base(
    len(train_data.channels), len(train_data.samples), len(train_data.labels)
)
train_and_test(model, train_data, test_data, val_data, logs_path=Path("../logs/"))

  0%|          | 0/758 [00:00<?, ? batch/s]

  0%|          | 0/758 [00:00<?, ? batch/s]

KeyboardInterrupt: 

In [None]:
summary(model, (16, 1, 147, 30))