In [2]:
%load_ext autoreload
%autoreload 2
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset
from hmpai.data import SAT1_STAGES_ACCURACY, COMMON_STAGES
from hmpai.visualization import plot_confusion_matrix
from hmpai.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
from mne.io import read_info
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Train model on SAT2 Dataset

In [14]:
set_global_seed(42)
data_path_sat2 = DATA_PATH / "sat2/window_stage_data_100hz.nc"
dataset_sat2 = xr.load_dataset(data_path_sat2)

In [15]:
shape_topological = False
train_data_sat2, val_data_sat2, test_data_sat2 = split_data_on_participants(
    dataset_sat2, 60, norm_min1_to_1
)

In [16]:
chk_path = Path("../models/sat2_gru_100hz.pt")
checkpoint = load_model(chk_path)

model_kwargs = {
    "n_channels": len(train_data_sat2.channels),
    "n_samples": len(train_data_sat2.samples),
    "n_classes": len(train_data_sat2.labels),
}
model = SAT1GRU(**model_kwargs)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)



In [17]:
train_dataset = SAT1Dataset(
    train_data_sat2, shape_topological=shape_topological, labels=SAT1_STAGES_ACCURACY[1:]
)
# Val and test were not used to train
val_dataset = SAT1Dataset(
    val_data_sat2, shape_topological=shape_topological, labels=SAT1_STAGES_ACCURACY[1:]
)
test_dataset = SAT1Dataset(
    test_data_sat2, shape_topological=shape_topological, labels=SAT1_STAGES_ACCURACY[1:]
)

In [18]:
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)

In [58]:
torch.set_printoptions(sci_mode=False)
window_size = 11
model = model.eval()
# for trial in SAT2 HPT-test set
for batch in test_loader:
    for epoch, true in zip(batch[0], batch[1]):
        print(true)
        # epoch.size() = (250, 30), 250 = samples, 30 = channels
        # print(epoch.size())
        # rt_idx = torch.nonzero(torch.eq(epoch[:,0], MASKING_VALUE))[0,0].item()        
        # slices = [epoch[start: start + window_size] for start in range(rt_idx - window_size + 1)]
        # stacked = torch.stack(slices).to(DEVICE)
        # pred = model(stacked)
        # pred = torch.nn.Softmax(dim=1)(pred)
        # print(pred.size())
        # # print(pred)
        # plt.plot(pred.detach().cpu(), label=SAT1_STAGES_ACCURACY[1:])
        # plt.legend()
        # plt.show()
        # break
        # What is the plot showing? Each X point

tensor([ 0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, -1, -1, -1, -1,
        -1, -1, -1, -1,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, 

KeyboardInterrupt: 