In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import netCDF4
import xarray as xr
import hmp
from pathlib import Path
from mne.io import read_info
from mne import read_epochs
from hmpai.data import StageFinder
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [None]:
# info_path = DATA_PATH / "sat2/preprocessed_500hz/preprocessed_S1_raw.fif"
info_path = DATA_PATH / "sat2/preprocessed_500hz/S1_epo.fif"
# info_path = DATA_PATH / "sat1/preprocessed/processed_0001_epo.fif"
positions = read_info(info_path)

epoch = read_epochs(info_path)
epoch.set_montage("biosemi64")
positions = epoch.info

#### HMP Fitting

In [None]:
cpus = 4

epoched_data_path = DATA_PATH / "prp/file.nc"

# TODO: Conditions + labels
conditions = ["accuracy", "speed"]
labels = {"accuracy": [], "speed": []}

stage_finder = StageFinder(
    epoched_data_path,
    labels,
    conditions=conditions,
    cpus=cpus,
    fit_function="fit",  # n_events is extracted from the amount of labels in 'conditions' variable
    verbose=False,
    condition_variable="condition",
    condition_method="equal",
    # fits_to_load=["250hz_accuracy_high.nc", "250hz_speed_high.nc"],
    n_comp=10,
    event_width=40,
)
# Retraining model with smaller window
stage_finder.fit_model()

In [None]:
stage_finder.visualize_model(positions, max_time=800)

In [None]:
hmp.utils.save_fit(stage_finder.fits[0], "250hz_1.nc")
hmp.utils.save_fit(stage_finder.fits[1], "250hz_2.nc")

In [None]:
output_path = DATA_PATH / "prp/stage_data_250hz.nc"
stage_data = stage_finder.label_model(
    label_fn=stage_finder.__label_model_probabilistic__, probabilistic=True
)

In [5]:
stage_data.to_netcdf(output_path)