In [1]:
%load_ext autoreload
%autoreload 2
from hmpai.pytorch.generators import SAT1DataLoader
import xarray as xr
from pathlib import Path
import torch
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate
from hmpai.pytorch.utilities import DEVICE, set_global_seed
from hmpai.normalization import *
import random
import numpy as np
from torchinfo import summary
from tqdm.notebook import tqdm, trange
from hmpai.utilities import print_results

In [2]:
set_global_seed(42)
data_path = Path("../data/sat1/split_stage_data.nc")

dataset = xr.load_dataset(data_path)
train_data, val_data, test_data = split_data_on_participants(dataset, 60, norm_dummy)

In [3]:
train_kwargs = {"logs_path": Path("../logs/")}
model_kwargs = {
    "n_channels": len(dataset.channels),
    "n_samples": len(dataset.samples),
    "n_classes": len(dataset.labels),
}
results = k_fold_cross_validate(
    SAT1Base,
    model_kwargs,
    dataset,
    5,
    normalization_fn=norm_dummy,
    train_kwargs={
        "logs_path": Path("../logs/cnn_performance"),
        "additional_name": "CNN",
    },
)
print_results(results)

Fold 1: test fold: ['0009' '0017' '0001' '0024' '0012']


  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

  0%|          | 0/997 [00:00<?, ? batch/s]

Fold 1: Accuracy: 0.8105314960629921
Fold 1: F1-Score: 0.8091714533539449
Fold 2: test fold: ['0010' '0014' '0002' '0023' '0006']


  0%|          | 0/995 [00:00<?, ? batch/s]

  0%|          | 0/995 [00:00<?, ? batch/s]

  0%|          | 0/995 [00:00<?, ? batch/s]

  0%|          | 0/995 [00:00<?, ? batch/s]

  0%|          | 0/995 [00:00<?, ? batch/s]

  0%|          | 0/995 [00:00<?, ? batch/s]

  0%|          | 0/995 [00:00<?, ? batch/s]

  0%|          | 0/995 [00:00<?, ? batch/s]

Fold 2: Accuracy: 0.814697265625
Fold 2: F1-Score: 0.8140069794331282
Fold 3: test fold: ['0003' '0013' '0016' '0004' '0005']


  0%|          | 0/1002 [00:00<?, ? batch/s]

  0%|          | 0/1002 [00:00<?, ? batch/s]

  0%|          | 0/1002 [00:00<?, ? batch/s]

  0%|          | 0/1002 [00:00<?, ? batch/s]

  0%|          | 0/1002 [00:00<?, ? batch/s]

  0%|          | 0/1002 [00:00<?, ? batch/s]

Fold 3: Accuracy: 0.8024598393574297
Fold 3: F1-Score: 0.8055564202536625
Fold 4: test fold: ['0021' '0018' '0022' '0019' '0025']


  0%|          | 0/1011 [00:00<?, ? batch/s]

  0%|          | 0/1011 [00:00<?, ? batch/s]

  0%|          | 0/1011 [00:00<?, ? batch/s]

  0%|          | 0/1011 [00:00<?, ? batch/s]

  0%|          | 0/1011 [00:00<?, ? batch/s]

  0%|          | 0/1011 [00:00<?, ? batch/s]

  0%|          | 0/1011 [00:00<?, ? batch/s]

Fold 4: Accuracy: 0.8257261410788381
Fold 4: F1-Score: 0.8287114318793218
Fold 5: test fold: ['0008' '0011' '0015' '0020' '0007']


  0%|          | 0/1001 [00:00<?, ? batch/s]

  0%|          | 0/1001 [00:00<?, ? batch/s]

  0%|          | 0/1001 [00:00<?, ? batch/s]

  0%|          | 0/1001 [00:00<?, ? batch/s]

  0%|          | 0/1001 [00:00<?, ? batch/s]

  0%|          | 0/1001 [00:00<?, ? batch/s]

  0%|          | 0/1001 [00:00<?, ? batch/s]

  0%|          | 0/1001 [00:00<?, ? batch/s]

Fold 5: Accuracy: 0.8215
Fold 5: F1-Score: 0.8224701977486821
Average Accuracy: 0.8149829484248521
Average F1-Score: 0.8159832965337479


In [16]:
train_and_test(model, train_data, test_data, val_data, logs_path=Path("../logs/"))

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

  0%|          | 0/757 [00:00<?, ? batch/s]

{'confirmation': {'precision': 0.8700440528634361,
  'recall': 0.8605664488017429,
  'f1-score': 0.8652792990142387,
  'support': 459.0},
 'decision': {'precision': 0.8762254901960784,
  'recall': 0.8024691358024691,
  'f1-score': 0.837727006444054,
  'support': 891.0},
 'encoding': {'precision': 0.8008948545861297,
  'recall': 0.8035914702581369,
  'f1-score': 0.8022408963585435,
  'support': 891.0},
 'pre-attentive': {'precision': 0.7385786802030457,
  'recall': 0.6830985915492958,
  'f1-score': 0.7097560975609758,
  'support': 852.0},
 'response': {'precision': 0.7936046511627907,
  'recall': 0.9191919191919192,
  'f1-score': 0.8517940717628705,
  'support': 891.0},
 'accuracy': 0.8099899598393574,
 'macro avg': {'precision': 0.8158695458022962,
  'recall': 0.8137835131207127,
  'f1-score': 0.8133594742281364,
  'support': 3984.0},
 'weighted avg': {'precision': 0.8107518140522856,
  'recall': 0.8099899598393574,
  'f1-score': 0.8087438033903912,
  'support': 3984.0}}

In [None]:
summary(model, (16, 1, 147, 30))