-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
serialization_string_kernels.py
218 lines (162 loc) · 5.51 KB
/
serialization_string_kernels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#!/usr/bin/env python
from shogun import WeightedDegreeStringKernel, LinearKernel, PolyKernel, GaussianKernel, CTaxonomy
from shogun import CombinedKernel, WeightedDegreeRBFKernel
from shogun import StringCharFeatures, RealFeatures, CombinedFeatures, StringWordFeatures, SortWordString
from shogun import DNA, PROTEIN, Labels
from shogun import WeightedDegreeStringKernel, CombinedKernel, WeightedCommWordStringKernel, WeightedDegreePositionStringKernel
from shogun import StringCharFeatures, DNA, StringWordFeatures, CombinedFeatures
from shogun import MSG_DEBUG
from shogun import RealFeatures, BinaryLabels, DNA, Alphabet
from shogun import WeightedDegreeStringKernel, GaussianKernel
try:
from shogun import SVMLight
except ImportError:
print("SVMLight is not available")
exit(0)
from numpy import concatenate, ones
from numpy.random import randn, seed
import numpy
import sys
import types
import random
import bz2
import pickle
import inspect
###################################################
# Random Data
###################################################
def generate_random_string(length, number):
"""
generate sample over alphabet
"""
dat = []
alphabet = "AGTC"
for i in range(number):
dat.append("".join([random.choice(alphabet) for j in range(length)]))
return dat
def generate_random_data(number):
"""
create random examples and labels
"""
labels = numpy.array([random.choice([-1.0, 1.0]) for i in range(number)])
examples = numpy.array(generate_random_string(22, number))
return examples, labels
def save(filename, myobj):
"""
save object to file using pickle
@param filename: name of destination file
@type filename: str
@param myobj: object to save (has to be pickleable)
@type myobj: obj
"""
try:
f = bz2.BZ2File(filename, 'wb')
except IOError as details:
sys.stderr.write('File ' + filename + ' cannot be written\n')
sys.stderr.write(details)
return
pickle.dump(myobj, f, protocol=2)
f.close()
def load(filename):
"""
Load from filename using pickle
@param filename: name of file to load from
@type filename: str
"""
try:
f = bz2.BZ2File(filename, 'rb')
except IOError as details:
sys.stderr.write('File ' + filename + ' cannot be read\n')
sys.stderr.write(details)
return
myobj = pickle.load(f)
f.close()
return myobj
def get_spectrum_features(data, order=3, gap=0, reverse=True):
"""
create feature object used by spectrum kernel
"""
charfeat = StringCharFeatures(data, DNA)
feat = StringWordFeatures(charfeat.get_alphabet())
feat.obtain_from_char(charfeat, order-1, order, gap, reverse)
preproc = SortWordString()
preproc.fit(feat)
feat = preproc.apply(feat)
return feat
def get_wd_features(data, feat_type="dna"):
"""
create feature object for wdk
"""
if feat_type == "dna":
feat = StringCharFeatures(DNA)
elif feat_type == "protein":
feat = StringCharFeatures(PROTEIN)
else:
raise Exception("unknown feature type")
feat.set_features(data)
return feat
def construct_features(features):
"""
makes a list
"""
feat_all = [inst for inst in features]
feat_lhs = [inst[0:15] for inst in features]
feat_rhs = [inst[15:] for inst in features]
feat_wd = get_wd_features(feat_all)
feat_spec_1 = get_spectrum_features(feat_lhs, order=3)
feat_spec_2 = get_spectrum_features(feat_rhs, order=3)
feat_comb = CombinedFeatures()
feat_comb.append_feature_obj(feat_wd)
feat_comb.append_feature_obj(feat_spec_1)
feat_comb.append_feature_obj(feat_spec_2)
return feat_comb
parameter_list = [[200, 1, 100]]
def serialization_string_kernels(n_data, num_shifts, size):
"""
serialize svm with string kernels
"""
##################################################
# set up toy data and svm
train_xt, train_lt = generate_random_data(n_data)
test_xt, test_lt = generate_random_data(n_data)
feats_train = construct_features(train_xt)
feats_test = construct_features(test_xt)
max_len = len(train_xt[0])
kernel_wdk = WeightedDegreePositionStringKernel(size, 5)
shifts_vector = numpy.ones(max_len, dtype=numpy.int32)*num_shifts
kernel_wdk.set_shifts(shifts_vector)
########
# set up spectrum
use_sign = False
kernel_spec_1 = WeightedCommWordStringKernel(size, use_sign)
kernel_spec_2 = WeightedCommWordStringKernel(size, use_sign)
########
# combined kernel
kernel = CombinedKernel()
kernel.append_kernel(kernel_wdk)
kernel.append_kernel(kernel_spec_1)
kernel.append_kernel(kernel_spec_2)
# init kernel
labels = BinaryLabels(train_lt);
svm = SVMLight(1.0, kernel, labels)
#svm.io.set_loglevel(MSG_DEBUG)
svm.train(feats_train)
##################################################
# serialize to file
fn = "serialized_svm.bz2"
#print("serializing SVM to file", fn)
save(fn, svm)
##################################################
# unserialize and sanity check
#print("unserializing SVM")
svm2 = load(fn)
#print("comparing predictions")
out = svm.apply(feats_test).get_labels()
out2 = svm2.apply(feats_test).get_labels()
# assert outputs are close
for i in range(len(out)):
assert abs(out[i] - out2[i] < 0.000001)
#print("all checks passed.")
return out,out2
if __name__=='__main__':
serialization_string_kernels(*parameter_list[0])