In [11]:
%load_ext autoreload
%autoreload 2
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate
from hmpai.pytorch.utilities import DEVICE, set_global_seed
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D
from torch.utils.data import DataLoader

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
set_global_seed(42)
data_path = Path("../data/sat1/split_stage_data.nc")
# data_path = Path("../data/sat1/split_stage_data_unprocessed_500hz.nc")

dataset = xr.load_dataset(data_path)
train_data, val_data, test_data = split_data_on_participants(dataset, 60, norm_dummy)

In [None]:
train_dataset = SAT1Dataset(train_data)
val_dataset = SAT1Dataset(val_data)
test_dataset = SAT1Dataset(test_data)

In [3]:
train_dataset = SAT1Dataset(train_data, shape_topological=True)
val_dataset = SAT1Dataset(val_data, shape_topological=True)
test_dataset = SAT1Dataset(test_data, shape_topological=True)

In [None]:
torch.cuda.empty_cache()
model = SAT1GRU(
    len(train_data.channels), len(train_data.samples), len(train_data.labels)
)
# Set workers=0 when using debugger
train_and_test(
    model,
    train_dataset,
    test_dataset,
    val_dataset,
    logs_path=Path("../logs/"),
    workers=0,
)

In [21]:
torch.cuda.empty_cache()
height, width = CHANNELS_2D.shape
model = SAT1TopologicalConv(
    width, height, len(train_data.samples), len(train_data.labels)
)
# Set workers=0 when using debugger
train_and_test(
    model,
    train_dataset,
    test_dataset,
    val_dataset,
    logs_path=Path("../logs/"),
    workers=4,
    batch_size=128,
)

  0%|          | 0/95 [00:00<?, ? batch/s]

  0%|          | 0/95 [00:00<?, ? batch/s]

  0%|          | 0/95 [00:00<?, ? batch/s]

  0%|          | 0/95 [00:00<?, ? batch/s]

  0%|          | 0/95 [00:00<?, ? batch/s]

{'0': {'precision': 0.7298701298701299,
  'recall': 0.6596244131455399,
  'f1-score': 0.6929716399506782,
  'support': 852.0},
 '1': {'precision': 0.7516629711751663,
  'recall': 0.7609427609427609,
  'f1-score': 0.7562744004461796,
  'support': 891.0},
 '2': {'precision': 0.8881578947368421,
  'recall': 0.7575757575757576,
  'f1-score': 0.8176862507571169,
  'support': 891.0},
 '3': {'precision': 0.8600917431192661,
  'recall': 0.8169934640522876,
  'f1-score': 0.8379888268156425,
  'support': 459.0},
 '4': {'precision': 0.735663082437276,
  'recall': 0.9214365881032548,
  'f1-score': 0.8181365221723966,
  'support': 891.0},
 'accuracy': 0.7808734939759037,
 'macro avg': {'precision': 0.793089164267736,
  'recall': 0.7833145967639201,
  'f1-score': 0.7846115280284028,
  'support': 3984.0},
 'weighted avg': {'precision': 0.7864426854217635,
  'recall': 0.7808734939759037,
  'f1-score': 0.7797207053775352,
  'support': 3984.0}}

In [None]:
train_kwargs = {"logs_path": Path("../logs/")}
model_kwargs = {
    "n_channels": len(dataset.channels),
    "n_samples": len(dataset.samples),
    "n_classes": len(dataset.labels),
}
results = k_fold_cross_validate(
    SAT1GRU,
    model_kwargs,
    dataset,
    25,
    batch_size=128,
    normalization_fn=norm_dummy,
    train_kwargs={
        "logs_path": Path("../logs/GRU_performance"),
        "additional_name": "GRU",
    },
)
print_results(results)

In [None]:
summary(model, (16, 1, 147, 30))