Skip to content

Commit

Permalink
fixed #20
Browse files Browse the repository at this point in the history
initialize f_val in smaller batch
  • Loading branch information
shijiashuai committed Nov 5, 2017
1 parent 461dbf5 commit 468c1e8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/test/test_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg");
el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput);
cudaSetDevice(1);
cudaSetDevice(0);
return RUN_ALL_TESTS();
}
25 changes: 14 additions & 11 deletions src/thundersvm/solver/csmosolver.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,27 @@ CSMOSolver::calculate_rho(const SyncData <real> &f_val, const SyncData<int> &y,

void CSMOSolver::init_f(const SyncData<real> &alpha, const SyncData<int> &y, const KernelMatrix &k_mat,
SyncData<real> &f_val) const {
//TODO initialize with smaller batch to reduce memory usage
//todo auto set batch size
int batch_size = 100;
vector<int> idx_vec;
vector<real> alpha_diff_vec;
for (int i = 0; i < alpha.size(); ++i) {
if (alpha[i] != 0) {
idx_vec.push_back(i);
alpha_diff_vec.push_back(-alpha[i] * y[i]);
}
}
if (idx_vec.size() > 0) {
SyncData<int> idx(idx_vec.size());
SyncData<real> alpha_diff(idx_vec.size());
idx.copy_from(idx_vec.data(), idx_vec.size());
alpha_diff.copy_from(alpha_diff_vec.data(), idx_vec.size());
SyncData<real> kernel_rows(idx.size() * k_mat.n_instances());
k_mat.get_rows(idx, kernel_rows);
SAFE_KERNEL_LAUNCH(update_f, f_val.device_data(), idx.size(), alpha_diff.device_data(),
kernel_rows.device_data(), k_mat.n_instances());
if (idx_vec.size() > batch_size || (i == alpha.size() - 1 && idx_vec.size() > 0)) {
SyncData<int> idx(idx_vec.size());
SyncData<real> alpha_diff(idx_vec.size());
idx.copy_from(idx_vec.data(), idx_vec.size());
alpha_diff.copy_from(alpha_diff_vec.data(), idx_vec.size());
SyncData<real> kernel_rows(idx.size() * k_mat.n_instances());
k_mat.get_rows(idx, kernel_rows);
SAFE_KERNEL_LAUNCH(update_f, f_val.device_data(), idx.size(), alpha_diff.device_data(),
kernel_rows.device_data(), k_mat.n_instances());
idx_vec.clear();
alpha_diff_vec.clear();
}
}
}

Expand Down
15 changes: 15 additions & 0 deletions src/thundersvm/thundersvm-train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ int main(int argc, char **argv) {
break;
}

//todo add this to check_parameter method
if (parser.param_cmd.svm_type == SvmParam::NU_SVC) {
train_dataset.group_classes();
for (int i = 0; i < train_dataset.n_classes(); ++i) {
int n1 = train_dataset.count()[i];
for (int j = i + 1; j < train_dataset.n_classes(); ++j) {
int n2 = train_dataset.count()[j];
if (parser.param_cmd.nu * (n1 + n2) / 2 > min(n1, n2)) {
printf("specified nu is infeasible\n");
return 1;
}
}
}
}

CUDA_CHECK(cudaSetDevice(parser.gpu_id));

if (parser.do_cross_validation) {
Expand Down

0 comments on commit 468c1e8

Please sign in to comment.