Skip to content

Commit

Permalink
fix predict batch size in scikit
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed Jul 14, 2018
1 parent 07c812e commit cae0a26
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/thundersvm/thundersvm-scikit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ extern "C" {
void sparse_decision(int row_size, float* val, int* row_ptr, int* col_ptr, SvmModel *model, int value_size, float* dec_value){
DataSet predict_dataset;
predict_dataset.load_from_sparse(row_size, val, row_ptr, col_ptr, (float *)NULL);
model->predict(predict_dataset.instances(), 10000);
model->predict(predict_dataset.instances(), -1);
SyncArray<float_type> dec_value_array(value_size);
dec_value_array.copy_from(model->get_dec_value());
float_type *dec_value_ptr = dec_value_array.host_data();
Expand All @@ -301,7 +301,7 @@ extern "C" {
void dense_decision(int row_size, int features, float* data, SvmModel *model, int value_size, float* dec_value){
DataSet predict_dataset;
predict_dataset.load_from_dense(row_size, features, data, (float*) NULL);
model->predict(predict_dataset.instances(), 10000);
model->predict(predict_dataset.instances(), -1);
SyncArray<float_type> dec_value_array(value_size);
dec_value_array.copy_from(model->get_dec_value());
float_type *dec_value_ptr = dec_value_array.host_data();
Expand Down

0 comments on commit cae0a26

Please sign in to comment.