In [1]:
from numba import jit

In [249]:
class HMM:
    import numpy as np
    def forward(A,B,pi,sequence):
        '''
        Perform the forward step in Baum-Welch Algorithm.
        
        Parameters
        ----------
        A: np.ndarray
            Stochastic transition matrix.
        B: np.ndarray
            Emission matrix.
        pi: np.1darray
            Initial state distribution.
        sequence: array-like
            The observed sequence.
            Need to be converted to integer coded.    
        '''
        
        N=A.shape[0]
        M=B.shape[1]
        T=len(sequence)
        alpha=np.zeros([T,N])
        alpha[0]=pi*B[:,sequence[0]]
        t=1
        while True:
            if t==T:
                break
            alpha[t]=alpha[t-1]@A*B[:,sequence[t]]
            t+=1
        return(np.sum(alpha[T-1]),alpha)
    
    def backward(A,B,pi,sequence):
        '''
        Perform the backward step in Baum-Welch Algorithm.
        
        Parameters
        ----------
        A: np.ndarray
            Stochastic transition matrix.
        B: np.ndarray
            Emission matrix.
        pi: np.1darray
            Initial state distribution.
        sequence: array-like
            The observed sequence.
            Need to be converted to integer coded.    
        '''
        
        N=A.shape[0]
        M=B.shape[1]
        T=len(sequence)
        beta=np.zeros([T,N])
        beta[T-1]=1
        t=T-2
        while True:
            if t<0:
                break
            beta[t]=A@(B[:,sequence[t+1]]*beta[t+1])
            t-=1
        return(np.sum(pi*B[:,sequence[0]]*beta[0]),beta)
    
    def Viterbi(A,B,pi,sequence):
        '''
        Viterbi decoding of HMM.
        
        Parameters
        ----------
        A: np.ndarray
            Stochastic transition matrix.
        B: np.ndarray
            Emission matrix.
        pi: np.1darray
            Initial state distribution.
        sequence: array-like
            The observed sequence.
            Need to be converted to integer coded.    
        '''
        
        N=A.shape[0]
        M=B.shape[1]
        T=len(sequence)
        delta=np.zeros([T,N])
        psi=np.zeros([T,N])
        delta[0]=pi*B[:,sequence[0]]
        t=1
        while True:
            if t==T:
                break
            delta_A=delta[t-1,np.newaxis].T*A
            delta[t]=np.max(delta_A,axis=0)*B[:,sequence[t]]
            psi[t]=np.argmax(delta_A,axis=0)
            t+=1
        psi=psi.astype(int)
        q=np.zeros(T).astype(int)
        q[T-1]=np.argmax(delta[T-1])
        t=T-2
        while True:
            if t<0:
                break
            q[t]=psi[t+1,q[t+1]]
            t-=1
        return(q)
    
    def Baum_Welch(A,B,pi,sequence,max_iter,threshold=1e-15):
        '''
        Baum-Welch algorithm of HMM. 
        See https://en.wikipedia.org/wiki/Baum%E2%80%93Welch_algorithm.
        
        Parameters
        ----------
        A: np.ndarray
            Initial stochastic transition matrix.
        B: np.ndarray
            Emission matrix.
        pi: np.1darray
            Initial state distribution.
        sequence: array-like
            The observed sequence.
            Need to be converted to integer coded.    
        '''
        
        N=A.shape[0]
        M=B.shape[1]
        T=len(sequence)
        likelihood,alpha=HMM.forward(A,B,pi,sequence)
        for i in range(max_iter):
            beta=HMM.backward(A,B,pi,sequence)[1]
            #temporary variables
            gamma=alpha*beta/np.sum(alpha*beta,axis=1).reshape((T,1))
            #Non-vectorized version for xi
            #xi=np.zeros([N,N,T-1])
            #for t in range(T-1):
            #    xi[:,:,t]=alpha[t].reshape((N,1))*A*beta[t+1]*B[:,sequence[t+1]]
            #    xi[:,:,t]=xi[:,:,t]/np.sum(xi[:,:,t])
            xi=alpha.T[:,np.newaxis,:-1]*A[:,:,np.newaxis]*(beta*B[:,sequence].T).T[np.newaxis,:,1:]
            xi=xi/np.sum(xi,axis=(0,1))
            pi=gamma[0]
            A=np.sum(xi,axis=2)/np.sum(gamma[:-1],axis=0).reshape([N,1])
            B=np.zeros([N,M])
            for t in range(T):
                B[:,sequence[t]]+=gamma[t]
            B=B/np.sum(gamma,axis=0).reshape([N,1])
            likelihood_new,alpha=HMM.forward(A,B,pi,sequence)
            if abs(likelihood-likelihood_new)<threshold:
                break
            likelihood=likelihood_new
        return(A,B,pi)
    
    def Baum_Welch_linear_memory(A,B,pi,sequence,max_iter,threshold=1e-15):
        '''
        Baum-Welch algorithm in linear memory.
        Implemented according to Churbanov, A., & Winters-Hilt, S. (2008).
        
        Parameters
        ----------
        A: np.ndarray
            Initial stochastic transition matrix.
        B: np.ndarray
            Emission matrix.
        pi: np.1darray
            Initial state distribution.
        sequence: array-like
            The observed sequence.
            Need to be converted to integer coded.  
        '''
        
        N=A.shape[0]
        M=B.shape[1]
        T=len(sequence)
        ###########################
        for z in range(max_iter):
            ##Beta_t+1
            beta_tilt_old=np.zeros(N)
            ##Beta_t
            beta_tilt_new=np.zeros(N)
            #T_t+1
            T_tilt_old=np.zeros([N,N,N])
            #T_t
            T_tilt_new=np.zeros([N,N,N])
            #E_t+1
            E_tilt_old=np.zeros([N,M,N])
            #E_t
            E_tilt_new=np.zeros([N,M,N])
            beta_tilt_old+=1
            d=1/np.sum(beta_tilt_old)
            beta_tilt_old=d*beta_tilt_old
            for m in range(N):
                for i in range(N):
                    for gamma in range(M):
                        E_tilt_old[i,gamma,m]=beta_tilt_old[i]*int(sequence[T-1]==gamma)
            for t in range(T-2,-1,-1):
                beta_tilt_new=A@(B[:,sequence[t+1]]*beta_tilt_old)
                dt=1/np.sum(beta_tilt_new)
                for m in range(N):
                    for i in range(N):
                        for j in range(N):
                            partial=0
                            for n in range(N):
                                partial+=A[m,n]*T_tilt_old[i,j,n]*B[n,sequence[t+1]]
                            T_tilt_new[i,j,m]=(beta_tilt_old[j]*A[m,j]*B[j,sequence[t+1]]*int(i==m)
                                               +partial)*dt
                        for gamma in range(M):
                            partial=0
                            for n in range(N):
                                partial+=B[n,sequence[t+1]]*A[m,n]*E_tilt_old[i,gamma,n]
                            E_tilt_new[i,gamma,m]=(partial+
                                        beta_tilt_new[m]*int(sequence[t]==gamma)*int(m==i))*dt
                beta_tilt_new=dt*beta_tilt_new
                beta_tilt_old=beta_tilt_new
                T_tilt_old=T_tilt_new
                E_tilt_old=E_tilt_new
            E_end=np.zeros([N,M])
            T_end=np.zeros([N,N])
            for m in range(N):
                E_end+=E_tilt_old[:,:,m]*pi[m]*B[m,sequence[0]]
                T_end+=T_tilt_old[:,:,m]*pi[m]*B[m,sequence[0]]
            alpha=pi*B[:,sequence[0]]
            pi=alpha*beta_tilt_old
            pi=pi/np.sum(pi)
            B=E_end/np.sum(E_end,axis=1).reshape((N,1))
            A=T_end/np.sum(T_end,axis=1).reshape((N,1))
        return(A,B,pi)

