<a href="https://colab.research.google.com/github/peterchang0414/hmm-jax/blob/main/fixed_lag_smoother.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install -q flax

In [37]:
# Inference and learning code for Hidden Markov Models using discrete observations.
# Has Jax version of each function. For the Numpy version, please see hmm_numpy_lib.py
# The Jax version of inference (not learning)
# has been upstreamed to https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm.py.
# This version is kept for historical purposes.
# Author: Gerardo Duran-Martin (@gerdm), Aleyna Kara (@karalleyna), Kevin Murphy (@murphyk)

from jax import lax
from jax.scipy.special import logit
from functools import partial

import jax.numpy as jnp
from scipy.special import softmax
from jax import vmap
from dataclasses import dataclass

import jax
import itertools
from jax import jit
from jax.nn import softmax
from jax.random import PRNGKey, split, normal
from jax.random import split, randint, PRNGKey, normal, permutation

import flax

'''
Hidden Markov Model class used in jax implementations of inference algorithms.
The functions of optimizers expect that the type of its parameters 
is pytree. So, they cannot work on a vanilla dataclass. To see more:
                https://github.com/google/jax/issues/2371
Since the flax.dataclass is registered pytree beforehand, it facilitates to use
jit, vmap and optimizers on the hidden markov model.
'''


@flax.struct.dataclass
class HMMJax:
    trans_mat: jnp.array  # A : (n_states, n_states)
    obs_mat: jnp.array  # B : (n_states, n_obs)
    init_dist: jnp.array  # pi : (n_states)


def normalize(u, axis=0, eps=1e-15):
    '''
    Normalizes the values within the axis in a way that they sum up to 1.
    Parameters
    ----------
    u : array
    axis : int
    eps : float
        Threshold for the alpha values
    Returns
    -------
    * array
        Normalized version of the given matrix
    * array(seq_len, n_hidden) :
        The values of the normalizer
    '''
    u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))
    c = u.sum(axis=axis, keepdims=True)
    c = jnp.where(c == 0, 1, c)
    return u / c, c


##############################
# Inference

def hmm_forwards_jax(params, obs_seq, length=None):
    '''
    Calculates a belief state
    Parameters
    ----------
    params : HMMJax
        Hidden Markov Model
    obs_seq: array(seq_len)
        History of observable events
    Returns
    -------
    * float
        The loglikelihood giving log(p(x|model))
    * array(seq_len, n_hidden) :
        All alpha values found for each sample
    '''
    seq_len = len(obs_seq)

    if length is None:
        length = seq_len

    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist

    trans_mat = jnp.array(trans_mat)
    obs_mat = jnp.array(obs_mat)
    init_dist = jnp.array(init_dist)

    n_states, n_obs = obs_mat.shape

    def scan_fn(carry, t):
        (alpha_prev, log_ll_prev) = carry
        alpha_n = jnp.where(t < length,
                            obs_mat[:, obs_seq[t]] * (alpha_prev[:, None] * trans_mat).sum(axis=0),
                            jnp.zeros_like(alpha_prev))

        alpha_n, cn = normalize(alpha_n)
        carry = (alpha_n, jnp.log(cn) + log_ll_prev)

        return carry, alpha_n

    # initial belief state
    alpha_0, c0 = normalize(init_dist * obs_mat[:, obs_seq[0]])
    print("before normalize: ", init_dist * obs_mat[:, obs_seq[0]])
    print("after normalize: ", alpha_0)

    # setup scan loop
    init_state = (alpha_0, jnp.log(c0))
    ts = jnp.arange(1, seq_len)
    carry, alpha_hist = lax.scan(scan_fn, init_state, ts)

    # post-process
    alpha_hist = jnp.vstack([alpha_0.reshape(1, n_states), alpha_hist])
    (alpha_final, log_ll) = carry
    return log_ll, alpha_hist


@jit
def hmm_loglikelihood_jax(params, observations, lens):
    '''
    Finds the loglikelihood of each observation sequence parallel using vmap.
    Parameters
    ----------
    params : HMMJax
        Hidden Markov Model
    observations: array(N, seq_len)
        Batch of observation sequences
    lens : array(N, seq_len)
        Consists of the valid length of each observation sequence
    Returns
    -------
    * array(N, seq_len)
        Consists of the loglikelihood of each observation sequence
    '''

    def forward_(params, x, length):
        return hmm_forwards_jax(params, x, length)[0]

    return vmap(forward_, in_axes=(None, 0, 0))(params, observations, lens)


