In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import xarray as xr
import hmp
from pathlib import Path
from mne.io import read_info
from mne import read_epochs
from hmpai.data import StageFinder, SAT_CLASSES_ACCURACY
import os
import itertools
import matplotlib.pyplot as plt
from hmpai.visualization import set_seaborn_style
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [None]:
info_path = DATA_PATH / "sat2/preprocessed_500hz/S1_epo.fif"
positions = read_info(info_path)
epoch = read_epochs(info_path)
epoch.set_montage('biosemi64')
positions = epoch.info

data_paths = [DATA_PATH / "sat2/stage_data_250hz_tertile.nc"]
datasets = [xr.open_dataset(path) for path in data_paths]

In [None]:
conditions = np.unique(datasets[0].condition).tolist()
tertiles = np.unique(datasets[0].tertile).tolist()
tertiles.reverse() # (Ordering for figure)
conditions = [cond for cond in conditions if cond != '']
tertiles = [tert for tert in tertiles if not np.isnan(tert)]
combinations = list(itertools.product(conditions, tertiles))
tert_names = ['High', 'Medium', 'Low']

In [None]:
fit_ac = hmp.utils.load_fit("files/fits/accuracy_250hz.nc")
fit_sp = hmp.utils.load_fit("files/fits/speed_250hz.nc")

n_events_to_fix = 1

prepars_ac = fit_ac.parameters.data[:n_events_to_fix]
premags_ac = fit_ac.magnitudes.data[:n_events_to_fix]
prepars_sp = fit_sp.parameters.data[:n_events_to_fix]
premags_sp = fit_sp.magnitudes.data[:n_events_to_fix]

### Fitting and saving models

In [None]:
pca_weights = xr.load_dataarray("files/train_pca.nc")
finder = StageFinder(
    datasets[0],
    labels=SAT_CLASSES_ACCURACY,
    fit_function="fit",
    event_width=45,
    n_comp=10,
    cpus=0,
    conditions=[],
    pca_weights=pca_weights,
)
fig, ax = plt.subplots(3, 2, sharex=True, sharey=False, figsize=(12, 6))
for i, (cond, tert) in enumerate(combinations):
    tert_name = tert_names[tertiles.index(tert)]
    print(f"SAT: {cond}, tertile: {tert_name}")
    if cond == 'accuracy':
        prepars = prepars_ac
        premags = premags_ac
    elif cond == 'speed':
        prepars = prepars_sp
        premags = premags_sp
    else:
        raise ValueError(f'cond {cond} not recognized.')
    # At this point, data is only labelled with tertile for participants in test/val sets
    finder.fit_model(fit_args={"prepars": prepars, "premags": premags}, extra_split=[("condition", "equal", cond), ("tertile", "equal", tert)])

    hmp.utils.save_fit(finder.fits[-1], f"files/fits/{cond}_{str(int(tert))}.nc")

    # Determine the subplot position based on index in a 3x2 grid
    row, col = tertiles.index(tert), conditions.index(cond)
    cur_ax = ax[row, col]
    
    # Visualize model in the current subplot
    finder.visualize_model(positions, max_time=800, ax=cur_ax, colorbar=col == 1, cond_label=None, model_index=i)

    # Label axes
    if row == 2:
        cur_ax.set_xlabel(f"Time (in ms)\n{cond.capitalize()}")
    if col == 0:
        cur_ax.set_ylabel(tert_name)
    cur_ax.set_yticks([0.5], [None])

fig.supylabel("Average confirmation probability tertile")
fig.supxlabel("Condition")
fig.tight_layout()
plt.show()

### Visualization
Takes slightly shorter than fitting all subsets

In [None]:
# Load in all fitted models
fit_path = Path("files/fits/")
fig, ax = plt.subplots(3, 2, dpi=300, sharex=True, sharey=False, figsize=(7.09, 4))  # Adjusted the layout and size for 3x2 grid
finder = StageFinder(
    datasets[0],
    labels=SAT_CLASSES_ACCURACY,
    fit_function="fit",
    event_width=45,
    n_comp=10,
    cpus=0,
    conditions=[],
    pca_weights = xr.load_dataarray("files/train_pca.nc"),
)
set_seaborn_style()
for i, (cond, tert) in enumerate(combinations):
    tert_name = tert_names[tertiles.index(tert)]
    print(f"SAT: {cond}, tertile: {tert_name}")
    finder.fits_to_load = [fit_path / f"{cond}_{str(int(tert))}.nc"]
    finder.fit_model(extra_split=[("condition", "equal", cond), ("tertile", "equal", tert)])
    # Determine the subplot position based on index in a 3x2 grid
    row, col = tertiles.index(tert), conditions.index(cond)
    cur_ax = ax[row, col]
    
    # Visualize model in the current subplot
    finder.visualize_model(positions, max_time=800, ax=cur_ax, colorbar=col == 1, cond_label=None, model_index=i)
    
    # Label axes
    if row == 2:
        cur_ax.set_xlabel(f"Time (in ms)\n{cond.capitalize()}")
    if col == 0:
        cur_ax.set_ylabel(tert_name)
    cur_ax.set_yticks([0.5], [None])

# Set common labels and save figure
fig.supylabel("Average confirmation probability tertile")
fig.supxlabel("Condition")
fig.tight_layout()
fig.savefig("../img/refit_stacked.svg")