Skip to content

Commit

Permalink
fixed #20
Browse files Browse the repository at this point in the history
add max_iter in smo_kernel
  • Loading branch information
shijiashuai committed Nov 5, 2017
1 parent b3c63ed commit 461dbf5
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 9 deletions.
7 changes: 3 additions & 4 deletions include/thundersvm/kernel/smo_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ bool is_free(float a, float y, float Cp, float Cn);
__global__ void
c_smo_solve_kernel(const int *label, real *f_values, real *alpha, real *alpha_diff, const int *working_set, int ws_size,
float Cp, float Cn, const float *k_mat_rows, const float *k_mat_diag, int row_len, real eps,
real *diff_and_bias);
real *diff_and_bias, int max_iter);

__global__ void
nu_smo_solve_kernel(const int *label, real *f_values, real *alpha, real *alpha_diff, const int *working_set,
int ws_size,
float C, const float *k_mat_rows, const float *k_mat_diag, int row_len, real eps,
real *diff_and_bias);
int ws_size, float C, const float *k_mat_rows, const float *k_mat_diag, int row_len, real eps,
real *diff_and_bias, int max_iter);

__global__ void update_f(real *f, int ws_size, const real *alpha_diff, const real *k_mat_rows, int n_instances);

Expand Down
6 changes: 4 additions & 2 deletions src/thundersvm/kernel/smo_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ __host__ __device__ bool is_free(float a, float y, float Cp, float Cn) {
__global__ void
c_smo_solve_kernel(const int *label, real *f_values, real *alpha, real *alpha_diff, const int *working_set, int ws_size,
float Cp, float Cn, const float *k_mat_rows, const float *k_mat_diag, int row_len, real eps,
real *diff_and_bias) {
real *diff_and_bias, int max_iter) {
//"row_len" equals to the number of instances in the original training dataset.
//allocate shared memory
extern __shared__ int shared_mem[];
Expand Down Expand Up @@ -114,14 +114,15 @@ c_smo_solve_kernel(const int *label, real *f_values, real *alpha, real *alpha_di
float kJ2wsI = k_mat_rows[row_len * j2 + wsi];//K[J2, wsi]
f -= l * (kJ2wsI - kIwsI);
numOfIter++;
if (numOfIter > max_iter) break;
}
}


__global__ void
nu_smo_solve_kernel(const int *label, real *f_values, real *alpha, real *alpha_diff, const int *working_set,
int ws_size, float C, const float *k_mat_rows, const float *k_mat_diag, int row_len, real eps,
real *diff_and_bias) {
real *diff_and_bias, int max_iter) {
//"row_len" equals to the number of instances in the original training dataset.
//allocate shared memory
extern __shared__ int shared_mem[];
Expand Down Expand Up @@ -248,6 +249,7 @@ nu_smo_solve_kernel(const int *label, real *f_values, real *alpha, real *alpha_d
float kJ2wsI = k_mat_rows[row_len * j2 + wsi];//K[J2, wsi]
f -= l * (kJ2wsI - kIwsI);
numOfIter++;
if (numOfIter > max_iter) break;
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/thundersvm/solver/csmosolver.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ void CSMOSolver::smo_kernel(const int *label, real *f_values, real *alpha, real
int ws_size, float Cp, float Cn, const float *k_mat_rows, const float *k_mat_diag,
int row_len, real eps, real *diff_and_bias) const {
size_t smem_size = ws_size * sizeof(real) * 3 + 2 * sizeof(float);
int max_iter = max(100000, ws_size > INT_MAX / 100 ? INT_MAX : 100 * ws_size);
c_smo_solve_kernel << < 1, ws_size, smem_size >> >
(label, f_values, alpha, alpha_diff,
working_set, ws_size, Cp, Cn, k_mat_rows,
k_mat_diag, row_len, eps, diff_and_bias);
k_mat_diag, row_len, eps, diff_and_bias, max_iter);
}
3 changes: 2 additions & 1 deletion src/thundersvm/solver/nusmosolver.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ void NuSMOSolver::smo_kernel(const int *label, real *f_values, real *alpha, real
int ws_size, float Cp, float Cn, const float *k_mat_rows, const float *k_mat_diag,
int row_len, real eps, real *diff_and_bias) const {
//Cn is not used but for compatibility with c-svc
int max_iter = max(100000, ws_size > INT_MAX / 100 ? INT_MAX : 100 * ws_size);
size_t smem_size = ws_size * sizeof(real) * 3 + 2 * sizeof(float);
nu_smo_solve_kernel << < 1, ws_size, smem_size >> >
(label, f_values, alpha, alpha_diff,
working_set, ws_size, Cp, k_mat_rows,
k_mat_diag, row_len, eps, diff_and_bias);
k_mat_diag, row_len, eps, diff_and_bias, max_iter);
}

void NuSMOSolver::select_working_set(vector<int> &ws_indicator, const SyncData<int> &f_idx2sort, const SyncData<int> &y,
Expand Down
1 change: 0 additions & 1 deletion src/thundersvm/thundersvm-train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ int main(int argc, char **argv) {

if (parser.do_cross_validation) {
model->cross_validation(train_dataset, parser.param_cmd, parser.nr_fold);
return 0;
} else {
model->train(train_dataset, parser.param_cmd);
model->save_to_file(parser.model_file_name);
Expand Down

0 comments on commit 461dbf5

Please sign in to comment.