In [253]:
M=3
N=3
pi=np.array([.3,.3,.4])
A=np.array([[.2,.3,.5],[.1,.5,.4],[.6,.1,.3]])
B=np.array([[0.1,0.5,0.4],[0.2,0.4,0.4],[0.3,0.6,0.1]])
sequence=[0,1,2,1,0]
print(HMM.forward(A,B,pi,sequence)[0])
print(HMM.forward(A,B,pi,sequence)[0])
print(HMM.Viterbi(A,B,pi,sequence))
print(HMM.Baum_Welch(A,B,pi,sequence,1))
print(HMM.Baum_Welch_linear_memory(A,B,pi,sequence,1))

0.0030591384
0.0030591384
[2 2 0 2 2]
(array([[ 0.21727437,  0.31716984,  0.46555579],
       [ 0.09400679,  0.47723982,  0.42875339],
       [ 0.5605604 ,  0.10263823,  0.33680136]]), array([[ 0.23722818,  0.4164501 ,  0.34632173],
       [ 0.38668796,  0.36020626,  0.25310578],
       [ 0.51583503,  0.41448613,  0.06967885]]), array([ 0.15082207,  0.28366026,  0.56551766]))
(array([[ 0.23486928,  0.32791706,  0.43721367],
       [ 0.09967307,  0.49309234,  0.40723459],
       [ 0.58368956,  0.10491812,  0.31139232]]), array([[ 0.56835923,  0.27576884,  0.15587193],
       [ 0.5083402 ,  0.27619717,  0.21546263],
       [ 0.50720294,  0.4059924 ,  0.08680466]]), array([ 0.15082207,  0.28366026,  0.56551766]))
