Skip to content

Commit

Permalink
complete svm-train & svm-predict
Browse files Browse the repository at this point in the history
  • Loading branch information
shijiashuai committed Nov 2, 2017
1 parent 05b6cea commit a7fb32c
Show file tree
Hide file tree
Showing 21 changed files with 117 additions and 195 deletions.
2 changes: 1 addition & 1 deletion include/thundersvm/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifndef THUNDERSVM_DATASET_H
#define THUNDERSVM_DATASET_H

#include "thundersvm-train.h"
#include "thundersvm.h"
#include "syncdata.h"
class DataSet {
public:
Expand Down
2 changes: 1 addition & 1 deletion include/thundersvm/kernel/kernelmatrix_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifndef THUNDERSVM_KERNELMATRIX_KERNEL_H
#define THUNDERSVM_KERNELMATRIX_KERNEL_H

#include "thundersvm/thundersvm-train.h"
#include "thundersvm/thundersvm.h"
#include "thundersvm/clion_cuda.h"

__global__ void
Expand Down
2 changes: 1 addition & 1 deletion include/thundersvm/kernel/smo_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#define THUNDERSVM_SMO_KERNEL_H

#include <thundersvm/clion_cuda.h>
#include <thundersvm/thundersvm-train.h>
#include <thundersvm/thundersvm.h>

__host__ __device__

Expand Down
2 changes: 1 addition & 1 deletion include/thundersvm/kernelmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#define THUNDERSVM_KERNELMATRIX_H

#include <cusparse.h>
#include "thundersvm-train.h"
#include "thundersvm.h"
#include "syncdata.h"
#include "dataset.h"
#include "svmparam.h"
Expand Down
2 changes: 0 additions & 2 deletions include/thundersvm/model/oneclass_svc.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ class OneClassSVC : public SvmModel {
public:
void train(DataSet dataset, SvmParam param) override;

// void load_from_file(string path) override;

vector<real> predict(const DataSet::node2d &instances, int batch_size) override;

};
Expand Down
2 changes: 1 addition & 1 deletion include/thundersvm/model/svr.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifndef THUNDERSVM_SVR_H
#define THUNDERSVM_SVR_H

#include "thundersvm/thundersvm-train.h"
#include "thundersvm/thundersvm.h"
#include "svmmodel.h"
#include <map>

Expand Down
2 changes: 1 addition & 1 deletion include/thundersvm/solver/csmosolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifndef THUNDERSVM_CSMOSOLVER_H
#define THUNDERSVM_CSMOSOLVER_H

#include <thundersvm/thundersvm-train.h>
#include <thundersvm/thundersvm.h>
#include <thundersvm/kernelmatrix.h>

class CSMOSolver {
Expand Down
2 changes: 1 addition & 1 deletion include/thundersvm/svmparam.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifndef THUNDERSVM_SVMPARAM_H
#define THUNDERSVM_SVMPARAM_H

#include "thundersvm-train.h"
#include "thundersvm.h"

struct SvmParam {
SvmParam() {
Expand Down
2 changes: 1 addition & 1 deletion include/thundersvm/syncdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifndef THUNDERSVM_SYNCDATA_H
#define THUNDERSVM_SYNCDATA_H

#include "thundersvm-train.h"
#include "thundersvm.h"
#include "syncmem.h"

template<typename T>
Expand Down
2 changes: 1 addition & 1 deletion include/thundersvm/syncmem.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifndef THUNDERSVM_SYNCMEM_H
#define THUNDERSVM_SYNCMEM_H

#include "thundersvm-train.h"
#include "thundersvm.h"

class SyncMem {
public:
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions include/thundersvm/util/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <thundersvm/thundersvm.h>

class Metric {
public:
virtual string name() = 0;

virtual real score(const vector<real> &predict_y, const vector<real> &ground_truth_y) = 0;
Expand Down
18 changes: 8 additions & 10 deletions src/test/test_nusvc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,15 @@ class NuSVCTest : public ::testing::Test {
}
};

//TEST_F(SVCTest, test_set) {
// EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
// "test_dataset.txt", DATASET_DIR
// "test_dataset.txt", 100, 0.5), 0.98, 1e-5);
//TEST_F(NuSVCTest, test_set) {
// load_dataset_and_train(DATASET_DIR "test_dataset.txt", DATASET_DIR "test_dataset.txt", 100, 0.5, 0.1);
//}

TEST_F(NuSVCTest, a9a) {
EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
"a9a", DATASET_DIR
"a9a.t", 100, 0.5, 0.1), 0.826608, 1e-3);
}
//TEST_F(NuSVCTest, a9a) {
// EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
// "a9a", DATASET_DIR
// "a9a.t", 100, 0.5, 0.1), 0.826608, 1e-3);
//}
//TEST_F(NuSVCTest, a1a1024) {
// EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
// "a1a1024", DATASET_DIR
Expand All @@ -62,7 +60,7 @@ TEST_F(NuSVCTest, a9a) {
//TEST_F(NuSVCTest, mnist) {
// load_dataset_and_train(DATASET_DIR "mnist.scale", DATASET_DIR "mnist.scale.t", 10, 0.125, 0.1);
//}
//

TEST_F(NuSVCTest, realsim) {
load_dataset_and_train(DATASET_DIR "real-sim", DATASET_DIR "real-sim", 4, 0.5, 0.1);
}
Expand Down
20 changes: 10 additions & 10 deletions src/test/test_svc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ class SVCTest : public ::testing::Test {
// "test_dataset.txt", 100, 0.5), 0.98, 1e-5);
//}

TEST_F(SVCTest, a9a) {
EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
"a9a", DATASET_DIR
"a9a.t", 100, 0.5), 0.826608, 1e-3);
}
//TEST_F(SVCTest, a9a) {
// EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
// "a9a", DATASET_DIR
// "a9a.t", 100, 0.5), 0.826608, 1e-3);
//}
//
//TEST_F(SVCTest, mnist) {
// load_dataset_and_train(DATASET_DIR "mnist.scale", DATASET_DIR "mnist.scale.t", 10, 0.125);
//}
//
//TEST_F(SVCTest, realsim) {
// EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
// "real-sim", DATASET_DIR
// "real-sim", 4, 0.5), 0.997276, 1e-3);
//}
TEST_F(SVCTest, realsim) {
EXPECT_NEAR(load_dataset_and_train(DATASET_DIR
"real-sim", DATASET_DIR
"real-sim", 4, 0.5), 0.997276, 1e-3);
}
2 changes: 1 addition & 1 deletion src/thundersvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
file(GLOB_RECURSE SRC *.c*)
list(REMOVE_ITEM SRC ${CMAKE_CURRENT_SOURCE_DIR}/thundersvm-train.cpp)
list(REMOVE_ITEM SRC ${CMAKE_CURRENT_SOURCE_DIR}/thundersvm-*.cpp)
cuda_add_library(${PROJECT_NAME}_lib ${SRC})
cuda_add_executable(${PROJECT_NAME}-train thundersvm-train.cpp ${COMMON_INCLUDES})
cuda_add_executable(${PROJECT_NAME}-predict thundersvm-predict.cpp ${COMMON_INCLUDES})
Expand Down
16 changes: 11 additions & 5 deletions src/thundersvm/cmdparser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ void CMDParser::parse_command_line(int argc, char **argv) {
break;
case 'u':
gpu_id = atoi(argv[i]);
break;
default:
fprintf(stderr, "Unknown option: -%c\n", argv[i - 1][1]);
HelpInfo_svmtrain();
Expand Down Expand Up @@ -154,6 +155,9 @@ void CMDParser::parse_command_line(int argc, char **argv) {
case 'b':
// predict_probability = atoi(argv[i]);
break;
case 'u':
gpu_id = atoi(argv[i]);
break;
default:
fprintf(stderr, "Unknown option: -%c\n", argv[i - 1][1]);
HelpInfo_svmpredict();
Expand All @@ -176,10 +180,12 @@ void CMDParser::parse_command_line(int argc, char **argv) {
strcpy(svmpredict_input_file, argv[i]);
strcpy(svmpredict_output_file, argv[i + 2]);
strcpy(svmpredict_model_file_name, argv[i + 1]);
} else {
printf("Usage: thundersvm [options] training_set_file [model_file]\n"
"or: thundersvm_predict [options] test_file model_file output_file\n"
"or: thundersvm_scale [options] data_filename\n");
exit(0);
}
// else {
//
// printf("Usage: thundersvm [options] training_set_file [model_file]\n"
// "or: thundersvm_predict [options] test_file model_file output_file\n"
// "or: thundersvm_scale [options] data_filename\n");
// exit(0);
// }
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
#include <thundersvm/model/oneclass_svc.h>
#include <thundersvm/solver/csmosolver.h>

using namespace std;

void OneClassSVC::train(DataSet dataset, SvmParam param) {
int n_instances = dataset.total_count();
SyncData<real> alpha(n_instances);
Expand Down
39 changes: 39 additions & 0 deletions src/thundersvm/model/svr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//
// Created by jiashuai on 17-10-5.
//
#include <iostream>
#include <thundersvm/kernel/kernelmatrix_kernel.h>
#include <thundersvm/solver/csmosolver.h>
#include "thundersvm/model/svr.h"

void SVR::train(DataSet dataset, SvmParam param) {
int n_instances = dataset.total_count();

//duplicate instances
DataSet::node2d instances_2(dataset.instances());
instances_2.insert(instances_2.end(), dataset.instances().begin(), dataset.instances().end());

KernelMatrix kernelMatrix(instances_2, param);

SyncData<real> f_val(n_instances * 2);
SyncData<int> y(n_instances * 2);

for (int i = 0; i < n_instances; ++i) {
f_val[i] = param.p - dataset.y()[i];
y[i] = +1;
f_val[i + n_instances] = -param.p - dataset.y()[i];
y[i + n_instances] = -1;
}

SyncData<real> alpha_2(n_instances * 2);
alpha_2.mem_set(0);
int ws_size = min(max2power(n_instances) * 2, 1024);
CSMOSolver solver;
solver.solve(kernelMatrix, y, alpha_2, rho, f_val, param.epsilon, param.C, param.C, ws_size);
SyncData<real> alpha(n_instances);
for (int i = 0; i < n_instances; ++i) {
alpha[i] = alpha_2[i] - alpha_2[i + n_instances];
}
record_model(alpha, y, dataset.instances(), param);
}

129 changes: 0 additions & 129 deletions src/thundersvm/model/svr.cu

This file was deleted.

0 comments on commit a7fb32c

Please sign in to comment.