Skip to content

Commit

Permalink
fix bug #104
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed Dec 9, 2018
1 parent 193b7d0 commit 7cbb543
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 0 deletions.
3 changes: 3 additions & 0 deletions include/thundersvm/model/svmmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class SvmModel {

//return prob_predict
const vector<float> &get_prob_predict() const;

//get the params, for scikit load params
void get_param(char* kernel_type, int* degree, float* gamma, float* coef0, int* probability);
protected:

/**
Expand Down
13 changes: 13 additions & 0 deletions python/thundersvmScikit.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,19 @@ def save_to_file(self, path):

def load_from_file(self, path):
thundersvm.load_from_file_scikit(c_void_p(self.model), path.encode('utf-8'))
degree = (c_int * 1)()
gamma = (c_float * 1)()
coef0 = (c_float * 1)()
probability = (c_int * 1)()
kernel = (c_char * 20)()
thundersvm.init_model_param(kernel, degree, gamma,
coef0, probability,c_void_p(self.model))
self.kernel = kernel.value
self.degree = degree[0]
if gamma[0] != 0.0:
self.gamma = gamma[0]
self.coef0 = coef0[0]
self.probability = probability[0]


class SVC(SvmModel, ClassifierMixin):
Expand Down
25 changes: 25 additions & 0 deletions src/thundersvm/model/svmmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,28 @@ void SvmModel::set_max_memory_size_Byte(size_t size) {
if(size > 0)
this->param.max_mem_size = static_cast<size_t>(size);
}

void SvmModel::get_param(char* kernel_type, int* degree, float* gamma, float* coef0, int* probability){
switch(param.kernel_type){
case 0:
strcpy(kernel_type, "linear");
break;
case 1:
strcpy(kernel_type, "polynomial");
break;
case 2:
strcpy(kernel_type, "rbf");
kernel_type = (char *)"rbf";
break;
case 3:
strcpy(kernel_type, "sigmoid");
break;
case 4:
strcpy(kernel_type, "precomputed");
break;
}
*degree = param.degree;
*gamma = param.gamma;
*coef0 = param.coef0;
*probability = param.probability;
}
4 changes: 4 additions & 0 deletions src/thundersvm/thundersvm-scikit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ extern "C" {
}
}

void init_model_param(char* kernel_type, int* degree, float* gamma, float* coef0, int* probability, SvmModel* model){
model->get_param(kernel_type, degree, gamma, coef0, probability);
}

void sparse_model_scikit(int row_size, float* val, int* row_ptr, int* col_ptr, float* label,
int svm_type, int kernel_type, int degree, float gamma, float coef0,
float cost, float nu, float epsilon, float tol, int probability,
Expand Down

0 comments on commit 7cbb543

Please sign in to comment.