<a href="https://colab.research.google.com/github/peterchang0414/hmm-jax/blob/main/fixed_lag_smoother.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fixed Lag Smoother - Unit Tests

This notebook demonstrates the following: 

1. the correctness of a JAX-ified version of fixed lag smoother by comparison of its full-lag smoothed posterior with the results of JSL's implementation of `hmm_forwards_backwards_jax`;

2. the improved "online" performance of a version that uses a vectorized approach to compute the $\beta$ values across its sliding window against the one that iteratively smooths backwards inside the window, as implemented in Kevin Murphy's HMM Toolbox for Matlab. 

The JAX State-Space Models Library (JSL) is available at:
https://github.com/probml/JSL

Kevin Murphy's Hidden Markov Model (HMM) Toolbox for Matlab is available at:
https://www.cs.ubc.ca/~murphyk/Software/HMM/hmm.html

Author: Peter G. Chang ([@peterchang0414](https://github.com/peterchang0414))

# 0. Imports

In [1]:
!pip install -q flax

[?25l[K     |█▉                              | 10 kB 16.7 MB/s eta 0:00:01[K     |███▋                            | 20 kB 19.1 MB/s eta 0:00:01[K     |█████▍                          | 30 kB 22.4 MB/s eta 0:00:01[K     |███████▏                        | 40 kB 19.0 MB/s eta 0:00:01[K     |█████████                       | 51 kB 15.5 MB/s eta 0:00:01[K     |██████████▊                     | 61 kB 17.3 MB/s eta 0:00:01[K     |████████████▌                   | 71 kB 16.9 MB/s eta 0:00:01[K     |██████████████▎                 | 81 kB 17.9 MB/s eta 0:00:01[K     |████████████████                | 92 kB 19.6 MB/s eta 0:00:01[K     |█████████████████▉              | 102 kB 18.0 MB/s eta 0:00:01[K     |███████████████████▋            | 112 kB 18.0 MB/s eta 0:00:01[K     |█████████████████████▍          | 122 kB 18.0 MB/s eta 0:00:01[K     |███████████████████████▏        | 133 kB 18.0 MB/s eta 0:00:01[K     |█████████████████████████       | 143 kB 18.0 MB/s eta 0:

In [2]:
!git clone https://github.com/probml/JSL.git

Cloning into 'JSL'...
remote: Enumerating objects: 1781, done.[K
remote: Counting objects: 100% (1726/1726), done.[K
remote: Compressing objects: 100% (1138/1138), done.[K
remote: Total 1781 (delta 1126), reused 1124 (delta 567), pack-reused 55[K
Receiving objects: 100% (1781/1781), 5.78 MiB | 30.66 MiB/s, done.
Resolving deltas: 100% (1145/1145), done.


In [4]:
import sys
sys.path.insert(0,'/content/JSL')
from jsl.hmm.hmm_lib import HMMJax, hmm_forwards_backwards_jax

In [6]:
from functools import partial

import jax.numpy as jnp
from jax import vmap
from jax import jit
from jax.random import PRNGKey, split

In [4]:
# Naive (un-vectorized) version
@partial(jax.jit, static_argnums=(1))
def fixed_lag_smoother(params, win_len, alpha_win, obs_seq_win, obs, act=None):
    '''
    Description...

    Parameters
    ----------
    params      : HMMJax
        Hidden Markov Model (with action-dependent transition)
    win_len     : int
        Desired window length (>= 2)
    alpha_win   : array
        Alpha values for the most recent win_len steps, excluding current step
    obs_seq_win : array
        Observations for the most recent win_len steps, excluding current step
    obs         : int
        New observation for the current step
    act         : array
        (optional) Actions for the most recent win_len steps, including current step
    Returns
    -------
    * array(win_len, n_states)
        Updated alpha values
    * array(win_len)
        Updated observations for the past d steps
    * array(win_len, n_states)
        Smoothed posteriors for the past d steps

    '''
    if len(alpha_win.shape) < 2:
        alpha_win = jnp.expand_dims(alpha_win, axis=0)
    curr_len = alpha_win.shape[0]
    win_len = min(win_len, curr_len+1)
    assert win_len >= 2, "Must keep a window of length at least 2."

    trans_mat, obs_mat = params.trans_mat, params.obs_mat
    n_states, n_obs = obs_mat.shape
    
    # If trans_mat is independent of action, adjust shape
    if len(trans_mat.shape) < 3:
        trans_mat = jnp.expand_dims(trans_mat, axis=0)
        act = None
    if act is None:
        act = jnp.zeros(shape=(curr_len+1,), dtype=jnp.int8)

    # Shift window forward by 1
    if curr_len == win_len:
        alpha_win = alpha_win[1:]
        obs_seq_win = obs_seq_win[1:]
    new_alpha, _ = normalize(
        obs_mat[:, obs] * (alpha_win[-1][:, None] * trans_mat[act[-1]]).sum(axis=0)
    )
    alpha_win = jnp.concatenate((alpha_win, new_alpha[None, :]), axis=0)
    obs_seq_win = jnp.append(obs_seq_win, obs)

    # Smooth backwards inside the window
    beta_win = jnp.ones(shape=(win_len, n_states))
    gamma_win = jnp.array(alpha_win)
    for t in range(win_len-2, -1, -1):
        new_beta, _ = normalize(
            (beta_win[t+1,:] * obs_mat[:, obs_seq_win[t+1]] *
             trans_mat[act[t]]).sum(axis=1)
        )
        beta_win = beta_win.at[t, :].set(new_beta)

        new_gamma, _ = normalize(alpha_win[t, :]*beta_win[t, :])
        gamma_win = gamma_win.at[t, :].set(new_gamma)
    return alpha_win, obs_seq_win, beta_win, gamma_win

In [5]:
# Vectorized version
@partial(jax.jit, static_argnums=(1))
def fixed_lag_smoother_vectorized(params, win_len, alpha_win, bmatrix_win, obs, act=None):
    '''
    Description...

    Parameters
    ----------
    params      : HMMJax
        Hidden Markov Model (with action-dependent transition)
    win_len     : int
        Desired window length (>= 2)
    alpha_win   : array
        Alpha values for the most recent win_len steps, excluding current step
    bmatrix_win   : array
        Beta transformations for the most recent win_len steps, excluding current step
    obs         : int
        New observation for the current step
    act         : array
        (optional) Actions for the most recent win_len steps, including current step
    Returns
    -------
    * array(win_len, n_states)
        Updated alpha values
    * array(win_len, n_states)
        Updated beta transformations
    * array(win_len, n_states)
        Smoothed posteriors for the past d steps
    '''
    if len(alpha_win.shape) < 2:
        alpha_win = jnp.expand_dims(alpha_win, axis=0)
    curr_len = alpha_win.shape[0]
    win_len = min(win_len, curr_len+1)
    assert win_len >= 2, "Must keep a window of length at least 2."

    trans_mat, obs_mat = params.trans_mat, params.obs_mat
    n_states, n_obs = obs_mat.shape
    
    # If trans_mat is independent of action, adjust shape
    if len(trans_mat.shape) < 3:
        trans_mat = jnp.expand_dims(trans_mat, axis=0)
        act = None
    if act is None:
        act = jnp.zeros(shape=(curr_len+1,), dtype=jnp.int8)

    # Shift window forward by 1
    if curr_len == win_len:
        alpha_win = alpha_win[1:]
        bmatrix_win = bmatrix_win[1:]
    # Perform one forward operation
    new_alpha, _ = normalize(
        obs_mat[:, obs] * (alpha_win[-1][:, None] * trans_mat[act[-1]]).sum(axis=0)
    )
    alpha_win = jnp.concatenate((alpha_win, new_alpha[None, :]))
    # Smooth inside the window in parallel
    def update_bmatrix(bmatrix):
        return (bmatrix @ trans_mat[act[-2]]) * obs_mat[:, obs]
    bmatrix_win = vmap(update_bmatrix)(bmatrix_win)
    bmatrix_win = jnp.concatenate((bmatrix_win, jnp.eye(n_states)[None, :]))
    # Compute beta values by row-summing bmatrices
    def get_beta(bmatrix):
        return normalize(bmatrix.sum(axis=1))[0]
    beta_win = vmap(get_beta)(bmatrix_win)
    gamma_win, _ = normalize(alpha_win * beta_win, axis=1)

    return alpha_win, bmatrix_win, gamma_win

In [6]:
# Umbrella-rain toy example (AI, A Modern Approach)
data = jnp.array([0, 0, 1, 1, 1])
transmat = jnp.array([[0.7, 0.3], [0.3, 0.7]])
obsmat = jnp.array([[0.9, 0.1], [0.2, 0.8]])
prior = jnp.array([0.5, 0.5])

hmm = HMMJax(trans_mat=transmat, obs_mat=obsmat, init_dist=prior)



In [16]:
# Larger-scale example for time experiments
key = PRNGKey(0)
data = jax.random.choice(key, 2, shape=(100,))

In [17]:
def get_fixed_lag_smoother_result(params, win_len, data, prior, act=None):
    assert data.size > 2, "Complete observation set must be of size at least 2"
    alpha, _ = normalize(jnp.multiply(prior, obsmat[:, data[0]]))
    obs_seq = jnp.array([data[0]])
    for obs in data[1:]:
        alpha, obs_seq, beta, gamma = fixed_lag_smoother(hmm, win_len, alpha, obs_seq, obs)
    return alpha, beta, gamma

In [18]:
def get_fixed_lag_smoother_result_vectorized(params, win_len, data, prior, act=None):
    assert data.size > 2, "Complete observation set must be of size at least 2"
    trans_mat, obs_mat = params.trans_mat, params.obs_mat
    n_states, n_obs = obs_mat.shape
    alpha, _ = normalize(jnp.multiply(prior, obsmat[:, data[0]]))
    bmatrix = jnp.eye(n_states)[None, :]
    obs_seq = jnp.array([data[0]])
    for obs in data[1:]:
        alpha, bmatrix, gamma = fixed_lag_smoother_vectorized(hmm, win_len, alpha, bmatrix, obs)
    return alpha, gamma

In [23]:
%%time
*_, gamma = get_fixed_lag_smoother_result(hmm, 20, data, prior)

CPU times: user 21.8 s, sys: 162 ms, total: 21.9 s
Wall time: 21.9 s


In [25]:
%%time
*_, gamma_vec = get_fixed_lag_smoother_result_vectorized(hmm, 20, data, prior)

CPU times: user 3.72 s, sys: 39 ms, total: 3.75 s
Wall time: 3.72 s


In [177]:
_, _, g, _ = hmm_forwards_backwards_jax(hmm, data)
print(g)

[[0.6852272  0.3147728 ]
 [0.104872   0.8951281 ]
 [0.08992866 0.9100713 ]
 [0.51482636 0.48517358]
 [0.08724132 0.9127587 ]
 [0.09187889 0.9081211 ]
 [0.56770897 0.43229103]
 [0.24101867 0.7589813 ]
 [0.80832404 0.19167593]
 [0.8642606  0.13573939]]


In [None]:
key = PRNGKey(0)
num_states = 3
num_obs = 5
num_timesteps = 15

data = jax.random.choice(key, num_obs, (num_timesteps,))
key, _ = split(key)

transmat = jax.random.uniform(key, shape=(num_states, num_states))
transmat, _ = normalize(transmat, axis=1)
print(transmat)
key, _ = split(key)

obsmat = jax.random.uniform(key, shape=(num_states, num_obs))
obsmat, _ = normalize(obsmat, axis=1)
print(obsmat)
key, _ = split(key)

prior = jnp.array([0.33, 0.33, 0.34])