In [1]:

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt


jax.config.update("jax_enable_x64", True)
jnp.set_printoptions(precision=3)

In [2]:
seed = 0
n = 4 # Number of states
m = 3 # Number of observations

key = jax.random.key(seed)
key, O_key, mu_key = jax.random.split(key, 3)

# Construct a transition matrix that stays in the same state with 99% probability and transitions to the next with 1%
T = 0.5 * jnp.eye(n, dtype=jnp.float64)
T = T + 0.3 * jnp.eye(n, k=1)
T = T + 0.1 * jnp.eye(n, k=-1)
T = T / jnp.sum(T, axis=-1)[:,None]

# Construct an initial state distribution where the earlier states are more likely
mu = jnp.exp(- 5 * jnp.arange(0, n, dtype=jnp.float64))
mu = mu.at[n -1].set(0)
mu = mu / jnp.sum(mu)

# Construct a noisy observation matrix
O = jax.random.uniform(O_key, (n, m), dtype=jnp.float64)
O = O * jnp.linspace(0, 1, n)[:, None]
for s in range(n):
    O = O.at[s, (s * m) // n ].set(.5)
O = O.at[n-1].set(0.0)
O = O.at[n-1, m-1].set(1.0)
O = O / jnp.sum(O, axis=-1)[:,None]

print("T row sums:", jnp.sum(T, axis=-1), 
      "\nO row sums:", jnp.sum(O, axis=-1), 
      "\nmu sum:", jnp.sum(mu)[None])

T row sums: [1. 1. 1. 1.] 
O row sums: [1. 1. 1. 1.] 
mu sum: [1.]


In [3]:
from generation import generate_sequence
from inference import forward_backward, forward_backward_log


states, observations = generate_sequence(jax.random.key(0), T, O, mu, 100)


In [4]:
gamma, xi = forward_backward(observations, T, O, mu)
gamma_log, xi_log = forward_backward_log(observations, T, O, mu)

print("Results match for gamma:", jnp.allclose(jnp.exp(gamma_log), gamma))
print("Results match for xi:", jnp.allclose(jnp.exp(xi_log), xi))

Results match for gamma: True
Results match for xi: True
