In [None]:
import os
import pickle
import numpy as np

from osl_dynamics.data import Data
from osl_dynamics.models.hmm import Config, Model
from osl_dynamics.models.hmm import Model
from osl_dynamics.utils import plotting
from osl_dynamics.inference import modes, metrics
from preproc_funcs import plot_channel_time_series

data = Data("/Users/podlaskijacek/Documents/HMM-Analysis/s_01_preproc-raw.fif")
print(data)
results_dir= f"/Users/podlaskijacek/Documents/HMM-Analysis/resultaty"

In [None]:
#PCA

methods = {
    "pca": {"n_pca_components": 64},
    "standardize": {},
}
data.prepare(methods)
print(data)
print(data.pca_components.shape)
print(data)

#plot_channel_time_series(data, savebase=None, exclude_bads=False)

In [None]:
# Create a config object
config = Config(
    n_states=6,
    n_channels=data.n_channels,
    sequence_length=200, #splits into random sequences for computational purposes, n=200 sequences
    learn_means=False,
    learn_covariances=True,
    batch_size=64,
    learning_rate=0.01,
    n_epochs=20,
)

In [None]:
model = Model(config)
model.summary()

In [None]:
init_history = model.random_state_time_course_initialization(data, n_epochs=1, n_init=3)
history = model.fit(data)

In [None]:
model.save("results/model")
from osl_dynamics.models import load

model = load("results/model")
print(model)

In [6]:
# Save model
model_dir = f"{results_dir}/model1"
model.save(model_dir)

In [None]:
import pickle

free_energy = model.free_energy(data)
history["free_energy"] = free_energy
pickle.dump(history, open("results/model/history.pkl", "wb"))

In [None]:
# Inferred state probabilities
alp = model.get_alpha(data)

os.makedirs("results/inf_params/", exist_ok=True)
pickle.dump(alp, open("results/inf_params/alp.pkl", "wb"))

In [None]:
from osl_dynamics.utils import plotting

plotting.plot_alpha(alp)

In [10]:
# Group-level HMM parameters
means, covs = model.get_means_covariances()
initial_state_probs = model.get_initial_state_probs()
trans_prob = model.get_trans_prob()