### Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2
import xarray as xr
import hmp
from pathlib import Path
from mne import read_epochs
from hmpai.data import StageFinder, SAT_CLASSES_ACCURACY, SAT_CLASSES_SPEED
from hmpai.pytorch.utilities import set_global_seed
from hmpai.training import split_participants
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

### Calculate PCA weights over train set

In [None]:
# Get PCA weights
epoched_data_path = DATA_PATH / "sat1/data_250hz.nc"
data = xr.load_dataset(epoched_data_path)

set_global_seed(42)
splits = split_participants([epoched_data_path], train_percentage=50)

data = data.sel(participant=splits[0])

conditions = ["AC", "SP"]
labels = {"AC": SAT_CLASSES_ACCURACY, "SP": SAT_CLASSES_SPEED}

pca_finder = StageFinder(data, labels, conditions, n_comp=10)
pca_finder.hmp_data_offset.pca_weights.to_netcdf("files/train_pca.nc")

### Fit HMP

In [None]:
info_path = DATA_PATH / "sat1/preprocessed_500hz/processed_500Hz_0001_epo.fif"
epoch = read_epochs(info_path)
epoch.set_montage('biosemi64')
positions = epoch.info

In [None]:
# Amount of cores to use in multiprocessing
cpus = 8

epoched_data_path = DATA_PATH / "sat1/data_250hz.nc"
output_path = DATA_PATH / "sat1/stage_data_250hz.nc"
conditions = ["AC", "SP"]
labels = {"AC": SAT_CLASSES_ACCURACY, "SP": SAT_CLASSES_SPEED}
pca_weights = xr.load_dataarray("files/train_pca.nc")

stage_finder = StageFinder(
    epoched_data_path,
    labels,
    conditions=conditions,
    cpus=cpus,
    fit_function="fit_single",  # n_events is extracted from the amount of labels in 'conditions' variable
    fit_args={"starting_points": 1},
    verbose=False,
    condition_variable="cue",
    condition_method="equal",
    n_comp=10,
    event_width=45,
    pca_weights=pca_weights,
    # fits_to_load=["files/fits/accuracy_250hz.nc", "files/fits/speed_250hz.nc"], # Comment out if fitting anew
)
stage_finder.fit_model()

In [None]:
fig, ax = stage_finder.visualize_model(positions, max_time=1000, figsize=(12, 3), set_vlims=False)


In [None]:
# Save models
hmp.utils.save_fit(stage_finder.fits[0], "files/fits/accuracy_250hz.nc")
hmp.utils.save_fit(stage_finder.fits[1], "files/fits/speed_250hz.nc")

In [None]:
output_path = DATA_PATH / "sat1/stage_data_250hz.nc"
stage_data = stage_finder.label_model(label_fn=stage_finder.__label_model_probabilistic__, probabilistic=True)

In [None]:
# Save labelled dataset
stage_data.to_netcdf(output_path)