@jit
def hmm_backwards_jax(params, obs_seq, length=None):
    '''
    Computes the backwards probabilities
    Parameters
    ----------
    params : HMMJax
        Hidden Markov Model
    obs_seq: array(seq_len,)
        History of observable events
    length : array(seq_len,)
        The valid length of the observation sequence
    Returns
    -------
    * array(seq_len, n_states)
       Beta values
    '''
    seq_len = len(obs_seq)

    if length is None:
        length = seq_len

    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist

    trans_mat = jnp.array(trans_mat)
    obs_mat = jnp.array(obs_mat)
    init_dist = jnp.array(init_dist)

    n_states, n_obs = obs_mat.shape

    beta_t = jnp.ones((n_states,))

    def scan_fn(beta_prev, t):
        beta_t = jnp.where(t > length,
                           jnp.zeros_like(beta_prev),
                           normalize((beta_prev * obs_mat[:, obs_seq[-t + 1]] * trans_mat).sum(axis=1))[0])
        return beta_t, beta_t

    ts = jnp.arange(2, seq_len + 1)
    _, beta_hist = lax.scan(scan_fn, beta_t, ts)

    beta_hist = jnp.flip(jnp.vstack([beta_t.reshape(1, n_states), beta_hist]), axis=0)

    return beta_hist


@jit
def hmm_forwards_backwards_jax(params, obs_seq, length=None):
    '''
    Computes, for each time step, the marginal conditional probability that the Hidden Markov Model was
    in each possible state given the observations that were made at each time step, i.e.
    P(z[i] | x[0], ..., x[num_steps - 1]) for all i from 0 to num_steps - 1
    Parameters
    ----------
    params : HMMJax
        Hidden Markov Model
    obs_seq: array(seq_len)
        History of observed states
    Returns
    -------
    * array(seq_len, n_states)
        Alpha values
    * array(seq_len, n_states)
        Beta values
    * array(seq_len, n_states)
        Marginal conditional probability
    * float
        The loglikelihood giving log(p(x|model))
    '''
    seq_len = len(obs_seq)
    if length is None:
        length = seq_len

    def gamma_t(t):
        gamma_t = jnp.where(t < length,
                            alpha[t] * beta[t - length],
                            jnp.zeros((n_states,)))
        return gamma_t

    ll, alpha = hmm_forwards_jax(params, obs_seq, length)
    n_states = alpha.shape[1]

    beta = hmm_backwards_jax(params, obs_seq, length)

    ts = jnp.arange(seq_len)
    gamma = vmap(gamma_t, (0))(ts)
    # gamma = alpha * jnp.roll(beta, -seq_len + length, axis=0) #: Alternative
    gamma = vmap(lambda x: normalize(x)[0])(gamma)
    return alpha, beta, gamma, ll


In [38]:
key = PRNGKey(0)
num_states = 3
num_obs = 5
num_timesteps = 15

data = jax.random.choice(key, num_obs, (num_timesteps,))
key, _ = split(key)

transmat = jax.random.uniform(key, shape=(num_states, num_states))
transmat, _ = normalize(transmat, axis=1)
print(transmat)
key, _ = split(key)

obsmat = jax.random.uniform(key, shape=(num_states, num_obs))
obsmat, _ = normalize(obsmat, axis=1)
print(obsmat)
key, _ = split(key)

prior = jnp.array([0.33, 0.33, 0.34])

[[0.00327667 0.372277   0.62444633]
 [0.03723689 0.5034511  0.459312  ]
 [0.38789576 0.27844316 0.33366108]]
[[0.19659741 0.04710761 0.2822843  0.16964358 0.30436713]
 [0.26830336 0.08532328 0.22020769 0.18710755 0.23905815]
 [0.1802366  0.16051903 0.21636975 0.33137584 0.11149877]]


In [39]:
hmm = HMMJax(trans_mat=transmat, obs_mat=obsmat, init_dist=prior)

In [40]:
print(hmm_forwards_jax(hmm, data))
print(hmm_forwards_backwards_jax(hmm, data))

before normalize:  [0.10044116 0.07888919 0.03790958]
after normalize:  [0.46235126 0.36314315 0.17450559]
(DeviceArray([-26.463226], dtype=float32), DeviceArray([[0.46235126, 0.36314315, 0.17450559],
             [0.03226263, 0.28504366, 0.68269366],
             [0.1256727 , 0.2855128 , 0.5888145 ],
             [0.2185694 , 0.44159764, 0.33983293],
             [0.23735477, 0.4984296 , 0.26421568],
             [0.08197626, 0.30641595, 0.61160785],
             [0.11105373, 0.28693265, 0.60201365],
             [0.22325665, 0.4403048 , 0.33643854],
             [0.23550224, 0.49924642, 0.26525134],
             [0.19819315, 0.5257243 , 0.27608263],
             [0.08609609, 0.30977526, 0.6041286 ],
             [0.17386994, 0.27751768, 0.5486124 ],
             [0.09731788, 0.28144982, 0.6212323 ],
             [0.11233708, 0.28358203, 0.6040809 ],
             [0.36626592, 0.41382396, 0.21991009]], dtype=float32))
before normalize:  Traced<ShapedArray(float32[3])>with<DynamicJaxprT

In [None]:
def fixed_lag_smoother(d, alpha, obslik, obsvec, transmt, act):
    '''
    Description...

    Parameters
    ----------
    d         : int
        The desired window length (must be >= 2).
    alpha     : array
        length d window, excluding t0 (columns indexed 1,..., d)
    obslilk   : array
        length d window
    obsvec    : array
        likelihood vector for current observation
    transfmat : array
        transition matrix
    act       : 
        ()

    '''
    return 0