# Implementing discrete HMMs in Numpy 

We start with a simple numpy implementation.

In [1]:
# Install necessary libraries

try:
    import jax
except:
    # For cuda version, see https://github.com/google/jax#installation
    %pip install --upgrade "jax[cpu]" 
    import jax

try:
    import jsl
except:
    %pip install git+https://github.com/probml/jsl
    import jsl


In [2]:
import abc
from dataclasses import dataclass
import functools
import itertools

from typing import Any, Callable, NamedTuple, Optional, Union, Tuple

import inspect
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

def print_source(fname):
    print('source code of ', fname)
    #txt = inspect.getsource(fname)
    (lines, line_num) = inspect.getsourcelines(fname)
    for line in lines:
        print(line.strip('\n'))

In [3]:
import jsl
import jsl.hmm.hmm_numpy_lib as hmm_lib_np
#import jsl.hmm.hmm_lib as hmm_lib_jax

Here are some handy utility functions we have already defined.

In [4]:
normalize = hmm_lib_np.normalize_numpy
print_source(normalize)
#print_source(hmm_lib_np.normalize_numpy)

source code of  <function normalize_numpy at 0x7fbf2893f0d0>
def normalize_numpy(u, axis=0, eps=1e-15):
    '''
    Normalizes the values within the axis in a way that they sum up to 1.

    Parameters
    ----------
    u : array
    axis : int
    eps : float
        Threshold for the alpha values

    Returns
    -------
    * array
        Normalized version of the given matrix

    * array(seq_len, n_hidden) :
        The values of the normalizer
    '''
    u = np.where(u == 0, 0, np.where(u < eps, eps, u))
    c = u.sum(axis=axis)
    c = np.where(c == 0, 1, c)
    return u / c, c


We first create the "Ocassionally dishonest casino" model from {cite}`Durbin98`.

```{figure} /figures/casino.png
:scale: 50%
:name: casino

Illustration of the casino HMM.
```



In [5]:

# state transition matrix
A = np.array([
    [0.95, 0.05],
    [0.10, 0.90]
])

# observation matrix
B = np.array([
    [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
    [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
])

pi = np.array([1, 1]) / 2

(nstates, nobs) = jnp.shape(B)
for i in range(nstates):
    A[i,:] = normalize(A[i,:])[0]
    B[i,:] = normalize(B[i,:])[0]

Let's bundle the parameters into a structure.

In [6]:

class HMMNumpy(NamedTuple):
    trans_mat: np.array  # A : (n_states, n_states)
    obs_mat: np.array  # B : (n_states, n_obs)
    init_dist: np.array  # pi : (n_states)
        

params_numpy = HMMNumpy(A, B, pi)
print(params_numpy)

HMMNumpy(trans_mat=array([[0.95, 0.05],
       [0.1 , 0.9 ]]), obs_mat=array([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,
        0.16666667],
       [0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,
        0.5       ]]), init_dist=array([0.5, 0.5]))


Function to sample a single sequence of hidden states and discrete observations.

In [7]:
hmm_sample = hmm_lib_np.hmm_sample_numpy
print_source(hmm_sample)

source code of  <function hmm_sample_numpy at 0x7fbf2a1ceee0>
def hmm_sample_numpy(params, seq_len, random_state=0):
    '''
    Samples an observation of given length according to the defined
    hidden markov model and gives the sequence of the hidden states
    as well as the observation.

    Parameters
    ----------
    params : HMMNumpy
        Hidden Markov Model

    seq_len: array(seq_len)
        The length of the observation sequence

    random_state : int
        Seed value

    Returns
    -------
    * array(seq_len,)
        Hidden state sequence

    * array(seq_len,) :
        Observation sequence
    '''

    def sample_one_step_(hist, a, p):
        x_t = np.random.choice(a=a, p=p)
        return np.append(hist, [x_t]), x_t

    seed(random_state)

    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist
    n_states, n_obs = obs_mat.shape

    state_seq = np.array([], dtype=int)
    obs_seq = np.array([], dtype=int)

    latent_states

Let's sample from this model.

In [8]:
seq_len = 20
state_seq, obs_seq = hmm_sample(params_numpy, seq_len, random_state=0)
print(state_seq)
print(obs_seq)

[1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 0 0 0]
[5 5 5 5 3 5 5 0 4 5 5 5 5 5 4 5 5 3 3 4]
