In [None]:
import os
import pickle
import numpy as np
import mne
import matplotlib.pyplot as plt

from osl_dynamics.data import Data
from osl_dynamics.models.hmm import Config, Model
from osl_dynamics.models import load
from osl_dynamics.models.hmm import Model
from osl_dynamics.utils import plotting
from osl_dynamics.inference import modes, metrics
from preproc_funcs import plot_channel_time_series


In [None]:

files = ["/Users/podlaskijacek/Documents/HMM-Analysis/s1_01_preproc-raw.fif", 
    "/Users/podlaskijacek/Documents/HMM-Analysis/s1_02_preproc-raw.fif", 
   "/Users/podlaskijacek/Documents/HMM-Analysis/s1_03_preproc-raw.fif"]

data = Data(
    files,
    picks="meg",
    reject_by_annotation="omit",  # drops BAD segments
)
results_dir= f"/Users/podlaskijacek/Documents/HMM-Analysis/HMM_6.02"
print(data)

In [None]:
#PCA

methods = {
    "pca": {"n_pca_components": 64},
    "standardize": {},
}
data.prepare(methods)
print(data)
print(data.pca_components.shape)
print(data)

In [None]:
# Create a config object
config = Config(
    n_states=6,
    n_channels=data.n_channels,
    sequence_length=200, #splits into random sequences for computational purposes, n=200 sequences
    learn_means=False,
    learn_covariances=True,
    batch_size=64,
    learning_rate=0.01,
    n_epochs=20,
)

In [None]:
model = Model(config)
model.summary()

In [None]:
#Training the HMM, takes A LOT of time
init_history = model.random_state_time_course_initialization(data, n_epochs=5, n_init=20)
history = model.fit(data)
model.save(results_dir)

In [None]:
model = load(results_dir)
print(model)
model.summary()

In [None]:
free_energy = model.free_energy(data)
history["free_energy"] = free_energy
pickle.dump(history, open("{results_dir}/history.pkl", "wb"))

In [None]:
# Inferred state probabilities
alp = model.get_alpha(data)
os.makedirs("{results_dir}/inf_params/", exist_ok=True)
pickle.dump(alp, open("{results_dir}/inf_params/alp.pkl", "wb"))

plotting.plot_alpha(alp[0])

In [None]:
covs = model.get_covariances()
covs_np = np.asarray(covs)
np.shape(covs_np)

In [None]:
plt.plot(covs_np[0])

In [None]:
# Group-level HMM parameters (left these for later)
means, covs = model.get_means_covariances()
initial_state_probs = model.get_initial_state_probs()
trans_prob = model.get_trans_prob()