In [12]:
import numpy as np
import copy

# A Toy HMM
Jiaxin Shi, ishijiaxin@126.com

# 0. Implementation

This implementation has following algorithms for HMM.
- Forward/Backward (sum-product)
- Viterbi (max-product)
- Baum-Welch (EM with exact E)
- Gibbs (EM with approx E)

In [13]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import copy


class HMM:
    def __init__(self, states, transition, emission, init):
        self.state_names = copy.copy(states)
        self.n_states = len(states)
        self.A = transition.copy()
        self.B = emission.copy()
        self.n_emissions = self.B.shape[1]
        self.init = init

    def generate(self, length):
        state = self.init
        states = []
        ret = []
        for i in xrange(1, length + 1):
            state = np.random.choice(range(self.n_states), p=self.A[state])
            states.append(state)
            ret.append(
                np.random.choice(range(self.n_emissions), p=self.B[state]))
        print 'Generating by states:', ''.join(self.state_names[i] for i in states)
        ret = ''.join([str(i) for i in ret])
        return ret

    def _forward(self, seq_arr):
        T = len(seq_arr)
        alpha = np.zeros((T + 1, self.n_states))
        alpha[0, self.init] = 1
        log_px = 0.
        for t in xrange(1, T + 1):
            alpha[t] = self.B[:, seq_arr[t - 1]] * \
                       np.dot(alpha[t - 1], self.A)
            pt = alpha[t].sum()
            alpha[t] /= pt
            log_px += np.log(pt)
        return alpha, log_px

    def _backward(self, seq_arr):
        T = len(seq_arr)
        beta = np.zeros((T + 1, self.n_states))
        beta[T, :] = 1
        log_px = 0.
        for t in xrange(T, 0, -1):
            beta[t - 1] = np.dot(self.A, beta[t] * self.B[:, seq_arr[t - 1]])
            pt = beta[t - 1].sum()
            beta[t - 1] /= pt
            log_px += np.log(pt)
        log_px += np.log(beta[0, self.init])
        return beta, log_px

    def viterbi(self, seq):
        # := max-product
        seq_arr = np.array([int(i) for i in seq])
        T = len(seq_arr)
        T1 = np.zeros((self.n_states, T + 1))
        T1[self.init, 0] = 1
        T2 = np.zeros((self.n_states, T + 1), dtype='int')
        states = np.zeros(T + 1, dtype='int')
        for t in xrange(1, T + 1):
            for j in xrange(self.n_states):
                T1[j, t] = np.max(T1[:, t - 1] * self.A[:, j])
                T1[j, t] *= self.B[j, seq_arr[t - 1]]
                T2[j, t] = np.argmax(T1[:, t - 1] * self.A[:, j])
        states[T] = np.argmax(T1[:, T])
        for t in xrange(T, 1, -1):
            states[t - 1] = T2[states[t], t - 1]
        return ''.join([self.state_names[s] for s in states[1:]])

    def baum_welch(self, seq):
        # := EM
        seq_arr = np.array([int(i) for i in seq])
        T = len(seq_arr)
        kesi = np.zeros((T + 1, self.n_states, self.n_states))
        log_px = None
        iter = 0
        while True:
            iter += 1
            alpha, alpha_log_px = self._forward(seq_arr)
            print "Iter %d" % iter, "log p(x): %s" % alpha_log_px
            if log_px and (np.abs(
                    log_px - alpha_log_px) < np.abs(1e-6 * log_px)):
                print "Converged."
                break
            beta, beta_log_px = self._backward(seq_arr)
            try:
                assert np.abs(
                    alpha_log_px - beta_log_px) < np.abs(1e-6 * alpha_log_px)
            except AssertionError as e:
                print "alpha_log_px:", alpha_log_px
                print "beta_log_px:", beta_log_px
                raise e
            log_px = alpha_log_px
            gamma = alpha * beta
            gamma /= np.sum(gamma, axis=1, keepdims=True)
            for t in xrange(1, T):
                kesi[t] = np.outer(
                    alpha[t],
                    beta[t + 1] * self.B[:, seq_arr[t + 1 - 1]]) * self.A
            kesi[1:T] = kesi[1:T] / kesi[1:T].sum(axis=(1, 2), keepdims=True)
            self.A = kesi[1:T].sum(axis=0) / \
                     gamma[1:T].sum(axis=0)[:, np.newaxis]
            assert np.all(np.abs(1. - self.A.sum(axis=1)) < 1e-6)
            obs = np.zeros((T + 1, self.n_emissions))
            obs[range(1, T + 1), seq_arr] = 1
            self.B = np.dot(gamma[1:].T, obs[1:]) / \
                     gamma[1:].sum(axis=0)[:, np.newaxis]
        print "Estimate A:"
        print np.array_str(self.A, precision=3)
        print "Estimate B:"
        print np.array_str(self.B, precision=3)
        return log_px, self.A, self.B

    def gibbs(self, seq, steps=1, burn_in=0, max_iters=None):
        seq_arr = np.array([int(i) for i in seq])
        T = len(seq_arr)
        states = np.zeros(T + 1, dtype='int')
        iter = 0
        log_px = None
        states[0] = self.init
        while True:
            iter += 1
            alpha, alpha_log_px = self._forward(seq_arr)
            print "Iter %d" % iter, "log p(x): %s" % alpha_log_px
            if log_px and (np.abs(
                    log_px - alpha_log_px) < np.abs(1e-6 * log_px)):
                print "Converged."
                break
            log_px = alpha_log_px
            if max_iters and (iter >= max_iters):
                break
            A = np.zeros_like(self.A)
            B = np.zeros_like(self.B)
            for t in xrange(1, T + 1):
                states[t] = np.random.choice(range(3))
            for step in xrange(steps):
                for t in xrange(1, T + 1):
                    p_state_t = self.B[:, seq_arr[t - 1]] * \
                                self.A[states[t - 1]]
                    if t < T:
                        p_state_t *= self.A[:, states[t + 1]]
                    p_state_t /= p_state_t.sum()
                    states[t] = np.random.choice(range(3), p=p_state_t)
                if step >= burn_in:
                    for t in xrange(1, T + 1):
                        if t < T:
                            A[states[t], states[t + 1]] += 1
                        B[states[t], seq_arr[t - 1]] += 1
            A = np.maximum(1., A)
            B = np.maximum(1., B)
            self.A = A / A.sum(axis=1, keepdims=True)
            self.B = B / B.sum(axis=1, keepdims=True)
        print "Estimate A:"
        print np.array_str(self.A, precision=3)
        print "Estimate B:"
        print np.array_str(self.B, precision=3)
        return log_px

