In [0]:
import numpy as np
import copy

# Demonstration of a hidden Markov Model
by Jiaxin Shi, ishijiaxin@126.com (from https://github.com/thjashin/hmm)

modified by Lior Pachter

In [61]:
!date

Tue Jan 14 17:09:49 UTC 2020


# 0. Implementation

This notebook implements the following hidden Markov model (HMM) algorithms:
- Forward-Backward (sum-product)
- Viterbi (max-product)
- Baum-Welch (expectation-maximization algorithm)

In [0]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import copy


class HMM:
    def __init__(self, states, transition, emission, init):
        self.state_names = copy.copy(states)
        self.n_states = len(states)
        self.A = transition.copy()
        self.B = emission.copy()
        self.n_emissions = self.B.shape[1]
        self.init = init

    def generate(self, length):
        state = self.init
        states = []
        ret = []
        for i in xrange(1, length + 1):
            state = np.random.choice(range(self.n_states), p=self.A[state])
            states.append(state)
            ret.append(
                np.random.choice(range(self.n_emissions), p=self.B[state]))
        print 'Generating by states:', ''.join(self.state_names[i] for i in states)
        ret = ''.join([str(i) for i in ret])
        return ret

    def _forward(self, seq_arr):
        T = len(seq_arr)
        alpha = np.zeros((T + 1, self.n_states))
        alpha[0, self.init] = 1
        log_px = 0.
        for t in xrange(1, T + 1):
            alpha[t] = self.B[:, seq_arr[t - 1]] * \
                       np.dot(alpha[t - 1], self.A)
            pt = alpha[t].sum()
            alpha[t] /= pt
            log_px += np.log(pt)
        return alpha, log_px

    def _backward(self, seq_arr):
        T = len(seq_arr)
        beta = np.zeros((T + 1, self.n_states))
        beta[T, :] = 1
        log_px = 0.
        for t in xrange(T, 0, -1):
            beta[t - 1] = np.dot(self.A, beta[t] * self.B[:, seq_arr[t - 1]])
            pt = beta[t - 1].sum()
            beta[t - 1] /= pt
            log_px += np.log(pt)
        log_px += np.log(beta[0, self.init])
        return beta, log_px

    def viterbi(self, seq):
        # := max-product
        seq_arr = np.array([int(i) for i in seq])
        T = len(seq_arr)
        T1 = np.zeros((self.n_states, T + 1))
        T1[self.init, 0] = 1
        T2 = np.zeros((self.n_states, T + 1), dtype='int')
        states = np.zeros(T + 1, dtype='int')
        for t in xrange(1, T + 1):
            for j in xrange(self.n_states):
                T1[j, t] = np.max(T1[:, t - 1] * self.A[:, j])
                T1[j, t] *= self.B[j, seq_arr[t - 1]]
                T2[j, t] = np.argmax(T1[:, t - 1] * self.A[:, j])
        states[T] = np.argmax(T1[:, T])
        for t in xrange(T, 1, -1):
            states[t - 1] = T2[states[t], t - 1]
        return ''.join([self.state_names[s] for s in states[1:]])

    def baum_welch(self, seq):
        # := EM
        seq_arr = np.array([int(i) for i in seq])
        T = len(seq_arr)
        kesi = np.zeros((T + 1, self.n_states, self.n_states))
        log_px = None
        iter = 0
        while True:
            iter += 1
            alpha, alpha_log_px = self._forward(seq_arr)
            print "Iter %d" % iter, "log p(x): %s" % alpha_log_px
            if log_px and (np.abs(
                    log_px - alpha_log_px) < np.abs(1e-6 * log_px)):
                print "Converged."
                break
            beta, beta_log_px = self._backward(seq_arr)
            try:
                assert np.abs(
                    alpha_log_px - beta_log_px) < np.abs(1e-6 * alpha_log_px)
            except AssertionError as e:
                print "alpha_log_px:", alpha_log_px
                print "beta_log_px:", beta_log_px
                raise e
            log_px = alpha_log_px
            gamma = alpha * beta
            gamma /= np.sum(gamma, axis=1, keepdims=True)
            for t in xrange(1, T):
                kesi[t] = np.outer(
                    alpha[t],
                    beta[t + 1] * self.B[:, seq_arr[t + 1 - 1]]) * self.A
            kesi[1:T] = kesi[1:T] / kesi[1:T].sum(axis=(1, 2), keepdims=True)
            self.A = kesi[1:T].sum(axis=0) / \
                     gamma[1:T].sum(axis=0)[:, np.newaxis]
            assert np.all(np.abs(1. - self.A.sum(axis=1)) < 1e-6)
            obs = np.zeros((T + 1, self.n_emissions))
            obs[range(1, T + 1), seq_arr] = 1
            self.B = np.dot(gamma[1:].T, obs[1:]) / \
                     gamma[1:].sum(axis=0)[:, np.newaxis]
        print "Estimate A:"
        print np.array_str(self.A, precision=3)
        print "Estimate B:"
        print np.array_str(self.B, precision=3)
        return log_px, self.A, self.B

    def gibbs(self, seq, steps=1, burn_in=0, max_iters=None):
        seq_arr = np.array([int(i) for i in seq])
        T = len(seq_arr)
        states = np.zeros(T + 1, dtype='int')
        iter = 0
        log_px = None
        states[0] = self.init
        while True:
            iter += 1
            alpha, alpha_log_px = self._forward(seq_arr)
            print "Iter %d" % iter, "log p(x): %s" % alpha_log_px
            if log_px and (np.abs(
                    log_px - alpha_log_px) < np.abs(1e-6 * log_px)):
                print "Converged."
                break
            log_px = alpha_log_px
            if max_iters and (iter >= max_iters):
                break
            A = np.zeros_like(self.A)
            B = np.zeros_like(self.B)
            for t in xrange(1, T + 1):
                states[t] = np.random.choice(range(3))
            for step in xrange(steps):
                for t in xrange(1, T + 1):
                    p_state_t = self.B[:, seq_arr[t - 1]] * \
                                self.A[states[t - 1]]
                    if t < T:
                        p_state_t *= self.A[:, states[t + 1]]
                    p_state_t /= p_state_t.sum()
                    states[t] = np.random.choice(range(3), p=p_state_t)
                if step >= burn_in:
                    for t in xrange(1, T + 1):
                        if t < T:
                            A[states[t], states[t + 1]] += 1
                        B[states[t], seq_arr[t - 1]] += 1
            A = np.maximum(1., A)
            B = np.maximum(1., B)
            self.A = A / A.sum(axis=1, keepdims=True)
            self.B = B / B.sum(axis=1, keepdims=True)
        print "Estimate A:"
        print np.array_str(self.A, precision=3)
        print "Estimate B:"
        print np.array_str(self.B, precision=3)
        return log_px

## 1.1 Hidden state inference 
This code generates sequences of lengths 100, 1000 and 10000 (see comment in code) on an alphabet of size 3 from 3 hidden states, and then the Viterbi algorithm is used to infer the most likely hidden states that generated the sequence(s).

In [63]:
np.random.seed(1236)
states = ['A', 'B', 'C']

print "1.1 Generation\n"
transition = np.array([
    [0.8, 0.2, 0.0],
    [0.1, 0.7, 0.2],
    [0.1, 0.0, 0.9]
])
emission = np.array([
    [0.9, 0.1],
    [0.5, 0.5],
    [0.1, 0.9]
])
init = 0
hmm = HMM(states, transition, emission, init)
seqs = []
for seq_len in [100, 1000, 10000]:
    seq = hmm.generate(seq_len)
    seqs.append(seq)
    print "Inferred optimal state series:"
    print hmm.viterbi(seq)
    # NOTE: To run chains with various length, REMOVE this break
    break


1.1 Generation

Generating by states: AAAAAABBBBCCCCCCCCCCCCCCAAAAAAAABBBBBBBCCAABBBCCCCCCCCCCAABBBBBCCCCCCCCCCCCCCCCABBBABBAAAAAAAAAAAAAA
Inferred optimal state series:
AAAAAABBBBBCCCCCCCCCCCCCCCAAAAAAAAAABBBBBBBBBBBCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCAAAAAAAAAAAAAAAAAA


## 1.2 Baum Welch


In [64]:
print "\n1.2 Baum Welch"
for seq in seqs:
    print "\nSequence length:", len(seq)
    As = []
    Bs = []
    for run in xrange(10):
        print "Run", run
        transition2 = np.random.random((3, 3))
        transition2 /= transition2.sum(axis=1, keepdims=True)
        emission2 = np.random.random((3, 2))
        emission2 /= emission2.sum(axis=1, keepdims=True)
        print "Init transition:"
        print transition2
        print "Init emission:"
        print emission2
        hmm2 = HMM(states, transition2, emission2, init)
        log_px, A, B = hmm2.baum_welch(seq)
        As.append(A)
        Bs.append(B)
        print "Final log p(x):", log_px
        break
   


1.2 Baum Welch

Sequence length: 100
Run 0
Init transition:
[[0.41141391 0.43933319 0.1492529 ]
 [0.64224114 0.283969   0.07378985]
 [0.21495536 0.31890642 0.46613823]]
Init emission:
[[0.63570068 0.36429932]
 [0.09348741 0.90651259]
 [0.32872192 0.67127808]]
Iter 1 log p(x): -72.19318633406532
Iter 2 log p(x): -69.69171314424933
Iter 3 log p(x): -68.77981999052288
Iter 4 log p(x): -67.66782656008486
Iter 5 log p(x): -65.82203787983642
Iter 6 log p(x): -62.79602914246915
Iter 7 log p(x): -59.44891713760399
Iter 8 log p(x): -57.56877412931889
Iter 9 log p(x): -56.90826500507927
Iter 10 log p(x): -56.64218616802919
Iter 11 log p(x): -56.50497639816005
Iter 12 log p(x): -56.4228378284552
Iter 13 log p(x): -56.36503427752464
Iter 14 log p(x): -56.3170377131918
Iter 15 log p(x): -56.27174756008821
Iter 16 log p(x): -56.225296064295726
Iter 17 log p(x): -56.17512103660135
Iter 18 log p(x): -56.119191411817994
Iter 19 log p(x): -56.0558369205126
Iter 20 log p(x): -55.9839597360092
Iter 21 lo

In [65]:
!date

Tue Jan 14 17:09:51 UTC 2020
