In [None]:
"""Example script for running HMM inference on simulated HMM-MVN data.

This script should take less than a couple minutes to run and
achieve a dice coefficient of ~0.99.
"""

print("Importing packages")

import os
import pickle
import numpy as np

from osl_dynamics.simulation import HMM_MVN
from osl_dynamics.data import Data
from osl_dynamics.models.hmm import Config, Model
from osl_dynamics.inference import modes, metrics

# Create directory for results
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

#%% Simulate data

print("Simulating data")
sim = HMM_MVN(
    n_samples=25600,
    n_states=5,
    n_channels=11,
    trans_prob="sequence",
    stay_prob=0.9,
    means="zero",
    covariances="random",
)

# Create Data object for training
data = Data(sim.time_series)

# Prepare data
data.standardize()

#%% Build model

config = Config(
    n_states=5,
    n_channels=11,
    sequence_length=200,
    learn_means=False,
    learn_covariances=True,
    batch_size=16,
    learning_rate=0.01,
    n_epochs=20,
)

model = Model(config)
model.summary()

#%% Train model

# Initialization
init_history = model.random_state_time_course_initialization(data, n_init=3, n_epochs=2)

# Full training
history = model.fit(data)

# Save model
model_dir = f"{results_dir}/model"
model.save(model_dir)

# Calculate the free energy
free_energy = model.free_energy(data)
history["free_energy"] = free_energy
print("Free energy:", free_energy)

# Save training history and free energy
pickle.dump(init_history, open(f"{model_dir}/init_history.pkl", "wb"))
pickle.dump(history, open(f"{model_dir}/history.pkl", "wb"))

#%% Get inferred parameters

# Inferred state probabilities
alp = model.get_alpha(data)

# Group-level HMM parameters
means, covs = model.get_means_covariances()
initial_state_probs = model.get_initial_state_probs()
trans_prob = model.get_trans_prob()

# Save
inf_params_dir = f"{results_dir}/inf_params"
os.makedirs(inf_params_dir, exist_ok=True)

pickle.dump(alp, open(f"{inf_params_dir}/alp.pkl", "wb"))
np.save(f"{inf_params_dir}/means.npy", means)
np.save(f"{inf_params_dir}/covs.npy", covs)
np.save(f"{inf_params_dir}/initial_state_probs.npy", initial_state_probs)
np.save(f"{inf_params_dir}/trans_prob.npy", trans_prob)

#%% Calculate summary statistics

# State time course
stc = modes.argmax_time_courses(alp)

# Calculate summary statistics
fo = modes.fractional_occupancies(stc)
lt = modes.mean_lifetimes(stc)
intv = modes.mean_intervals(stc)
sr = modes.switching_rates(stc)

# Save
summary_stats_dir = f"{results_dir}/summary_stats"
os.makedirs(summary_stats_dir, exist_ok=True)

np.save(f"{summary_stats_dir}/fo.npy", fo)
np.save(f"{summary_stats_dir}/lt.npy", lt)
np.save(f"{summary_stats_dir}/intv.npy", intv)
np.save(f"{summary_stats_dir}/sr.npy", sr)

#%% Compare inferred parameters to ground truth simulation

# Re-order simulated state time courses to match inferred
inf_stc, sim_stc = modes.match_modes(stc, sim.state_time_course)

# Calculate dice coefficient
dice = metrics.dice_coefficient(inf_stc, sim_stc)

print("Dice coefficient:", dice)

Importing packages
Simulating data


Loading files:   0%|          | 0/1 [00:00<?, ?it/s]

Standardize:   0%|          | 0/1 [00:00<?, ?it/s]

2026-01-23 13:13:28 INFO osl-dynamics [inf_mod_base.py:1435:random_state_time_course_initialization]: Random state time course initialization
2026-01-23 13:13:28 INFO osl-dynamics [inf_mod_base.py:1451:random_state_time_course_initialization]: Initialization 0
2026-01-23 13:13:28 INFO osl-dynamics [inf_mod_base.py:1491:set_random_state_time_course_initialization]: Setting random means and covariances


Epoch 1/2


2026-01-23 13:13:28.510158: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 24ms/step - ll_loss: 14.7726 - loss: 14.7726 - learning_rate: 0.0100 - rho: 0.2853
Epoch 2/2
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - ll_loss: 14.1195 - loss: 14.1195 - learning_rate: 0.0090 - rho: 0.1866


2026-01-23 13:13:31 INFO osl-dynamics [inf_mod_base.py:1451:random_state_time_course_initialization]: Initialization 1
2026-01-23 13:13:31 INFO osl-dynamics [inf_mod_base.py:1491:set_random_state_time_course_initialization]: Setting random means and covariances


Epoch 1/2


2026-01-23 13:13:31.940470: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 22ms/step - ll_loss: 14.7481 - loss: 14.7481 - learning_rate: 0.0100 - rho: 0.2853
Epoch 2/2
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - ll_loss: 13.9498 - loss: 13.9498 - learning_rate: 0.0090 - rho: 0.1866


2026-01-23 13:13:34 INFO osl-dynamics [inf_mod_base.py:1451:random_state_time_course_initialization]: Initialization 2
2026-01-23 13:13:34 INFO osl-dynamics [inf_mod_base.py:1491:set_random_state_time_course_initialization]: Setting random means and covariances


Epoch 1/2
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 24ms/step - ll_loss: 14.7725 - loss: 14.7725 - learning_rate: 0.0100 - rho: 0.2853
Epoch 2/2
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - ll_loss: 14.1351 - loss: 14.1351 - learning_rate: 0.0090 - rho: 0.1866


2026-01-23 13:13:37 INFO osl-dynamics [inf_mod_base.py:1477:random_state_time_course_initialization]: Using initialization 1


Epoch 1/20
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 25ms/step - ll_loss: 13.2794 - loss: 13.2794 - learning_rate: 0.0100 - rho: 0.2853
Epoch 2/20
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - ll_loss: 12.7227 - loss: 12.7227 - learning_rate: 0.0090 - rho: 0.1866
Epoch 3/20
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - ll_loss: 12.3052 - loss: 12.3052 - learning_rate: 0.0082 - rho: 0.1436
Epoch 4/20
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - ll_loss: 11.9788 - loss: 11.9788 - learning_rate: 0.0074 - rho: 0.1187
Epoch 5/20
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - ll_loss: 11.7141 - loss: 11.7141 - learning_rate: 0.0067 - rho: 0.1022
Epoch 6/20
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - ll_loss: 11.4932 - loss: 11.4932 - learning_rate: 0.0061 - rho: 0.0904
Epoch 7/20
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [

Getting free energy:   0%|          | 0/8 [00:00<?, ?it/s]

2026-01-23 13:13:45.527477: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2026-01-23 13:13:45 INFO osl-dynamics [inf_mod_base.py:1151:get_alpha]: Getting alpha


Free energy: 10.75791478552967
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 50ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step
Dice coefficient: 0.9986328125
