In [1]:
%load_ext autoreload
%autoreload 2
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D
from torch.utils.data import DataLoader

### Load datasets

In [2]:
set_global_seed(42)
data_path = Path("../data/sat1/split_stage_data_100hz.nc")
# data_path = Path("../data/sat1/split_stage_data_unprocessed_500hz.nc")

dataset = xr.load_dataset(data_path)

In [3]:
shape_topological = False
train_data, val_data, test_data = split_data_on_participants(dataset, 60, norm_dummy)
train_dataset = SAT1Dataset(train_data, shape_topological=shape_topological)
val_dataset = SAT1Dataset(val_data, shape_topological=shape_topological)
test_dataset = SAT1Dataset(test_data, shape_topological=shape_topological)

#### LSTM (Single run)

In [None]:
model = SAT1LSTM(
    len(train_data.channels), len(train_data.samples), len(train_data.labels)
)

train_and_test(
    model,
    train_dataset,
    test_dataset,
    val_dataset,
    logs_path=Path("../logs/"),
    workers=4,
    batch_size=128,
)

#### GRU (Single run, with manual test/train selection)

In [3]:
model = SAT1GRU(len(dataset.channels), len(dataset.samples), len(dataset.labels))
test_dataset = SAT1Dataset(dataset.sel(participant=["0014"]))
train_dataset = SAT1Dataset(
    dataset.sel(participant=[p for p in dataset.participant.values if p != "0014"])
)
train_and_test(
    model,
    train_dataset,
    test_dataset,
    test_dataset,
    logs_path=Path("../logs/"),
    workers=0,
    batch_size=128,
)



  0%|          | 0/152 [00:00<?, ? batch/s]

  0%|          | 0/152 [00:00<?, ? batch/s]

  0%|          | 0/152 [00:00<?, ? batch/s]

  0%|          | 0/152 [00:00<?, ? batch/s]

  0%|          | 0/152 [00:00<?, ? batch/s]

  0%|          | 0/152 [00:00<?, ? batch/s]

{'0': {'precision': 0.8043478260869565,
  'recall': 0.8705882352941177,
  'f1-score': 0.8361581920903954,
  'support': 170.0},
 '1': {'precision': 0.8974358974358975,
  'recall': 0.8235294117647058,
  'f1-score': 0.8588957055214724,
  'support': 170.0},
 '2': {'precision': 0.9408284023668639,
  'recall': 0.9352941176470588,
  'f1-score': 0.9380530973451328,
  'support': 170.0},
 '3': {'precision': 0.9010989010989011,
  'recall': 0.9213483146067416,
  'f1-score': 0.9111111111111112,
  'support': 89.0},
 '4': {'precision': 0.9230769230769231,
  'recall': 0.9176470588235294,
  'f1-score': 0.9203539823008848,
  'support': 170.0},
 'accuracy': 0.8907672301690507,
 'macro avg': {'precision': 0.8933575900131083,
  'recall': 0.8936814276272307,
  'f1-score': 0.8929144176737992,
  'support': 769.0},
 'weighted avg': {'precision': 0.8925421853343707,
  'recall': 0.8907672301690507,
  'f1-score': 0.8909977308488028,
  'support': 769.0}}

#### CNN (Deep, for 500Hz data, single run)

In [5]:
model = SAT1Deep(
    len(train_data.channels), len(train_data.samples), len(train_data.labels)
)

train_and_test(
    model,
    train_dataset,
    test_dataset,
    val_dataset,
    logs_path=Path("../logs/"),
    workers=4,
)

  0%|          | 0/97 [00:00<?, ? batch/s]

  0%|          | 0/97 [00:00<?, ? batch/s]

  0%|          | 0/97 [00:00<?, ? batch/s]

  0%|          | 0/97 [00:00<?, ? batch/s]

  0%|          | 0/97 [00:00<?, ? batch/s]

KeyboardInterrupt: 

#### CNN (Deep, for topological 500Hz data, single run)

In [None]:
height, width = CHANNELS_2D.shape
model = SAT1TopologicalConv(
    width, height, len(train_data.samples), len(train_data.labels)
)
train_and_test(
    model,
    train_dataset,
    test_dataset,
    val_dataset,
    logs_path=Path("../logs/"),
    workers=4,
)

#### CNN K-Fold CV (Deep)

In [None]:
train_kwargs = {"logs_path": Path("../logs/")}
model_kwargs = {
    "n_channels": len(dataset.channels),
    "n_samples": len(dataset.samples),
    "n_classes": len(dataset.labels),
}
results = k_fold_cross_validate(
    SAT1Deep,
    model_kwargs,
    dataset,
    k=25,
    batch_size=128,
    normalization_fn=norm_dummy,
    train_kwargs={
        "logs_path": Path("../logs/CNN_performance"),
        "additional_name": "CNN_DEEP",
    },
)
print_results(results)

#### GRU K-Fold CV

In [None]:
train_kwargs = {"logs_path": Path("../logs/")}
model_kwargs = {
    "n_channels": len(dataset.channels),
    "n_samples": len(dataset.samples),
    "n_classes": len(dataset.labels),
}
results = k_fold_cross_validate(
    SAT1GRU,
    model_kwargs,
    dataset,
    k=25,
    batch_size=128,
    normalization_fn=norm_dummy,
    train_kwargs={
        "logs_path": Path("../logs/GRU_performance"),
        "additional_name": "GRU",
    },
)
print_results(results)