Skip to content

Commit

Permalink
fix unchanged diff issue
Browse files Browse the repository at this point in the history
  • Loading branch information
zeyiwen committed Jun 21, 2018
1 parent 1fb7ad2 commit 574318b
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/thundersvm/solver/csmosolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ CSMOSolver::solve(const KernelMatrix &k_mat, const SyncArray<int> &y, SyncArray<
//avoid infinite loop of repeated local diff
int same_local_diff_cnt = 0;
float_type previous_local_diff = INFINITY;
int swap_local_diff_cnt = 0;
float_type last_local_diff = INFINITY;
float_type second_last_local_diff = INFINITY;

for (int iter = 0;; ++iter) {
//select working set
Expand Down Expand Up @@ -79,20 +82,33 @@ CSMOSolver::solve(const KernelMatrix &k_mat, const SyncArray<int> &y, SyncArray<
update_f(f_val, alpha_diff, k_mat_rows, k_mat.n_instances());
float_type *diff_data = diff.host_data();
local_iter += diff_data[1];

//track unchanged diff
if (fabs(diff_data[0] - previous_local_diff) < eps * 0.001) {
same_local_diff_cnt++;
} else {
same_local_diff_cnt = 0;
previous_local_diff = diff_data[0];
}

if (iter % 100 == 0)
//track unchanged swapping diff
if(fabs(diff_data[0] - second_last_local_diff) < eps * 0.001){
swap_local_diff_cnt++;
} else {
swap_local_diff_cnt = 0;
}
second_last_local_diff = last_local_diff;
last_local_diff = diff_data[0];

//if (iter % 100 == 0)
LOG(INFO) << "global iter = " << iter << ", total local iter = " << local_iter << ", diff = "
<< diff_data[0];
//todo find some other ways to deal unchanged diff
//training terminates in three conditions: 1. diff stays unchanged; 2. diff is closed to 0; 3. training reaches the limit of iterations.
//repeatedly swapping between two diffs
if ((same_local_diff_cnt >= 10 && fabs(diff_data[0] - 2.0) > eps) || diff_data[0] < eps ||
(out_max_iter != -1) && (iter == out_max_iter)) {
(out_max_iter != -1) && (iter == out_max_iter) ||
(swap_local_diff_cnt >= 10 && fabs(diff_data[0] - 2.0) > eps)) {
rho = calculate_rho(f_val, y, alpha, Cp, Cn);
LOG(INFO) << "global iter = " << iter << ", total local iter = " << local_iter << ", diff = "
<< diff_data[0];
Expand Down

0 comments on commit 574318b

Please sign in to comment.