-
Notifications
You must be signed in to change notification settings - Fork 0
/
hmm_trans_matrix.py
51 lines (44 loc) · 1.72 KB
/
hmm_trans_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pickle
from collections import defaultdict
def bigrams(words):
wprev = None
for w in words:
if not wprev==None:
yield (wprev, w)
wprev = w
def hmm_prob(domain):
bigram = [''.join((i,j)) for i,j in bigrams(domain) if not i==None]
prob = transitions[''][bigram[0]]
for x in range(len(bigram)-1):
next_step = transitions[bigram[x]][bigram[x+1]]
prob *= next_step
return prob
def train(data, trans_matrix_path, n_grams=2):
words = [w.strip().lower() for w in data]
words = ["^" + w.split('/')[0] + "$" for w in words if w != ""]
transitions = defaultdict(lambda: defaultdict(float))
n = n_grams
for word in words:
if len(word) >= n:
transitions[""][word[:n]] += 1.0
for i in range(len(word) - n):
gram = word[i : i + n]
next_ = word[i + 1 : i + n + 1]
transitions[gram][next_] += 1.0
# normalize the probabilities
for gram in transitions:
total = sum([transitions[gram][next_] for next_ in transitions[gram]])
for next_ in transitions[gram]:
transitions[gram][next_] /= total
with open(trans_matrix_path, mode='w') as fw:
for key1, dict1 in transitions.items():
for key2, value in dict1.items():
fw.write('%s\t%s\t%f\n'%(key1,key2,value))
def load_trans_matrix(trans_matrix_path):
transitions = defaultdict(lambda: defaultdict(float)) # 加载
with open(trans_matrix_path, mode='r') as f_trans:
for f in f_trans:
key1,key2,value =f.rstrip().split('\t')
value = float(value)
transitions[key1][key2] = value
return transitions