In [1]:
%load_ext autoreload
%autoreload 2
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, load_model
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, REINDEX_CHANNELS_AR
from torch.utils.data import DataLoader
from hmpai.data import SAT1_STAGES_ACCURACY, AR_STAGES
import scipy

### Load datasets

In [2]:
set_global_seed(42)
data_path = Path("../data/ar/split_stage_data.nc")
dataset = xr.load_dataset(data_path)

In [3]:
# Re-order channels to be as similar as possible to SAT1
dataset = dataset.reindex(channels=REINDEX_CHANNELS_AR).drop_sel(
    channels=["trash1", "trash2"]
)
# Labels in both SAT1 and AR
common_labels = list(set(SAT1_STAGES_ACCURACY).intersection(set(AR_STAGES)))
dataset = dataset.sel(labels=common_labels)

In [7]:
dataset = dataset.isel(samples=slice(None, 161))

In [8]:
shape_topological = False
train_data, val_data, test_data = split_data_on_participants(
    dataset, 60, norm_min1_to_1
)
train_dataset = SAT1Dataset(train_data, shape_topological=shape_topological)
val_dataset = SAT1Dataset(val_data, shape_topological=shape_topological)
test_dataset = SAT1Dataset(test_data, shape_topological=shape_topological)

In [17]:
chk_path = Path("../models/gru100/checkpoint.pt")
checkpoint = load_model(chk_path)

model_kwargs = {
    "n_channels": len(dataset.channels),
    "n_samples": len(dataset.samples),
    "n_classes": len(dataset.labels),
}
model = SAT1GRU(**model_kwargs)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)



In [9]:
chk_path = Path("../models/cnn100/checkpoint.pt")
checkpoint = load_model(chk_path)

model_kwargs = {
    "n_channels": len(dataset.channels),
    "n_samples": len(dataset.samples),
    "n_classes": len(dataset.labels),
}
model = SAT1Base(**model_kwargs)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)



In [10]:
test_loader = DataLoader(
    test_dataset, 128, shuffle=True, num_workers=4, pin_memory=True
)
test(model, test_loader, None)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'0': {'precision': 0.3184110970996217,
  'recall': 0.385643375334097,
  'f1-score': 0.3488171300293559,
  'support': 2619.0},
 '1': {'precision': 0.01386481802426343,
  'recall': 0.008833271991166729,
  'f1-score': 0.01079136690647482,
  'support': 2717.0},
 '2': {'precision': 0.4594663930220626,
  'recall': 0.6591829223408171,
  'f1-score': 0.5414965986394558,
  'support': 2717.0},
 '3': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0.0},
 '4': {'precision': 0.09214501510574018,
  'recall': 0.022451232977548766,
  'f1-score': 0.03610535661438295,
  'support': 2717.0},
 'accuracy': 0.26796657381615596,
 'macro avg': {'precision': 0.17677746465033756,
  'recall': 0.21522216052872595,
  'f1-score': 0.18744209043793386,
  'support': 10770.0},
 'weighted avg': {'precision': 0.22008519682071243,
  'recall': 0.26796657381615596,
  'f1-score': 0.2332606053720014,
  'support': 10770.0}}

In [None]:
test_loader.__next__()