## 1.1 Generation
We comment out long sequence (1000/10000) code here (visual burden). To run long sequence, following the comments in the code.

In [14]:
np.random.seed(1236)
states = ['A', 'B', 'C']

print "1.1 Generation\n"
transition = np.array([
    [0.8, 0.2, 0.0],
    [0.1, 0.7, 0.2],
    [0.1, 0.0, 0.9]
])
emission = np.array([
    [0.9, 0.1],
    [0.5, 0.5],
    [0.1, 0.9]
])
init = 0
hmm = HMM(states, transition, emission, init)
seqs = []
for seq_len in [100, 1000, 10000]:
    seq = hmm.generate(seq_len)
    seqs.append(seq)
    print "Inferred optimal state series:"
    print hmm.viterbi(seq)
    # NOTE: To run chains with various length, REMOVE this break
    break

1.1 Generation

Generating by states: AAAAAABBBBCCCCCCCCCCCCCCAAAAAAAABBBBBBBCCAABBBCCCCCCCCCCAABBBBBCCCCCCCCCCCCCCCCABBBABBAAAAAAAAAAAAAA
Inferred optimal state series:
AAAAAABBBBBCCCCCCCCCCCCCCCAAAAAAAAAABBBBBBBBBBBCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCAAAAAAAAAAAAAAAAAA


## 1.2 / 1.3 Baum Welch
We comment out variance computation code here (time consuming). To run variance computation, following the comments in the code.
So we put the variance computed here:

#### Variance for chain with length 100
```
Var(A):
[[ 0.09052057  0.06006296  0.04894882]
 [ 0.06632638  0.1313508   0.05201452]
 [ 0.06691535  0.0444691   0.08084979]]
Var(B):
[[ 0.06158714  0.06158714]
 [ 0.14604622  0.14604622]
 [ 0.0887852   0.0887852 ]]
```

#### Variance for chain with length 1000
```
Var(A):
[[ 0.08107341  0.0389215   0.02910226]
 [ 0.07240883  0.10532198  0.07765103]
 [ 0.04485569  0.07224511  0.08158469]]
Var(B):
[[ 0.08268158  0.08268158]
 [ 0.1226752   0.1226752 ]
 [ 0.10824732  0.10824732]]
```

#### Variance for chain with length 10000
```
Var(A):
[[ 0.02264952  0.01980305  0.02054278]
 [ 0.02993416  0.06203205  0.02176658]
 [ 0.03476353  0.03295623  0.07207851]]
Var(B):
[[ 0.13777976  0.13777976]
 [ 0.10384683  0.10384683]
 [ 0.12860774  0.12860774]]
```

The conclusion is that the longer chain, the less variance in estimation of transition matrix (A). No similar results for observation matrix (B).

In [15]:
print "\n1.2/1.3 Baum Welch"
for seq in seqs:
    print "\nSequence length:", len(seq)
    As = []
    Bs = []
    for run in xrange(10):
        print "Run", run
        transition2 = np.random.random((3, 3))
        transition2 /= transition2.sum(axis=1, keepdims=True)
        emission2 = np.random.random((3, 2))
        emission2 /= emission2.sum(axis=1, keepdims=True)
        print "Init transition:"
        print transition2
        print "Init emission:"
        print emission2
        hmm2 = HMM(states, transition2, emission2, init)
        log_px, A, B = hmm2.baum_welch(seq)
        As.append(A)
        Bs.append(B)
        print "Final log p(x):", log_px
        print "Optimal state series:", hmm2.viterbi(seq)
        # NOTE: To calculate variance, REMOVE this break
        break
    print "Variance of estimated:"
    print "Var(A):"
    print np.var(As, axis=0)
    print "Var(B):"
    print np.var(Bs, axis=0)


