Skip to content

Commit

Permalink
Merge pull request #21 from GODqinbin/master
Browse files Browse the repository at this point in the history
finish kernelmatrix_kernel.cpp; add save_to_file and load_from_file in tests; add svm_type in tests; add eigen submodule.
  • Loading branch information
shijiashuai committed Nov 13, 2017
2 parents 0a51dfa + 359ae58 commit aedac40
Show file tree
Hide file tree
Showing 18 changed files with 265 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "eigen"]
path = eigen
url = https://github.com/RLovelett/eigen.git
18 changes: 17 additions & 1 deletion CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,31 @@ cmake_minimum_required(VERSION 2.8.12)
set(PROJECT_NAME_STR thundersvm)
project(${PROJECT_NAME_STR} C CXX)

set(USE_CUDA ON CACHE BOOL "Compile with CUDA")
set(USE_CUDA OFF CACHE BOOL "Compile with CUDA")

find_package(Threads REQUIRED)
find_package(OpenMP)
if (OPENMP_FOUND)
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()

#find_package(MKL REQUIRED)
#include_directories(${MKL_INCLUDE_DIRS})
#link_directories(${MKL_LIBRARIES})
#target_link_libraries(<module>
#mkl_intel_lp64
#mkl_sequential
#mkl_core
#)

if (USE_CUDA)
message("Compile with CUDA")
find_package(CUDA REQUIRED QUIET)
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -std=c++11 -Wno-deprecated-gpu-targets)
set(LINK_LIBRARY ${CUDA_cusparse_LIBRARY})
else ()
include_directories(${PROJECT_SOURCE_DIR}/eigen)
message("Compile without CUDA")
endif ()

