In [1]:
import jax
import jax.numpy as jnp

from baum_welch_jax import (
    HiddenMarkovParameters, 
    baum_welch,
    generate_sequence,
)

# Recommended for best performance!
jax.config.update('jax_enable_x64', True)

# Defining a hidden Markov model
T = jnp.array([[0.6, 0.4], [0.1, 0.9]])
O = jnp.array([[0.7, 0.3], [0.0, 1.0]])
mu = jnp.array([1.0, 0.0])

hmm = HiddenMarkovParameters(T, O, mu)

# Generating a sequence
states, observations = generate_sequence(jax.random.key(0), hmm, length=100)

# Estimating model parameters
initial_guess = HiddenMarkovParameters(
    jnp.array([[0.51, 0.49], [0.49, 0.51]]), 
    jnp.array([[0.51, 0.49], [0.49, 0.51]]), 
    jnp.ones_like(mu) / 2)
    
# Run Baumâ€“Welch until convergence
estimation_result = baum_welch(observations, initial_guess)

print('Iterations:', estimation_result.iterations)
print(estimation_result.params.to_prob())

Iterations: 85
T =
[[0.64785338 0.35214662]
 [0.09239689 0.90760311]]

O =
[[6.27292203e-01 3.72707797e-01]
 [7.41241657e-07 9.99999259e-01]]

mu=
[1.00000000e+000 8.74296522e-234]