1.2/1.3 Baum Welch

Sequence length: 100
Run 0
Init transition:
[[ 0.41141391  0.43933319  0.1492529 ]
 [ 0.64224114  0.283969    0.07378985]
 [ 0.21495536  0.31890642  0.46613823]]
Init emission:
[[ 0.63570068  0.36429932]
 [ 0.09348741  0.90651259]
 [ 0.32872192  0.67127808]]
Iter 1 log p(x): -72.1931863341
Iter 2 log p(x): -69.6917131442
Iter 3 log p(x): -68.7798199905
Iter 4 log p(x): -67.6678265601
Iter 5 log p(x): -65.8220378798
Iter 6 log p(x): -62.7960291425
Iter 7 log p(x): -59.4489171376
Iter 8 log p(x): -57.5687741293
Iter 9 log p(x): -56.9082650051
Iter 10 log p(x): -56.642186168
Iter 11 log p(x): -56.5049763982
Iter 12 log p(x): -56.4228378285
Iter 13 log p(x): -56.3650342775
Iter 14 log p(x): -56.3170377132
Iter 15 log p(x): -56.2717475601
Iter 16 log p(x): -56.2252960643
Iter 17 log p(x): -56.1751210366
Iter 18 log p(x): -56.1191914118
Iter 19 log p(x): -56.0558369205
Iter 20 log p(x): -55.983959736
Iter 21 log p(x): -55.9035071474
Iter 22 log p(x): -55.8159897143
Iter 

### Compare with true parameters/states
Find estimated paramters/states in above results (only one run presented here).
The estimated parameters in most run is very similar as the true parameter. But some run also got stuck in local optimal.

## 2.1 Gibbs

In [16]:
print "\n2.1 Gibbs"
transition3 = np.random.random((3, 3))
transition3 /= transition3.sum(axis=1, keepdims=True)
emission3 = np.random.random((3, 2))
emission3 /= emission3.sum(axis=1, keepdims=True)
print "Init transition:"
print transition3
print "Init emission:"
print emission3
print "Sequence length:", len(seqs[0])
hmm3 = HMM(states, transition3, emission3, init)
log_px = hmm3.gibbs(seqs[0], steps=200, burn_in=100, max_iters=100)
print "Final log p(x):", log_px
print "Optimal state series:", hmm3.viterbi(seqs[0])


2.1 Gibbs
Init transition:
[[ 0.34110724  0.39255971  0.26633305]
 [ 0.24355523  0.37788669  0.37855808]
 [ 0.32565081  0.21240733  0.46194186]]
Init emission:
[[ 0.507358    0.492642  ]
 [ 0.37109127  0.62890873]
 [ 0.506693    0.493307  ]]
Sequence length: 100
Iter 1 log p(x): -68.9635658108
Iter 2 log p(x): -68.9547155068
Iter 3 log p(x): -68.9419872001
Iter 4 log p(x): -68.9193644246
Iter 5 log p(x): -68.8234956815
Iter 6 log p(x): -68.7058509947
Iter 7 log p(x): -68.5225882733
Iter 8 log p(x): -68.2710770794
Iter 9 log p(x): -67.955095736
Iter 10 log p(x): -67.147018193
Iter 11 log p(x): -65.453949634
Iter 12 log p(x): -62.5769123357
Iter 13 log p(x): -59.1442762918
Iter 14 log p(x): -57.2899042789
Iter 15 log p(x): -56.7861824909
Iter 16 log p(x): -56.7136273509
Iter 17 log p(x): -56.530892125
Iter 18 log p(x): -56.4306142307
Iter 19 log p(x): -56.3663842661
Iter 20 log p(x): -56.3058989836
Iter 21 log p(x): -56.2953490606
Iter 22 log p(x): -56.2905485485
Iter 23 log p(x): -56.3

### Compare with true parameters/states
Find estimated paramters/states in above results (only one run presented here).
The estimated parameters in most run is very similar as the true parameter (get similar likelihood with Baum Welch). But some run also got stuck in local optimal.

## 2.2 Compare results (Gibbs vs. Baum-Welch)

- Baum-Welch is an EM algorithm with exact inference for latent states. It is more likely to get stuck in local optimal. And due to exact inference, there is no chance for Baum-Welch to jump out of the local optimal.
- EM with Gibbs inference for latent states is also an EM, but with E-step inference by gibbs sampling, which is approximating posterior by samples drawn. 
    - With short chain: This method has stochastic nature and can jump out of some local optimals caused by EM and get better results.
    - With long chain: This method is approximating the posterior as well as exact inference by Baum Welch. So their results are similar.

## 2.3 MEMM: Gibbs pesudo code

```
E: Gibbs sampling the latent states
y[t] ~ p(y[t]|y[t-1], x[t])p(y[t+1]|y[t], x[t+1])
M: Re-estimate parameters
Restimate parameters of p(y[t]|y[t-1], x[t]) using sampled y[t] and y[t-1].
```