In [None]:
%load_ext autoreload
%autoreload 2
import xarray as xr
import hmp
from pathlib import Path
from mne.io import read_info
from mne import read_epochs
from hmpai.data import StageFinder, SAT_CLASSES_ACCURACY, SAT_CLASSES_SPEED
import os
from hmpai.behaviour.sat2 import SAT2_SPLITS
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [None]:
info_path = DATA_PATH / "sat2/preprocessed_500hz/S1_epo.fif"
positions = read_info(info_path)

epoch = read_epochs(info_path)
epoch.set_montage('biosemi64')
positions = epoch.info

### Calculate PCA weights over train set

In [None]:
epoched_data_path = DATA_PATH / "sat2/data_250hz.nc"
data = xr.load_dataset(epoched_data_path)
data = data.sel(participant=SAT2_SPLITS[0])

conditions = ["accuracy", "speed"]
labels = {"accuracy": SAT_CLASSES_ACCURACY, "speed": SAT_CLASSES_SPEED}

pca_finder = StageFinder(data, labels, conditions, n_comp=10)
pca_finder.hmp_data_offset.pca_weights

In [None]:
conditions = ["accuracy", "speed"]
labels = {"accuracy": SAT_CLASSES_ACCURACY, "speed": SAT_CLASSES_SPEED}

pca_finder = StageFinder(data, labels, conditions, n_comp=10)
pca_finder.hmp_data_offset.pca_weights.to_netcdf("files/train_pca.nc")

### Fit HMP

In [None]:
cpus = 4 # Amount of cores used (depends on fit_function if multiprocessing is used)

epoched_data_path = DATA_PATH / "sat2/data_250hz.nc"
pca_weights = xr.load_dataarray("files/train_pca.nc")
conditions = ["accuracy", "speed"]
labels = {"accuracy": SAT_CLASSES_ACCURACY, "speed": SAT_CLASSES_SPEED}

stage_finder = StageFinder(
    epoched_data_path,
    labels,
    conditions=conditions,
    cpus=cpus,
    fit_function="fit_single",  # n_events is extracted from the amount of labels in 'conditions' variable
    fit_args={"starting_points": 1},
    verbose=False,
    condition_variable="condition",
    condition_method="equal",
    fits_to_load=["accuracy_250hz.nc", "speed_250hz.nc"], # Comment out if fitting anew
    n_comp=10,
    event_width=45,
    behaviour_path=DATA_PATH / "sat2/behavioural/df_full.csv",
    pca_weights=pca_weights,
)
stage_finder.fit_model()

In [None]:
fig, ax = stage_finder.visualize_model(positions, max_time=800, figsize=(7.09, 2))
fig.savefig("../img/hmp_fit.svg")

In [None]:
hmp.utils.save_fit(stage_finder.fits[0], "fits/accuracy_250hz.nc")
hmp.utils.save_fit(stage_finder.fits[1], "fits/speed_250hz.nc")

In [None]:
output_path = DATA_PATH / "sat2/stage_data_250hz.nc"
stage_data = stage_finder.label_model(label_fn=stage_finder.__label_model_probabilistic__, probabilistic=True)

In [None]:
stage_data.to_netcdf(output_path)

### Event width comparison

In [None]:
info_path = DATA_PATH / "sat2/preprocessed_500hz/S1_epo.fif"
positions = read_info(info_path)

epoch = read_epochs(info_path)
epoch.set_montage('biosemi64')
positions = epoch.info

In [None]:
epoched_data_path = DATA_PATH / "sat2/data_250hz.nc"
pca_weights = xr.load_dataarray("train_pca.nc")

conditions = ["accuracy", "speed"]
labels = {"accuracy": SAT_CLASSES_ACCURACY, "speed": SAT_CLASSES_SPEED}

ew_values = [20, 25, 30, 35, 40, 45, 50, 55, 60]
for event_width in ew_values:
    stage_finder = StageFinder(
        epoched_data_path,
        labels,
        conditions=conditions,
        cpus=cpus,
        fit_function="fit",
        verbose=False,
        condition_variable="condition",
        condition_method="equal",
        n_comp=10,
        event_width=event_width,
        behaviour_path=DATA_PATH / "sat2/behavioural/df_full.csv",
        pca_weights=pca_weights,
    )
    stage_finder.fit_model()
    fig, ax = stage_finder.visualize_model(positions, max_time=800, figsize=(12, 3))
    fig.savefig(f"../img/event_width/hmp_fit_{str(event_width)}.svg")