Skip to content

Commit

Permalink
fix nu-svc bugs
Browse files Browse the repository at this point in the history
rename dataset.h to config.h
  • Loading branch information
shijiashuai committed Nov 2, 2017
1 parent b36d9c4 commit b3c63ed
Show file tree
Hide file tree
Showing 16 changed files with 31 additions and 31 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
.idea
.git
.*
!.gitignore
build
dataset/*
!dataset/*.sh
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ set(COMMON_INCLUDES ${PROJECT_SOURCE_DIR}/include ${CMAKE_CURRENT_BINARY_DIR})
set(LINK_LIBRARY ${CUDA_cusparse_LIBRARY})

set(DATASET_DIR ${PROJECT_SOURCE_DIR}/dataset/)
configure_file(include/test/dataset.h.in dataset.h)
configure_file(include/thundersvm/config.h.in config.h)

include_directories(${COMMON_INCLUDES})
add_subdirectory(${PROJECT_SOURCE_DIR}/src/thundersvm)


set(PROJECT_TEST_NAME ${PROJECT_NAME}_test)
set(PROJECT_TEST_NAME ${PROJECT_NAME}-test)
add_subdirectory(${PROJECT_SOURCE_DIR}/src/test)
add_custom_target(runtest
COMMAND ${PROJECT_TEST_NAME})
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion include/thundersvm/model/svc.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class SVC : public SvmModel {

protected:
virtual void train_binary(const DataSet &dataset, int i, int j, int k);
void record_binary_model(int k, const SyncData<real> &alpha, const SyncData<int> &y, real rho,

virtual void record_binary_model(int k, const SyncData<real> &alpha, const SyncData<int> &y, real rho,
const vector<int> &original_index, const DataSet::node2d &original_instance);

private:
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_cross_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Created by jiashuai on 17-10-13.
//
#include <gtest/gtest.h>
#include <dataset.h>
#include <config.h>
#include <thundersvm/model/svc.h>
#include <thundersvm/model/svr.h>

Expand Down
2 changes: 1 addition & 1 deletion src/test/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//
#include "gtest/gtest.h"
#include "thundersvm/dataset.h"
#include <dataset.h>
#include <config.h>
TEST(SvmProblemTest, load_dataset){
DataSet dataSet;
dataSet.load_from_file(DATASET_DIR "test_dataset.txt");
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_kernelmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//
#include "gtest/gtest.h"
#include "thundersvm/kernelmatrix.h"
#include <dataset.h>
#include <config.h>
real rbf_kernel(const DataSet::node2d &instances, int x, int y, real gamma) {
real sum = 0;
auto i = instances[x].begin();
Expand Down
3 changes: 2 additions & 1 deletion src/test/test_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ INITIALIZE_EASYLOGGINGPP
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg");
cudaSetDevice(0);
el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput);
cudaSetDevice(1);
return RUN_ALL_TESTS();
}
17 changes: 6 additions & 11 deletions src/test/test_nusvc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <thundersvm/svmparam.h>
#include <thundersvm/model/svmmodel.h>
#include <thundersvm/model/nusvc.h>
#include <dataset.h>
#include <config.h>

//
// Created by jiashuai on 17-10-30.
Expand Down Expand Up @@ -46,16 +46,11 @@ class NuSVCTest : public ::testing::Test {
// load_dataset_and_train(DATASET_DIR "test_dataset.txt", DATASET_DIR "test_dataset.txt", 100, 0.5, 0.1);
//}

//TEST_F(NuSVCTest, a9a) {
// EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
// "a9a", DATASET_DIR
// "a9a.t", 100, 0.5, 0.1), 0.826608, 1e-3);
//}
//TEST_F(NuSVCTest, a1a1024) {
// EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
// "a1a1024", DATASET_DIR
// "a1a1024", 100, 0.5), 0.826608, 1e-3);
//}
TEST_F(NuSVCTest, a9a) {
EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
"a9a", DATASET_DIR
"a9a.t", 100, 0.5, 0.1), 0.826608, 1e-3);
}
//
//TEST_F(NuSVCTest, mnist) {
// load_dataset_and_train(DATASET_DIR "mnist.scale", DATASET_DIR "mnist.scale.t", 10, 0.125, 0.1);
Expand Down
4 changes: 2 additions & 2 deletions src/test/test_nusvr.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <src/test/gtest/src/googletest/googletest/include/gtest/gtest.h>
#include <gtest/gtest.h>
#include <thundersvm/dataset.h>
#include <dataset.h>
#include <config.h>
#include <thundersvm/model/svmmodel.h>
#include <thundersvm/model/nusvr.h>

Expand Down
2 changes: 1 addition & 1 deletion src/test/test_oneclass_svc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//
#include <gtest/gtest.h>
#include <thundersvm/model/oneclass_svc.h>
#include <dataset.h>
#include <config.h>

TEST(OneClassSVCTest, train) {
DataSet dataset;
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_svc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//
#include <thundersvm/model/svc.h>
#include "gtest/gtest.h"
#include "dataset.h"
#include "config.h"

class SVCTest : public ::testing::Test {
protected:
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_svr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Created by jiashuai on 17-10-5.
//
#include <gtest/gtest.h>
#include <dataset.h>
#include <config.h>
#include <thundersvm/model/nusvr.h>

TEST(SVRTest, train) {
Expand Down
6 changes: 3 additions & 3 deletions src/thundersvm/kernel/smo_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ nu_smo_solve_kernel(const int *label, real *f_values, real *alpha, real *alpha_d

if (numOfIter == 0) {
local_eps = max(eps, 0.1f * local_diff);
if (tid == 0) {
diff_and_bias[0] = local_diff;
}
}

if (local_diff < local_eps) {
alpha[wsi] = a;
alpha_diff[tid] = -(a - aold) * y;
if (tid == 0) {
diff_and_bias[0] = local_diff;
}
break;
}
__syncthreads();
Expand Down
6 changes: 2 additions & 4 deletions src/thundersvm/model/svmmodel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,10 @@ void SvmModel::save_to_file(string path) {
fs_model << "total_sv " << sv.size() << endl;
fs_model << "rho " << rho << endl;
fs_model << "SV " << endl;
vector<real> sv_coef = this->coef;
vector<vector<DataSet::node>> SV = this->sv;
for (int i = 0; i < sv.size(); i++) {
fs_model << setprecision(16) << sv_coef[i] << " ";
fs_model << setprecision(16) << coef[i] << " ";

vector<DataSet::node> p = SV[sv_index[i]];
vector<DataSet::node> p = sv[sv_index[i]];
int k = 0;
// if (param.kernel_type == SvmParam::PRECOMPUTED)
// fs_model << "0:" << p[k].value << " ";
Expand Down
6 changes: 5 additions & 1 deletion src/thundersvm/solver/csmosolver.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,16 @@ void CSMOSolver::solve(const KernelMatrix &k_mat, const SyncData<int> &y, SyncDa
//update f
SAFE_KERNEL_LAUNCH(update_f, f_val.device_data(), ws_size, alpha_diff.device_data(), k_mat_rows.device_data(),
n_instances);
LOG_EVERY_N(10, INFO) << "diff=" << diff_and_bias[0];
if (iter % 10 == 0) {
printf(".");
std::cout.flush();
}
if (diff_and_bias[0] < eps) {
rho = calculate_rho(f_val, y, alpha, Cp, Cn);
break;
}
}
printf("\n");
}

void CSMOSolver::select_working_set(vector<int> &ws_indicator, const SyncData<int> &f_idx2sort, const SyncData<int> &y,
Expand Down

0 comments on commit b3c63ed

Please sign in to comment.