In [1]:
import numpy as np
from hmmlearn import hmm
from scipy.special import logsumexp

In [2]:
def logsumexp(x):
    c = x.max()
    return c + np.log(np.sum(np.exp(x - c)))

In [3]:
hmm_states = 2

model = hmm.CategoricalHMM(n_components=hmm_states)
model.emissionprob_ = np.array([[0.9,0.1],[0.1,0.9]])
model.transmat_ = np.array([[0.9,0.1],[0.1,0.9]])
model.startprob_ = np.array([0.4,0.6])
data,states = model.sample(n_samples = 1000,random_state=28)
data = np.squeeze(data)

In [4]:
n_states = 2
n_features = 2
n_obs = data.shape[0]
# emission_log = np.random.rand(n_obs,n_states)
# emission_log = emission_log/np.tile(np.expand_dims(np.sum(emission_log,axis=1),axis=1),(1,2))

# emission_log = np.random.rand(n_states,n_features)
# emission_log = emission_log/np.tile(np.expand_dims(np.sum(emission_log,axis=1),axis=1),(1,2))

emission_log = np.random.normal(0,1,(n_states,n_features)) # working in the log space
emission_log = emission_log - logsumexp(emission_log)
# emission_log = np.array([[0.96515917, 0.03484083],
#        [0.36020674, 0.63979326]])


# transition_log =  np.random.rand(n_states,n_states)
# transition_log = transition_log/np.tile(np.expand_dims(np.sum(transition_log,axis=1),axis=1),(1,2))

transition_log = np.random.normal(0,1,(n_states,n_states)) # working in the log space
transition_log = transition_log - logsumexp(transition_log)
# transition_log = np.array([[0.48577716, 0.51422284],
#        [0.41604097, 0.58395903]])


# scale_factors = np.zeros((n_obs))
forward_log = np.random.rand(n_obs,n_states)

backward_log = np.random.rand(n_obs,n_states)

init_prob_log = np.log(np.array([0.5,0.5]))


p_old = -10000
tol = 0.01
max_iter = 100

In [5]:
print(f'emission_log is {emission_log} and transition_log is {transition_log}')

emission_log is [[-2.78266652 -0.22593736]
 [-2.56593787 -2.75653353]] and transition_log is [[-0.47595801 -1.16080881]
 [-3.68496756 -3.20942782]]


In [6]:
for ite in range(max_iter):
    print(ite)
    forward_log = np.zeros((n_obs,n_states))
    backward_log = np.zeros((n_obs,n_states))

    forward_log[0,:] = init_prob_log + emission_log[:,data[0]]
    
    for t in range(n_obs-1):
        for i in range(n_states):
            temp = np.empty(n_states)
            for j in range(n_states):
                temp[j] = forward_log[t,j] + transition_log[j,i] 
            forward_log[t+1,i] = logsumexp(temp) + emission_log[i,data[t+1]]

        
        # print(f'temp is {temp} and the scale factor is {scale_factors[t+1]} and the forward_log is {forward_log[t+1]}')

    backward_log[-1,:] = 0
    for t in reversed(range(n_obs-1)):
        for i in range(n_states):
            temp = np.empty(n_states)
            for j in range(n_states):
                temp[j] = backward_log[t+1,j] + transition_log[i,j] + emission_log[j,data[t+1]]
            backward_log[t,i] = logsumexp(temp)

    log_x = logsumexp(forward_log[-1,:])

    a_log = np.zeros((n_obs,n_states))
    b_log = np.zeros((n_obs,n_states,n_states))
    for i in range(n_obs):
        for j in range(n_states):
            a_log[i,j] = forward_log[i,j]+backward_log[i,j]-log_x

    for t in range(n_obs-1):
        for i in range(n_states):
            for j in range(n_states):
                b_log[t,i,j] = forward_log[t,i]+backward_log[t+1,j]+transition_log[i,j]+emission_log[j,data[t+1]]-log_x

    for i in range(n_states):
        for j in range(n_states):
            transition_log[i,j] = logsumexp(b_log[0:-1,i,j]) - logsumexp(b_log[0:-1,i,:])
            # print(np.sum(b[0:-1,i,j]),np.sum(b[0:-1,i,:]))

    for i in range(n_states):
        init_prob_log[i] = a_log[0,i] - logsumexp(a_log[0,:])

    for i in range(n_states):
        for j in range(n_states):
            emission_log[j,i] = logsumexp(a_log[np.argwhere(data==i),j]) - logsumexp(a_log[:,j])
            
    p = log_x
    print(f'p is:{p},transition_log is {np.exp(transition_log)},emission_log is {np.exp(emission_log)},init is {np.exp(init_prob_log)}')
    # print(f'transition_log prob is: {transition_log}')
    # print(f'emission_log prob is:{emission_log}')
    if p>p_old and p - p_old < tol:
        break
    p_old = p

0
p is:-1953.4983390753987,transition_log is [[0.98615081 0.01384919]
 [0.9408459  0.0591541 ]],emission_log is [[0.48844237 0.51155763]
 [0.94143106 0.05856894]],init is [0.9967829 0.0032171]
1
p is:-692.8307323902625,transition_log is [[0.98597112 0.01402888]
 [0.91803573 0.08196427]],emission_log is [[0.4881212  0.5118788 ]
 [0.94565153 0.05434847]],init is [9.99645695e-01 3.54304938e-04]
2
p is:-692.6693294511035,transition_log is [[0.98564748 0.01435252]
 [0.88531282 0.11468718]],emission_log is [[0.48760723 0.51239277]
 [0.95157419 0.04842581]],init is [9.99964701e-01 3.52988975e-05]
3
p is:-692.396129831374,transition_log is [[0.98507901 0.01492099]
 [0.83734314 0.16265686]],emission_log is [[0.48673521 0.51326479]
 [0.95943967 0.04056033]],init is [9.99996974e-01 3.02606807e-06]
4
p is:-691.8768185860074,transition_log is [[0.98405118 0.01594882]
 [0.76517049 0.23482951]],emission_log is [[0.48513093 0.51486907]
 [0.96923671 0.03076329]],init is [9.99999794e-01 2.05742300e-07]
