-
Notifications
You must be signed in to change notification settings - Fork 0
/
baum_welch.py
148 lines (106 loc) · 4.51 KB
/
baum_welch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import numpy as np
from util import memoized, hashdict
from fb import forward_prob_table, backward_prob_table
from lmatrix import LMatrix
def convergent(old_score, new_score, places = 4):
"""whether scores are approximate enough"""
return np.abs((old_score - new_score) / old_score) < 10**(-places)
@memoized
def gamma(obs, j, from_state, to_state, A, B, pi):
"""
(observation: list of str, time j: int, from state: str, to state: str, transition matrix: LMatrix, emission matrix: LMatrix, hashdict: init state vector) => float
for observation `obs`, the probability that has `from_state` as its `j`th state and `to_state` as its (`j`+1)th state
"""
ft, obs_prob1 = forward_prob_table(obs, A, B, pi)
bt, obs_prob2 = backward_prob_table(obs, A, B, pi)
#ensure obs_prob from both tables are the same
assert(np.abs(obs_prob1 - obs_prob2) < 1e-5)
return ft[from_state, j] * A[from_state, to_state] * B[to_state, obs[j+1]] * bt[to_state, j+1] / obs_prob1
@memoized
def delta(obs, j, s, A, B, pi):
"""
(observation: list of str, time j: int, from state: str, to state: str, transition matrix: LMatrix, emission matrix: LMatrix, hashdict: init state vector) => float
for observation `obs`, the probability that its `j`th state is `s`, given A, B and pi
"""
states = A.rlabels
T = len(obs)
if j < T - 1: #not the special case
return sum( (gamma(obs, j, s, to_state, A, B, pi) for to_state in states) )
else:
ft, obs_prob = forward_prob_table(obs, A, B, pi)
return ft[s, T-1] / obs_prob
def one_iter(lst_of_obs, A, B, pi):
"""
given list of observations and the configuratin of HMM(A, B and pi),
first expect, then maximize, last return a new HMM
"""
Q = A.rlabels
V = B.clabels
#get pi
pi_unnormalized = np.array(map(lambda s: sum((delta(obs, 0, s, A, B, pi) for obs in lst_of_obs)), Q))
#normalize it
pi_normalized = pi_unnormalized / np.sum(pi_unnormalized)
#to hashdict
pi_normalized = hashdict(zip(Q, pi_normalized))
#get transition prob matrix
A_unnormalized = LMatrix(rlabels = Q, clabels = Q)
for obs in lst_of_obs:
T = len(obs)
for fs in Q:
for ts in Q:
A_unnormalized[fs, ts] += sum( (gamma(obs, j, fs, ts, A, B, pi) for j in xrange(T-1)) )
#normalize it
rc,cc = A_unnormalized.shape
A_normalized = A_unnormalized / A_unnormalized.sum(1).reshape(rc,1).repeat(cc,1)
#get emission prob matrix
B_unnormalized = LMatrix(rlabels = Q, clabels = V)
for obs in lst_of_obs:
for j, ob in enumerate(obs):
for s in Q:
B_unnormalized[s,ob] += delta(obs, j, s, A, B, pi)
#normalize it
rc,cc = B_unnormalized.shape
B_normalized = B_unnormalized / B_unnormalized.sum(1).reshape(rc,1).repeat(cc,1)
return A_normalized, B_normalized, pi_normalized
def clear_memoization():
gamma.cache = {}
delta.cache = {}
forward_prob_table.cache = {}
backward_prob_table.cache = {}
def take_snapshot(iteration, A, B, pi):
#save a snap shot of the parameters
from cPickle import dump, load
dump(A.rlabels, open("param_snapshot/Q.vec", "w"))
dump(B.clabels, open("param_snapshot/V.vec", "w"))
dump(A,open("param_snapshot/%d_A.mat" %iteration, "w"))
dump(B,open("param_snapshot/%d_B.mat" %iteration, "w"))
dump(pi,open("param_snapshot/%d_pi.mat" %iteration, "w"))
def baum_welch(lst_of_obs, A, B, pi):
"""
lst_of_obs: list of tuple of str, list of observation sequence,
(
A: initial transition prob matrix,
B: initial emission prob matrix
pi: initial state vector
) =>
(
A: transition prob matrix
B: emission prob matrix,
pi: initial state vector
)
the baum-welch algorithm
"""
scores = []
iteration = 0
while True:
take_snapshot(iteration, A, B, pi)
old_score = sum( (np.log(forward_prob_table(obs, A, B, pi)[1]) for obs in lst_of_obs) )
new_A,new_B,new_pi = one_iter(lst_of_obs, A, B, pi)
new_score = sum( (np.log(forward_prob_table(obs, new_A, new_B, new_pi)[1]) for obs in lst_of_obs) )
print "iteration %d, score %f" %(iteration, new_score)
if convergent(old_score, new_score):
return new_A, new_B, new_pi
A, B, pi = new_A, new_B, new_pi
#to prevent memory usage explode
clear_memoization()
iteration += 1