Skip to content

Commit

Permalink
enable max_memory_size in scikit predict
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed Aug 4, 2018
1 parent efbf58b commit 158cdc5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
7 changes: 4 additions & 3 deletions python/thundersvmScikit.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(self, kernel, degree,
self.max_mem_size = max_mem_size
thundersvm.model_new.restype = c_void_p
self.model = thundersvm.model_new(SVM_TYPE.index(self._impl))
if self.max_mem_size != -1:
thundersvm.set_memory_size(c_void_p(self.model), self.max_mem_size)

def label_validate(self, y):

Expand Down Expand Up @@ -238,7 +240,6 @@ def _validate_for_predict(self, X):
return X

def predict(self, X):

X = self._validate_for_predict(X)
predict = self._sparse_predict if self._sparse else self._dense_predict
return predict(X)
Expand Down Expand Up @@ -294,7 +295,7 @@ def _dense_predict(self, X):
thundersvm.dense_predict(
samples, features, data,
c_void_p(self.model),
self.predict_label_ptr)
self.predict_label_ptr, self.verbose)

self.predict_label = np.array([self.predict_label_ptr[index] for index in range(0, X.shape[0])])
return self.predict_label
Expand All @@ -310,7 +311,7 @@ def _sparse_predict(self, X):
thundersvm.sparse_predict(
X.shape[0], data, indptr, indices,
c_void_p(self.model),
self.predict_label_ptr)
self.predict_label_ptr, self.verbose)

predict_label = [self.predict_label_ptr[index] for index in range(0, X.shape[0])]

Expand Down
3 changes: 2 additions & 1 deletion src/thundersvm/model/svmmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,5 +334,6 @@ int SvmModel::get_working_set_size(int n_instances, int n_features) {
}

void SvmModel::set_max_memory_size(size_t size) {
this->param.max_mem_size = size;
if(size > 0)
this->param.max_mem_size = static_cast<size_t>(size) << 20;
}
16 changes: 14 additions & 2 deletions src/thundersvm/thundersvm-scikit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ extern "C" {
n_classes[0] = model->get_n_classes();
}

int sparse_predict(int row_size, float* val, int* row_ptr, int* col_ptr, SvmModel *model, float* predict_label){
int sparse_predict(int row_size, float* val, int* row_ptr, int* col_ptr, SvmModel *model, float* predict_label, int verbose){
if(verbose)
el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Enabled, "true");
else
el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Enabled, "false");
DataSet predict_dataset;
predict_dataset.load_from_sparse(row_size, val, row_ptr, col_ptr, (float *)NULL);
vector<float_type> predict_y;
Expand Down Expand Up @@ -219,7 +223,11 @@ extern "C" {

}

int dense_predict(int row_size, int features, float* data, SvmModel *model, float* predict_label){
int dense_predict(int row_size, int features, float* data, SvmModel *model, float* predict_label, int verbose){
if(verbose)
el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Enabled, "true");
else
el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Enabled, "false");
DataSet predict_dataset;
predict_dataset.load_from_dense(row_size, features, data, (float*) NULL);
vector<float_type> predict_y;
Expand Down Expand Up @@ -333,4 +341,8 @@ extern "C" {
void get_n_classes(SvmModel *model, int *n_classes){
n_classes[0] = model->get_n_classes();
}

void set_memory_size(SvmModel *model, int m_size){
model->set_max_memory_size(m_size);
}
}

0 comments on commit 158cdc5

Please sign in to comment.