<a href="https://colab.research.google.com/github/peterchang0414/hmm-jax/blob/main/fixed_lag_smoother.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q flax

[K     |████████████████████████████████| 184 kB 6.6 MB/s 
[K     |████████████████████████████████| 140 kB 12.5 MB/s 
[K     |████████████████████████████████| 72 kB 604 kB/s 
[?25h

In [2]:
# Inference and learning code for Hidden Markov Models using discrete observations.
# Has Jax version of each function. For the Numpy version, please see hmm_numpy_lib.py
# The Jax version of inference (not learning)
# has been upstreamed to https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm.py.
# This version is kept for historical purposes.
# Author: Gerardo Duran-Martin (@gerdm), Aleyna Kara (@karalleyna), Kevin Murphy (@murphyk)

from jax import lax
from jax.scipy.special import logit
from functools import partial

import jax.numpy as jnp
from scipy.special import softmax
from jax import vmap
from dataclasses import dataclass

import jax
import itertools
from jax import jit
from jax.nn import softmax
from jax.random import PRNGKey, split, normal
from jax.random import split, randint, PRNGKey, normal, permutation

import flax

'''
Hidden Markov Model class used in jax implementations of inference algorithms.
The functions of optimizers expect that the type of its parameters 
is pytree. So, they cannot work on a vanilla dataclass. To see more:
                https://github.com/google/jax/issues/2371
Since the flax.dataclass is registered pytree beforehand, it facilitates to use
jit, vmap and optimizers on the hidden markov model.
'''


@flax.struct.dataclass
class HMMJax:
    trans_mat: jnp.array  # A : (n_states, n_states)
    obs_mat: jnp.array  # B : (n_states, n_obs)
    init_dist: jnp.array  # pi : (n_states)


def normalize(u, axis=0, eps=1e-15):
    '''
    Normalizes the values within the axis in a way that they sum up to 1.
    Parameters
    ----------
    u : array
    axis : int
    eps : float
        Threshold for the alpha values
    Returns
    -------
    * array
        Normalized version of the given matrix
    * array(seq_len, n_hidden) :
        The values of the normalizer
    '''
    u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))
    c = u.sum(axis=axis, keepdims=True)
    c = jnp.where(c == 0, 1, c)
    return u / c, c


##############################
# Inference

def hmm_forwards_jax(params, obs_seq, length=None):
    '''
    Calculates a belief state
    Parameters
    ----------
    params : HMMJax
        Hidden Markov Model
    obs_seq: array(seq_len)
        History of observable events
    Returns
    -------
    * float
        The loglikelihood giving log(p(x|model))
    * array(seq_len, n_hidden) :
        All alpha values found for each sample
    '''
    seq_len = len(obs_seq)

    if length is None:
        length = seq_len

    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist

    trans_mat = jnp.array(trans_mat)
    obs_mat = jnp.array(obs_mat)
    init_dist = jnp.array(init_dist)

    n_states, n_obs = obs_mat.shape

    def scan_fn(carry, t):
        (alpha_prev, log_ll_prev) = carry
        alpha_n = jnp.where(t < length,
                            obs_mat[:, obs_seq[t]] * (alpha_prev[:, None] * trans_mat).sum(axis=0),
                            jnp.zeros_like(alpha_prev))

        alpha_n, cn = normalize(alpha_n)
        carry = (alpha_n, jnp.log(cn) + log_ll_prev)

        return carry, alpha_n

    # initial belief state
    alpha_0, c0 = normalize(init_dist * obs_mat[:, obs_seq[0]])

    # setup scan loop
    init_state = (alpha_0, jnp.log(c0))
    ts = jnp.arange(1, seq_len)
    carry, alpha_hist = lax.scan(scan_fn, init_state, ts)

    # post-process
    alpha_hist = jnp.vstack([alpha_0.reshape(1, n_states), alpha_hist])
    (alpha_final, log_ll) = carry
    return log_ll, alpha_hist


@jit
def hmm_loglikelihood_jax(params, observations, lens):
    '''
    Finds the loglikelihood of each observation sequence parallel using vmap.
    Parameters
    ----------
    params : HMMJax
        Hidden Markov Model
    observations: array(N, seq_len)
        Batch of observation sequences
    lens : array(N, seq_len)
        Consists of the valid length of each observation sequence
    Returns
    -------
    * array(N, seq_len)
        Consists of the loglikelihood of each observation sequence
    '''

    def forward_(params, x, length):
        return hmm_forwards_jax(params, x, length)[0]

    return vmap(forward_, in_axes=(None, 0, 0))(params, observations, lens)


def hmm_backwards_jax(params, obs_seq, length=None):
    '''
    Computes the backwards probabilities
    Parameters
    ----------
    params : HMMJax
        Hidden Markov Model
    obs_seq: array(seq_len,)
        History of observable events
    length : array(seq_len,)
        The valid length of the observation sequence
    Returns
    -------
    * array(seq_len, n_states)
       Beta values
    '''
    seq_len = len(obs_seq)

    if length is None:
        length = seq_len

    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist

    trans_mat = jnp.array(trans_mat)
    obs_mat = jnp.array(obs_mat)
    init_dist = jnp.array(init_dist)

    n_states, n_obs = obs_mat.shape

    beta_t = jnp.ones((n_states,))

    def scan_fn(beta_prev, t):
        beta_t = jnp.where(t > length,
                           jnp.zeros_like(beta_prev),
                           normalize((beta_prev * obs_mat[:, obs_seq[-t + 1]] * trans_mat).sum(axis=1))[0])
        return beta_t, beta_t

    ts = jnp.arange(2, seq_len + 1)
    _, beta_hist = lax.scan(scan_fn, beta_t, ts)

    beta_hist = jnp.flip(jnp.vstack([beta_t.reshape(1, n_states), beta_hist]), axis=0)

    return beta_hist


def hmm_forwards_backwards_jax(params, obs_seq, length=None):
    '''
    Computes, for each time step, the marginal conditional probability that the Hidden Markov Model was
    in each possible state given the observations that were made at each time step, i.e.
    P(z[i] | x[0], ..., x[num_steps - 1]) for all i from 0 to num_steps - 1
    Parameters
    ----------
    params : HMMJax
        Hidden Markov Model
    obs_seq: array(seq_len)
        History of observed states
    Returns
    -------
    * array(seq_len, n_states)
        Alpha values
    * array(seq_len, n_states)
        Beta values
    * array(seq_len, n_states)
        Marginal conditional probability
    * float
        The loglikelihood giving log(p(x|model))
    '''
    seq_len = len(obs_seq)
    if length is None:
        length = seq_len

    def gamma_t(t):
        gamma_t = jnp.where(t < length,
                            alpha[t] * beta[t - length],
                            jnp.zeros((n_states,)))
        return gamma_t

    ll, alpha = hmm_forwards_jax(params, obs_seq, length)
    n_states = alpha.shape[1]

    beta = hmm_backwards_jax(params, obs_seq, length)

    ts = jnp.arange(seq_len)
    gamma = vmap(gamma_t, (0))(ts)
    # gamma = alpha * jnp.roll(beta, -seq_len + length, axis=0) #: Alternative
    gamma = vmap(lambda x: normalize(x)[0])(gamma)
    return alpha, beta, gamma, ll


In [None]:
key = PRNGKey(0)
num_states = 3
num_obs = 5
num_timesteps = 15

data = jax.random.choice(key, num_obs, (num_timesteps,))
key, _ = split(key)

transmat = jax.random.uniform(key, shape=(num_states, num_states))
transmat, _ = normalize(transmat, axis=1)
print(transmat)
key, _ = split(key)

obsmat = jax.random.uniform(key, shape=(num_states, num_obs))
obsmat, _ = normalize(obsmat, axis=1)
print(obsmat)
key, _ = split(key)

prior = jnp.array([0.33, 0.33, 0.34])



[[0.00327667 0.372277   0.62444633]
 [0.03723689 0.5034511  0.459312  ]
 [0.38789576 0.27844316 0.33366108]]
[[0.19659741 0.04710761 0.2822843  0.16964358 0.30436713]
 [0.26830336 0.08532328 0.22020769 0.18710755 0.23905815]
 [0.1802366  0.16051903 0.21636975 0.33137584 0.11149877]]


In [11]:
# Umbrella-rain toy example (AI, A Modern Approach)
data = jnp.array([0, 0, 1, 1, 1])
transmat = jnp.array([[0.7, 0.3], [0.3, 0.7]])
obsmat = jnp.array([[0.9, 0.1], [0.2, 0.8]])
prior = jnp.array([0.5, 0.5])

hmm = HMMJax(trans_mat=transmat, obs_mat=obsmat, init_dist=prior)

print(hmm_forwards_backwards_jax(hmm, data))

(DeviceArray([[0.81818175, 0.18181819],
             [0.88335705, 0.11664297],
             [0.19066793, 0.8093321 ],
             [0.07011891, 0.92988104],
             [0.05751521, 0.94248486]], dtype=float32), DeviceArray([[0.5727641 , 0.42723584],
             [0.32267347, 0.6773265 ],
             [0.32465208, 0.67534786],
             [0.34444445, 0.65555555],
             [1.        , 1.        ]], dtype=float32), DeviceArray([[0.8578096 , 0.14219043],
             [0.7829768 , 0.21702313],
             [0.10172988, 0.89827013],
             [0.03811033, 0.9618896 ],
             [0.0575152 , 0.94248474]], dtype=float32), DeviceArray([-3.300516], dtype=float32))


In [3]:
@flax.struct.dataclass
class HMMWithActionJax:
    trans_mat: jnp.array  # A : (n_actions, n_states, n_states)
    obs_mat: jnp.array  # B : (n_states, n_obs)
    init_dist: jnp.array  # pi : (n_states)

In [69]:
def fixed_lag_smoother(params, win_len, alpha_win, obs_seq_win, obs, act=None):
    '''
    Description...

    Parameters
    ----------
    params      : HMMWithActionJax
        Hidden Markov Model (with action-dependent transition)
    win_len     : int
        Desired window length (>= 2)
    alpha_win   : array
        Alpha values for the most recent win_len steps, excluding current step
    obs_seq_win : array
        Observations for the most recent win_len steps, excluding current step
    obs         : int
        New observation for the current step
    act         : array
        (optional) Actions for the most recent win_len steps, including current step
    Returns
    -------
    * array(win_len, n_states)
        Updated alpha values
    * array(win_len)
        Updated observations for the past d steps
    * array(win_len, n_states)
        Smoothed posteriors for the past d steps
    * float
        The loglikelihood of the past d steps
    '''
    curr_len = alpha_win.shape[1]
    win_len = min(win_len, curr_len+1)
    assert win_len < 2,"Must keep a window of length at least 2."

    trans_mat, obs_mat = params.trans_mat, params.obs_mat
    n_states, n_obs = obs_mat.shape
    
    # If trans_mat is independent of action, adjust shape
    if len(trans_mat.shape) < 3:
        trans_mat = jnp.expand_dims(trans_mat, axis=0)
        act = None
    if act is None:
        act = jnp.zeros(shape=(1, curr_len+1))

    # Shift window by 1
    if curr_len < win_len:
        alpha_win = alpha_win[1:]
        obs_seq_win = obs_seq_win[1:]
    new_alpha = obs_mat[:, obs] * (alpha_win[-1][:, None] * 
                                   trans_mat[act[-1]]).sum(axis=0)
    alpha_win = jnp.append(alpha_win, new_alpha)
    obs_seq_win = jnp.append(obs_seq_win, obs)

    # Smooth backwards inside the window
    beta = jnp.ones(n_states, win_len)


In [53]:
X = jnp.arange(1000)
window_len = 10
window = X[jnp.arange(window_len)]

In [60]:
X = jnp.array([[1, 1]])

In [68]:
X.shape

(1, 2)

In [67]:
X[:, None].shape

(1, 1, 2)

In [46]:
X.size

1000

In [56]:
%%time
X = jnp.arange(1000)
window_len = 10
window = X[jnp.arange(window_len)]

for i in range(X.size-window_len):
    window = window[1:]
    window = jnp.append(window, X[i+window_len])

CPU times: user 1.84 s, sys: 45.3 ms, total: 1.88 s
Wall time: 1.89 s


In [57]:
%%time
X = jnp.arange(1000)
window_len = 10
window = X[jnp.arange(window_len)]

for i in range(X.size-window_len):
    window = window.at[:-1].set(window[1:])
    window = window.at[-1].set(X[i+window_len])

CPU times: user 4.15 s, sys: 75.9 ms, total: 4.23 s
Wall time: 4.33 s


In [None]:
%%timeit
X = jnp.arange(1000)
window_len = 10
window = X[jnp.arange(window_len)]

for i in range(X.size-window_len):
    window = window[1:]
    window = jnp.append(window, X[i+window_len])

In [31]:
window

DeviceArray([], dtype=int32)

In [14]:
fixed_lag_smoother(hmm, 5, jnp.array([[1]]), jnp.array([[1]]), jnp.array([1]), 1)

AssertionError: ignored

In [4]:
trans_mat = jnp.array([[2,2],[2,2]])



In [8]:
jnp.expand_dims(trans_mat, axis=0).shape

(1, 2, 2)