Skip to content

Commit

Permalink
fix bug #101
Browse files Browse the repository at this point in the history
  • Loading branch information
blackjack201312 committed Sep 4, 2018
1 parent d4eb6e4 commit 4eef25a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
7 changes: 5 additions & 2 deletions src/thundersvm/model/svmmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,16 @@ SvmModel::predict_dec_values(const DataSet::node2d &instances, SyncArray<float_t
vector<float_type> SvmModel::predict(const DataSet::node2d &instances, int batch_size = -1) {
// param.max_mem_size
dec_values.resize(instances.size() * n_binary_models);
vector<float_type> dec_values_vec(dec_values.size());
dec_values.set_host_data(dec_values_vec.data());
// vector<float_type> dec_values_vec(dec_values.size());
// dec_values.set_host_data(dec_values_vec.data());
#ifdef USE_CUDA
dec_values.to_device();//reserve space
#endif
predict_dec_values(instances, dec_values, batch_size);
dec_values.to_host();//copy back from device
float_type* dec_values_host = dec_values.host_data();
vector<float_type> dec_values_vec(dec_values.size());
memcpy(dec_values_vec.data(), dec_values_host, dec_values.size() * sizeof(float_type));
return dec_values_vec;
}

Expand Down
7 changes: 4 additions & 3 deletions src/thundersvm/thundersvm-scikit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,10 @@ extern "C" {
DataSet predict_dataset;
predict_dataset.load_from_dense(row_size, features, data, (float*) NULL);
model->predict(predict_dataset.instances(), -1);
SyncArray<float_type> dec_value_array(value_size);
dec_value_array.copy_from(model->get_dec_value());
float_type *dec_value_ptr = dec_value_array.host_data();
//SyncArray<float_type> dec_value_array(value_size);
//dec_value_array.copy_from(model->get_dec_value());
const SyncArray<float_type>& dec_value_array = model->get_dec_value();
const float_type *dec_value_ptr = dec_value_array.host_data();
for(int i = 0; i < dec_value_array.size(); i++){
dec_value[i] = dec_value_ptr[i];
}
Expand Down

0 comments on commit 4eef25a

Please sign in to comment.