In [1]:
import abc
from dataclasses import dataclass
import functools
from functools import partial
import itertools
import matplotlib.pyplot as plt
import numpy as np
from typing import Any, Callable, NamedTuple, Optional, Union, Tuple

import jax
import jax.numpy as jnp
from jax import lax, vmap, jit, grad
import jax.random as jr

import distrax
import optax

import jsl

import inspect
import inspect as py_inspect
import rich
from rich import inspect as r_inspect
from rich import print as r_print

In [2]:
# state transition matrix
A = np.array([
    [0.95, 0.05],
    [0.10, 0.90]
])

# observation matrix
B = np.array([
    [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
    [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
])

pi = np.array([0.5, 0.5])

(nstates, nobs) = np.shape(B)

In [3]:
import distrax
from distrax import HMM


hmm = HMM(trans_dist=distrax.Categorical(probs=A),
            init_dist=distrax.Categorical(probs=pi),
            obs_dist=distrax.Categorical(probs=B))

print(hmm)

I0000 00:00:1706202843.073165       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


<distrax._src.utils.hmm.HMM object at 0x7febc2760e20>


sample from the model

In [4]:
seed = 314
n_samples = 300
z_hist, x_hist = hmm.sample(seed=jr.PRNGKey(seed), seq_len=n_samples)

z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60]

print("Printing sample observed/latent...")
print(f"x: {x_hist_str}")
print(f"z: {z_hist_str}")

Printing sample observed/latent...
x: 633665342652353616444236412331351246651613325161656366246242
z: 222222211111111111111111111111111111111222111111112222211111


sample function :

In [None]:
## DO NOT RUN ##

def sample(self,
             *,
             seed: chex.PRNGKey,
             seq_len: int) -> Tuple[chex.Array, chex.Array]:
    """Sample from this HMM.

    Samples an observation of given length according to this
    Hidden Markov Model and gives the sequence of the hidden states
    as well as the observation.

    Args:
      seed: Random key of shape (2,) and dtype uint32.
      seq_len: The length of the observation sequence.

    Returns:
      Tuple of hidden state sequence, and observation sequence.
    """
    rng_key, rng_init = jax.random.split(seed)
    initial_state = self._init_dist.sample(seed=rng_init)

    def draw_state(prev_state, key):
      state = self._trans_dist.sample(seed=key)[prev_state]
      return state, state

    rng_state, rng_obs = jax.random.split(rng_key)
    keys = jax.random.split(rng_state, seq_len - 1)

    # scan -> https://ericmjl.github.io/dl-workshop/02-jax-idioms/02-loopy-carry.html
    _, states = jax.lax.scan(draw_state, initial_state, keys)
    states = jnp.append(initial_state, states)

    def draw_obs(state, key):
      return self._obs_dist.sample(seed=key)[state]

    keys = jax.random.split(rng_obs, seq_len)
    # vmap -> https://ericmjl.github.io/dl-workshop/02-jax-idioms/01-loopless-loops.html
    obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(states, keys)

    return states, obs_seq