Expand Down
1 change: 1 addition & 0 deletions eigen
Submodule eigen added at ca8aa3
2 changes: 1 addition & 1 deletion include/thundersvm/config.h.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#cmakedefine DATASET_DIR "@DATASET_DIR@"
#cmakedefine USE_CUDA
#cmakedefine USE_CUDA
20 changes: 20 additions & 0 deletions include/thundersvm/kernelmatrix_kernel_openmp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <omp.h>
#include <iostream>
#include "thundersvm/thundersvm.h"
#include "thundersvm/clion_cuda.h"
//typedef float real;
void kernel_get_working_set_ins_openmp(const real *val, const int *col_ind, const int *row_ptr, const int *data_row_idx,
real *data_rows,
int m);
void kernel_sum_kernel_values_openmp(const real *k_mat, int n_instances, int n_sv_unique, int n_bin_models,
const int *sv_index, const real *coef, const int *sv_start,
const int *sv_count,
const real *rho, real *dec_values);
void kernel_RBF_kernel_openmp(const real *self_dot0, const real *self_dot1, real *dot_product, int m, int n, real gamma);
void kernel_RBF_kernel_openmp(const int *self_dot0_idx, const real *self_dot1, real *dot_product, int m, int n, real gamma);
void kernel_sum_kernel_values_openmp(const real *k_mat, int n_instances, int n_sv_unique, int n_bin_models,
const int *sv_index, const real *coef, const int *sv_start,
const int *sv_count,
const real *rho, real *dec_values);
void kernel_poly_kernel_openmp(real *dot_product, real gamma, real coef0, int degree, int mn);
void kernel_sigmoid_kernel_openmp(real *dot_product, real gamma, real coef0, int mn);
1 change: 1 addition & 0 deletions src/test/test_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg");
el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput);
//cudaSetDevice(1);
return RUN_ALL_TESTS();
}
7 changes: 5 additions & 2 deletions src/test/test_nusvc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@ class NuSVCTest : public ::testing::Test {
param.epsilon = 0.001;
param.nu = nu;
param.kernel_type = SvmParam::RBF;
param.svm_type = SvmParam::NU_SVC;
// param.probability = 1;
SvmModel *model = new NuSVC();

model->train(train_dataset, param);
predict_y = model->predict(test_dataset.instances(), 10000);
model->save_to_file(train_filename+".model");
SvmModel *new_model = new NuSVC();
new_model->load_from_file(train_filename+".model");
predict_y = new_model->predict(test_dataset.instances(), 10000);
int n_correct = 0;
for (unsigned i = 0; i < predict_y.size(); ++i) {
if (predict_y[i] == test_dataset.y()[i])
Expand Down
7 changes: 5 additions & 2 deletions src/test/test_nusvr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ TEST(NuSVRTest, train) {
param.epsilon = 0.001;
param.nu = 0.5;
param.kernel_type = SvmParam::RBF;
param.svm_type = SvmParam::NU_SVR;
SvmModel *model = new NuSVR();
model->train(dataset, param);

model->save_to_file(DATASET_DIR "test_dataset.txt.model2");
SvmModel *new_model = new NuSVR();
new_model->load_from_file(DATASET_DIR "test_dataset.txt.model2");
vector<real> predict_y;
predict_y = model->predict(dataset.instances(), 100);
predict_y = new_model->predict(dataset.instances(), 100);
real mse = 0;
for (int i = 0; i < predict_y.size(); ++i) {
mse += (predict_y[i] - dataset.y()[i]) * (predict_y[i] - dataset.y()[i]);
Expand Down
7 changes: 5 additions & 2 deletions src/test/test_oneclass_svc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ TEST(OneClassSVCTest, train) {
param.nu = 0.1;
param.epsilon = 0.001;
param.kernel_type = SvmParam::RBF;
param.svm_type = SvmParam::ONE_CLASS;
SvmModel *model = new OneClassSVC();
model->train(dataset, param);

vector<real> predict_y = model->predict(dataset.instances(), 100);
model->save_to_file(DATASET_DIR "test_dataset.txt.model");
SvmModel *new_model = new OneClassSVC();
new_model->load_from_file(DATASET_DIR "test_dataset.txt.model");
vector<real> predict_y = new_model->predict(dataset.instances(), 100);
int n_pos = 0;
for (int i = 0; i < predict_y.size(); ++i) {
if (predict_y[i] > 0)
Expand Down
6 changes: 4 additions & 2 deletions src/test/test_svc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ class SVCTest : public ::testing::Test {
param.kernel_type = SvmParam::RBF;
// param.probability = 1;
SvmModel *model = new SVC();

model->train(train_dataset, param);
predict_y = model->predict(test_dataset.instances(), 10000);
model->save_to_file(train_filename+".model");
SvmModel *new_model = new SVC();
new_model->load_from_file(train_filename+".model");
predict_y = new_model->predict(test_dataset.instances(), 10000);
int n_correct = 0;
for (unsigned i = 0; i < predict_y.size(); ++i) {
if (predict_y[i] == test_dataset.y()[i])
Expand Down
7 changes: 5 additions & 2 deletions src/test/test_svr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ TEST(SVRTest, train) {
param.epsilon = 0.001;
param.nu = 0.5;
param.kernel_type = SvmParam::RBF;
param.svm_type = SvmParam::EPSILON_SVR;
SvmModel *model = new SVR();
model->train(dataset, param);

model->save_to_file(DATASET_DIR "test_dataset.txt.model");
SvmModel *new_model = new SVR();
new_model->load_from_file(DATASET_DIR "test_dataset.txt.model");
vector<real> predict_y;
predict_y = model->predict(dataset.instances(), 100);
predict_y = new_model->predict(dataset.instances(), 100);
real mse = 0;
for (int i = 0; i < predict_y.size(); ++i) {
mse += (predict_y[i] - dataset.y()[i]) * (predict_y[i] - dataset.y()[i]);
Expand Down
113 changes: 104 additions & 9 deletions src/thundersvm/kernel/kernelmatrix_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,137 @@
// Created by jiashuai on 17-11-7.
//
#include <thundersvm/kernel/kernelmatrix_kernel.h>

//#include <mkl.h>
#include <Eigen/Dense>
#include <Eigen/Sparse>
#include <iostream>
namespace svm_kernel {
void get_working_set_ins(const SyncData<real> &val, const SyncData<int> &col_ind, const SyncData<int> &row_ptr,
const SyncData<int> &data_row_idx, SyncData<real> &data_rows, int m) {

//std::cout << "val[0]" << val[0] << std::endl;
//std::cout << "val.host_data[0]" << val.host_data()[0] <<std::endl;
#pragma omp parallel for
for(int i = 0; i < m; i++) {
int row = data_row_idx[i];
for (int j = row_ptr[row]; j < row_ptr[row + 1]; ++j) {
int col = col_ind[j];
data_rows[col * m + i] = val[j]; // row-major for cuSPARSE
}
}
//std::cout << "data_rows[0]" << data_rows[0] << std::endl;
//std::cout << "data_rows.host_data()[0]" << data_rows.host_data()[0] << std::endl;
}

void
RBF_kernel(const SyncData<real> &self_dot0, const SyncData<real> &self_dot1, SyncData<real> &dot_product, int m,
int n, real gamma) {

#pragma omp parallel for
for(int idx = 0; idx < m * n; idx++){
int i = idx / n;//i is row id
int j = idx % n;//j is column id
dot_product[idx] = expf(-(self_dot0[i] + self_dot1[j] - dot_product[idx] * 2) * gamma);
}
}

void
RBF_kernel(const SyncData<int> &self_dot0_idx, const SyncData<real> &self_dot1, SyncData<real> &dot_product, int m,
int n, real gamma) {
#pragma omp parallel for
for(int idx = 0; idx < m * n; idx++){
int i = idx / n;//i is row id
int j = idx % n;//j is column id
dot_product[idx] = expf(-(self_dot1[self_dot0_idx[i]] + self_dot1[j] - dot_product[idx] * 2) * gamma);
}

}

void poly_kernel(SyncData<real> &dot_product, real gamma, real coef0, int degree, int mn) {

#pragma omp parallel for
for(int idx = 0; idx < mn; idx++){
dot_product[idx] = powf(gamma * dot_product[idx] + coef0, degree);
}
}

void sigmoid_kernel(SyncData<real> &dot_product, real gamma, real coef0, int mn) {

#pragma omp parallel for
for(int idx = 0; idx < mn; idx++){
dot_product[idx] = tanhf(gamma * dot_product[idx] + coef0);
}
}

void sum_kernel_values(const SyncData<real> &coef, int total_sv, const SyncData<int> &sv_start,
const SyncData<int> &sv_count, const SyncData<real> &rho, const SyncData<real> &k_mat,
SyncData<real> &dec_values, int n_classes, int n_instances) {

#pragma omp parallel for
for(int idx = 0; idx < n_instances; idx++){
int k = 0;
int n_binary_models = n_classes * (n_classes - 1) / 2;
for (int i = 0; i < n_classes; ++i) {
for (int j = i + 1; j < n_classes; ++j) {
int si = sv_start[i];
int sj = sv_start[j];
int ci = sv_count[i];
int cj = sv_count[j];
const real *coef1 = &coef[(j - 1) * total_sv];
const real *coef2 = &coef[i * total_sv];
const real *k_values = &k_mat[idx * total_sv];
real sum = 0;
#pragma omp parallel for reduction(+:sum)
for (int l = 0; l < ci; ++l) {
sum += coef1[si + l] * k_values[si + l];
}
#pragma omp parallel for reduction(+:sum)
for (int l = 0; l < cj; ++l) {
sum += coef2[sj + l] * k_values[sj + l];
}
dec_values[idx * n_binary_models + k] = sum - rho[k];
k++;
}
}
}
}

void dns_csr_mul(int m, int n, int k, const SyncData<real> &dense_mat, const SyncData<real> &csr_val,
const SyncData<int> &csr_row_ptr, const SyncData<int> &csr_col_ind, int nnz,
SyncData<real> &result) {

/*
for(int row = 0; row < m; row ++){
int nz_value_num = csr_row_ptr[row + 1] - csr_row_ptr[row];
if(nz_value_num != 0){
for(int col = 0; col < n; col++){
real sum = 0;
for(int nz_value_index = csr_row_ptr[row]; nz_value_index < csr_row_ptr[row + 1]; nz_value_index++){
sum += csr_val[nz_value_index] * dense_mat[col + csr_col_ind[nz_value_index] * n];
}
result[row * n + col] = sum;
}
}
}
*/
/*
Eigen::Map<Eigen::Matrix<real, n, k, Eigen::ColMajor> > denseMat(dense_mat.host_data());
Eigen::Map<Eigen::SparseMatrix<real, Eigen::RowMajor> > sparseMat(m, k, nnz, csr_row_ptr.host_data(), csr_col_ind.host_data(), csr_val.host_data());
Eigen::Matrix<real, n, m, Eigen::RowMajor> retMat = denseMat * sparseMat.transpose();
Eigen::Map<Marix<real, n, m, Eigen::RowMajor> > (result.host_data(), retMat.rows(), retMat.cols()) = retMat;
*/
Eigen::Map<const Eigen::MatrixXf> denseMat(dense_mat.host_data(), n, k);
Eigen::Map<const Eigen::SparseMatrix<float, Eigen::RowMajor>> sparseMat(m, k, nnz, csr_row_ptr.host_data(), csr_col_ind.host_data(), csr_val.host_data());
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> retMat = denseMat * sparseMat.transpose();
Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> > (result.host_data(), retMat.rows(), retMat.cols()) = retMat;
/*
float one(1);
float zero(0);
char matdescra[4] = {'G', 0, 0, 'C'};
char transa = 'N';
const char* matdescra_ptr = &matdescra[0];
//dense_mat transpose
mkl_simatcopy();
const int* m_ptr = &m;
const int* n_ptr = &n;
const int* k_ptr = &k;
//BLAS_usmm
mkl_scsrmm(&transa, m_ptr, n_ptr, k_ptr, &one, matdescra_ptr, csr_val.host_data(), csr_col_ind.host_data(),
csr_row_ptr.host_data(), &csr_row_ptr.host_data()[1], dense_mat.host_data(), n_ptr, &zero, result.host_data(), m_ptr);
*/
}
}
}
2 changes: 2 additions & 0 deletions src/thundersvm/kernel/smo_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

//
// Created by jiashuai on 17-11-7.
//
Expand Down Expand Up @@ -26,3 +27,4 @@ namespace svm_kernel {

}
}

1 change: 1 addition & 0 deletions src/thundersvm/kernelmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <thundersvm/svmparam.h>
#include "thundersvm/kernelmatrix.h"
#include "thundersvm/kernel/kernelmatrix_kernel.h"
#include "thundersvm/kernelmatrix_kernel_openmp.h"

using namespace svm_kernel;
KernelMatrix::KernelMatrix(const DataSet::node2d &instances, SvmParam param) {
Expand Down
77 changes: 77 additions & 0 deletions src/thundersvm/kernelmatrix_kernel_openmp.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "thundersvm/kernelmatrix_kernel_openmp.h"
#include <iostream>
void kernel_get_working_set_ins_openmp(const real *val, const int *col_ind, const int *row_ptr, const int *data_row_idx,
real *data_rows,
int m) {
#pragma omp parallel for
for(int i = 0; i < m; i++) {
int row = data_row_idx[i];
//#pragma omp parallel for
for (int j = row_ptr[row]; j < row_ptr[row + 1]; ++j) {
int col = col_ind[j];
data_rows[col * m + i] = val[j]; // row-major for cuSPARSE
}
}
}

void kernel_RBF_kernel_openmp(const real *self_dot0, const real *self_dot1, real *dot_product, int m, int n, real gamma) {
//m rows of kernel matrix, where m is the working set size; n is the number of training instances
#pragma omp parallel for
for(int idx = 0; idx < m * n; idx++) {
int i = idx / n;//i is row id
int j = idx % n;//j is column id
dot_product[idx] = expf(-(self_dot0[i] + self_dot1[j] - dot_product[idx] * 2) * gamma);
}
}

void kernel_RBF_kernel_openmp(const int *self_dot0_idx, const real *self_dot1, real *dot_product, int m, int n, real gamma) {
//compute m rows of kernel matrix, where m is the working set size and n is the number of training instances, according to idx
#pragma omp parallel for
for(int idx = 0; idx < m * n; idx++){
int i = idx / n;//i is row id
int j = idx % n;//j is column id
dot_product[idx] = expf(-(self_dot1[self_dot0_idx[i]] + self_dot1[j] - dot_product[idx] * 2) * gamma);
}
}


void kernel_sum_kernel_values_openmp(const real *k_mat, int n_instances, int n_sv_unique, int n_bin_models,
const int *sv_index, const real *coef, const int *sv_start,
const int *sv_count,
const real *rho, real *dec_values) {//compute decision values for n_instances

#pragma omp parallel for
for(int idx = 0; idx < n_instances * n_bin_models; idx++){
//one iteration uses a binary svm model to predict a decision value of an instance.
//#ifndef _OPENMP
//std::cout<<"no openmp"<<std::endl;
//#endif
int ins_id = idx / n_bin_models;
int model_id = idx % n_bin_models;
real sum = 0;
const real *kernel_row = k_mat + ins_id * n_sv_unique;//kernel values of this instance
int si = sv_start[model_id];
int ci = sv_count[model_id];
#pragma omp parallel for reduction(+:sum)
for (int i = 0; i < ci; ++i) {//TODO: improve by parallelism
sum += coef[si + i] * kernel_row[sv_index[si + i]];//sv_index maps uncompressed sv idx to compressed sv idx.
}
dec_values[idx] = sum - rho[model_id];
}

}

void kernel_poly_kernel_openmp(real *dot_product, real gamma, real coef0, int degree, int mn) {
#pragma omp parallel for
for(int idx = 0; idx < mn; idx++){
dot_product[idx] = powf(gamma * dot_product[idx] + coef0, degree);
}
}

void kernel_sigmoid_kernel_openmp(real *dot_product, real gamma, real coef0, int mn) {
//KERNEL_LOOP(idx, mn) {
#pragma omp parallel for
for(int idx = 0; idx < mn; idx++){
dot_product[idx] = tanhf(gamma * dot_product[idx] + coef0);
}
}

0 comments on commit aedac40

Please sign in to comment.