Skip to content

Commit

Permalink
fix issue #215
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed Apr 29, 2020
1 parent d3f3906 commit 250d5a8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
17 changes: 9 additions & 8 deletions python/thundersvm/thundersvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def load_from_string(self, data):
thundersvm.set_memory_size(c_void_p(self.model), self.max_mem_size)
thundersvm.load_from_string_scikit(c_void_p(self.model), data)
self._post_load_init()

def _post_load_init(self):
degree = (c_int * 1)()
gamma = (c_float * 1)()
Expand Down Expand Up @@ -483,11 +483,12 @@ def _post_load_init(self):
thundersvm.get_rho(rho, rho_size, c_void_p(self.model))
self.intercept_ = np.frombuffer(rho, dtype=np.float32).astype(float)

# if self.kernel == 'linear':
# coef = (c_float * (self.n_binary_model * self.n_sv))()
# thundersvm.get_linear_coef(coef, self.n_binary_model, self.n_features, c_void_p(self.model))
# self.coef_ = np.array([coef[index] for index in range(0, self.n_binary_model * self.n_features)]).astype(float)
# self.coef_ = np.reshape(self.coef_, (self.n_binary_model, self.n_features))
if self.kernel == 'linear':
coef = (c_float * (self.n_binary_model * self.n_features))()
thundersvm.get_linear_coef(coef, self.n_binary_model, self.n_features, c_void_p(self.model))
self.coef_ = np.frombuffer(coef, dtype=np.float32) \
.astype(float) \
.reshape((self.n_binary_model, self.n_features))

self.kernel = kernel.value.decode()
self.degree = degree[0]
Expand All @@ -508,8 +509,8 @@ def __getstate__(self):
def __setstate__(self, state):
self.__dict__.update(state)
if '_saved_as_str' in state:
self.load_from_string(state['_saved_as_str'])
self.load_from_string(state['_saved_as_str'])


class SVC(SvmModel, ClassifierMixin):
_impl = 'c_svc'
Expand Down
22 changes: 22 additions & 0 deletions src/thundersvm/model/svmmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ string SvmModel::save_to_string() {
}
s_model << endl;
}
if (param.kernel_type == SvmParam::LINEAR) {
s_model << "coef " << endl;
const float_type *linear_coef_data = linear_coef.host_data();
for (int i = 0; i < linear_coef.size(); i++) {
s_model << setprecision(16) << linear_coef_data[i] << " ";
}
s_model << endl;
}
s_model << "SV " << endl;
const float_type *coef_data = coef.host_data();
for (int i = 0; i < sv.size(); i++) {
Expand Down Expand Up @@ -274,6 +282,20 @@ void SvmModel::load_from_string(string data) {
for (int i = 0; i < n_binary_models; ++i) {
iss >> probB[i];
}
} else if (feature == "coef"){
string line;
getline(iss, line);
getline(iss, line);
int size = 0;
for (int i = 0; line[i] !='\0'; i++){
if(line[i] == ' ')
size++;
}
linear_coef.resize(size);
stringstream ss(line);
for(int i = 0; i < size; i++){
ss >> linear_coef.host_data()[i];
}
} else if (feature == "SV") {
sv.clear();
coef.resize((n_classes - 1) * n_total_sv);
Expand Down

0 comments on commit 250d5a8

Please sign in to comment.