-
Notifications
You must be signed in to change notification settings - Fork 9
/
Code6_20_ProfileHMM.py
98 lines (81 loc) · 2.34 KB
/
Code6_20_ProfileHMM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#Chunyu Zhao 20160227
import sys, numpy as np
from collections import OrderedDict
def profileHMM(theta,symbols,msa):
# check for insertions:
insertnum = [float(col.count('-'))/len(col) for col in zip(*msa)]
insertedCol = [n for n,i in enumerate(insertnum) if i>theta]
# match state number
matchnum = len(insertnum) - len(insertedCol)
names = ['S','I0']
for i in range(1,matchnum+1):
for state in ['M','D','I']:
names.append(state+str(i))
names.append('E')
n = 3*(matchnum+1)
transition = OrderedDict()
for i in range(n):
transition[names[i]] = OrderedDict(zip(names,[0.0]*n))
emission = OrderedDict()
for i in range(n):
emission[names[i]] = OrderedDict(zip(symbols,[0.0]*len(symbols)))
# feasible transition: (Mi,Di,Ii) -> (Ii,Mi+1,Di+1)
for seq in msa:
mi = 0
for si in range(len(seq)):
if si == 0:
if si in insertedCol:
prevstate = 'I0'
else:
prevstate = 'S'
if si in insertedCol:
if seq[si] == '-':
continue
else:
currstate = 'I'+str(mi)
elif seq[si] == '-':
currstate = 'D'+str(mi+1)
mi += 1
else:
currstate = 'M'+str(mi+1)
mi += 1
transition[prevstate][currstate] += 1
if seq[si] != '-':
emission[currstate][seq[si]] += 1
prevstate = currstate
transition[currstate]['E'] += 1
transition = normalize(transition,names)
emission = normalize(emission,names)
return transition,emission,names
def normalize(mat,names):
ret = []
for ti in range(len(mat)):
newrow = mat[names[ti]].values()
if newrow.count(0) < len(newrow):
newrow = [row/sum(newrow) for row in newrow]
ret.append(newrow)
return ret
def printtransition(mat,names):
print "\t".join([''] + names)
for ri,row in enumerate(mat):
print names[ri]+"\t" + '\t'.join(["%.3f"]*len(row)) % tuple(row)
def printemission(mat,names,symbols):
print "\t".join(['']+symbols)
for ri,row in enumerate(mat):
print names[ri]+"\t" + '\t'.join(["%.3f"]*len(row)) % tuple(row)
def main():
if len(sys.argv) == 2:
filename = sys.argv[1]
with open(filename) as f:
lines = f.read().splitlines()
theta = float(lines[0])
symbols = lines[2].split('\t')#
msa = []
for i in range(4,len(lines)):#
msa.append(lines[i])
transition,emission,names = profileHMM(theta, symbols,msa)
printtransition(transition,names)
print "--------"
printemission(emission,names,symbols)
if __name__ == '__main__':
main()