<a href="https://colab.research.google.com/github/probml/ssm-jax/blob/main/ssm_jax/hmm/demos/casino_hmm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# "Occasionally dishonest Casino" HMM

Based on https://github.com/probml/JSL/blob/main/jsl/demos/hmm_casino.py


# Setup

In [1]:
try:
    import optax
except ModuleNotFoundError:
    %pip install -qq optax
    import optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting optax
  Downloading optax-0.1.2-py3-none-any.whl (140 kB)
[?25l[K     |██▎                             | 10 kB 25.6 MB/s eta 0:00:01[K     |████▋                           | 20 kB 29.6 MB/s eta 0:00:01[K     |███████                         | 30 kB 34.2 MB/s eta 0:00:01[K     |█████████▎                      | 40 kB 35.8 MB/s eta 0:00:01[K     |███████████▋                    | 51 kB 24.8 MB/s eta 0:00:01[K     |██████████████                  | 61 kB 27.1 MB/s eta 0:00:01[K     |████████████████▎               | 71 kB 27.7 MB/s eta 0:00:01[K     |██████████████████▋             | 81 kB 28.7 MB/s eta 0:00:01[K     |█████████████████████           | 92 kB 30.8 MB/s eta 0:00:01[K     |███████████████████████▎        | 102 kB 32.4 MB/s eta 0:00:01[K     |█████████████████████████▋      | 112 kB 32.4 MB/s eta 0:00:01[K     |████████████████████████████    |

In [2]:
try:
    import ssm_jax
except ModuleNotFoundError:
    %pip install -qq git+https://github.com/probml/ssm-jax.git
    import ssm_jax 

[?25l[K     |█▏                              | 10 kB 23.6 MB/s eta 0:00:01[K     |██▍                             | 20 kB 31.0 MB/s eta 0:00:01[K     |███▋                            | 30 kB 30.8 MB/s eta 0:00:01[K     |████▉                           | 40 kB 17.1 MB/s eta 0:00:01[K     |██████                          | 51 kB 14.4 MB/s eta 0:00:01[K     |███████▎                        | 61 kB 16.4 MB/s eta 0:00:01[K     |████████▍                       | 71 kB 17.2 MB/s eta 0:00:01[K     |█████████▋                      | 81 kB 17.2 MB/s eta 0:00:01[K     |██████████▉                     | 92 kB 18.7 MB/s eta 0:00:01[K     |████████████                    | 102 kB 19.2 MB/s eta 0:00:01[K     |█████████████▎                  | 112 kB 19.2 MB/s eta 0:00:01[K     |██████████████▌                 | 122 kB 19.2 MB/s eta 0:00:01[K     |███████████████▋                | 133 kB 19.2 MB/s eta 0:00:01[K     |████████████████▉               | 143 kB 19.2 MB/s eta 0:

In [3]:
import jax.numpy as jnp
import jax.random as jr
from jax import jit, value_and_grad

import matplotlib.pyplot as plt
import numpy as np

from ssm_jax.hmm.models import CategoricalHMM
import ssm_jax.hmm.learning as learning
import ssm_jax.hmm.hmm_plot_utils as plot_utils

In [4]:
# Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect.
# https://github.com/tensorflow/probability/issues/1523
import logging
logger = logging.getLogger() 

class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

# Model

In [17]:
# state transition matrix
A = jnp.array([
    [0.95, 0.05],
    [0.10, 0.90]
])

# observation matrix
B = jnp.array([
    [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6],  # fair die
    [1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 5 / 10]  # loaded die
])


init_state_dist = jnp.array([1, 1]) / 2

hmm = CategoricalHMM(init_state_dist, A, B)

print(hmm)
  


<ssm_jax.hmm.models.CategoricalHMM object at 0x7f1d1f570390>


# Generate samples

In [19]:
n_samples = 300
key = jr.PRNGKey(0)
z_hist, x_hist = hmm.sample(key, n_samples)

z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60]

print("Printing sample observed/latent...")
print(f"obs: {x_hist_str}")
print(f"hid: {z_hist_str}")

Printing sample observed/latent...
obs: 135553452655336631635155152623211211346222126326426542234464
hid: 111111111122221111111111111111111111222222222122221111111111


# Inference

In [9]:
x_hist.shape

(300,)

In [10]:
emissions = jnp.reshape(x_hist, (len(x_hist), 1))
print(emissions.shape)

(300, 1)


In [11]:

posterior = hmm.smoother(emissions)
loglik =  posterior.marginal_log_lkhd
alpha = posterior.filtered_probs
gamma = posterior.smoothed_probs

print(f"Loglikelihood: {loglik}")




TypeError: ignored

In [None]:
z_map = hmm.most_likely_states(emissions)
print(z_map)

# Plot results

In [None]:
def find_dishonest_intervals(z_hist):
    """
    Find the span of timesteps that the
    simulated systems turns to be in state 1
    Parameters
    ----------
    z_hist: array(n_samples)
        Result of running the system with two
        latent states
    Returns
    -------
    list of tuples with span of values
    """
    spans = []
    x_init = 0
    for t, _ in enumerate(z_hist[:-1]):
        if z_hist[t + 1] == 0 and z_hist[t] == 1:
            x_end = t
            spans.append((x_init, x_end))
        elif z_hist[t + 1] == 1 and z_hist[t] == 0:
            x_init = t + 1
    return spans

In [None]:
def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):
    """
    Plot the estimated smoothing/filtering/map of a sequence of hidden states.
    "Vertical gray bars denote times when the hidden
    state corresponded to state 1. Blue lines represent the
    posterior probability of being in that state given diﬀerent subsets
    of observed data." See Markov and Hidden Markov models section for more info
    Parameters
    ----------
    inference_values: array(n_samples, state_size)
        Result of runnig smoothing method
    z_hist: array(n_samples)
        Latent simulation
    ax: matplotlib.axes
    state: int
        Decide which state to highlight
    map_estimate: bool
        Whether to plot steps (simple plot if False)
    """
    n_samples = len(inference_values)
    xspan = np.arange(1, n_samples + 1)
    spans = find_dishonest_intervals(z_hist)
    if map_estimate:
        ax.step(xspan, inference_values, where="post")
    else:
        ax.plot(xspan, inference_values[:, state])

    for span in spans:
        ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
    ax.set_xlim(1, n_samples)
    # ax.set_ylim(0, 1)
    ax.set_ylim(-0.1, 1.1)
    ax.set_xlabel("Observation number")

In [None]:

dict_figures = {}

# Plot results
fig, ax = plt.subplots()
plot_inference(alpha, z_hist, ax)
ax.set_ylabel("p(loaded)")
ax.set_title("Filtered")
dict_figures["hmm_casino_filter"] = fig

fig, ax = plt.subplots()
plot_inference(gamma, z_hist, ax)
ax.set_ylabel("p(loaded)")
ax.set_title("Smoothed")
dict_figures["hmm_casino_smooth"] = fig

fig, ax = plt.subplots()
plot_inference(z_map, z_hist, ax, map_estimate=True)
ax.set_ylabel("MAP state")
ax.set_title("Viterbi")
dict_figures["hmm_casino_map"] = fig

plt.show()

# Draw model

In [None]:
def savedotfile(dotfiles):
  if "FIGDIR" in os.environ:
      figdir = os.environ["FIGDIR"]
      for name, dot in dotfiles.items():
          fname_full = os.path.join(figdir, name)
          dot.render(fname_full)
          print(f"saving dot file to {fname_full}")

