-
Notifications
You must be signed in to change notification settings - Fork 213
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from GODqinbin/master
Improve docs. Add python interface.
- Loading branch information
Showing
4 changed files
with
209 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#!/usr/bin/env python | ||
|
||
from ctypes import * | ||
from ctypes.util import find_library | ||
from os import path | ||
import sys | ||
|
||
|
||
dirname = path.dirname(path.abspath(__file__)) | ||
libsvm = CDLL(path.join(dirname, '../build/lib/libthundersvm-lib.so')) | ||
|
||
def svm_train(param): | ||
param_list = param.split() | ||
param_list.insert(0, 'thundersvm-train') | ||
param_array = (c_char_p * len(param_list))() | ||
param_array[:] = param_list | ||
libsvm.thundersvm_train(len(param_list), param_array) | ||
|
||
def svm_predict(param): | ||
param_list = param.split() | ||
param_list.insert(0, 'thundersvm-predict') | ||
param_array = (c_char_p * len(param_list))() | ||
param_array[:] = param_list | ||
libsvm.thundersvm_predict(len(param_list), param_array) | ||
|
||
#libsvm.thundersvm_train(15, "./thundersvm-train -s 1 -t 2 -g 0.5 -c 100 -n 0.1 -e 0.001 dataset/test_dataset.txt dataset/test_dataset.txt.model"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
//functions for python interface | ||
|
||
#include <thundersvm/util/log.h> | ||
#include <thundersvm/model/svc.h> | ||
#include <thundersvm/model/svr.h> | ||
#include <thundersvm/model/oneclass_svc.h> | ||
#include <thundersvm/model/nusvc.h> | ||
#include <thundersvm/model/nusvr.h> | ||
#include <thundersvm/util/metric.h> | ||
#include "thundersvm/cmdparser.h" | ||
#include <iostream> | ||
|
||
INITIALIZE_EASYLOGGINGPP | ||
extern "C" { | ||
void thundersvm_train(int argc, char **argv) { | ||
CMDParser parser; | ||
parser.parse_command_line(argc, argv); | ||
/* | ||
parser.param_cmd.svm_type = SvmParam::NU_SVC; | ||
parser.param_cmd.kernel_type = SvmParam::RBF; | ||
parser.param_cmd.C = 100; | ||
parser.param_cmd.gamma = 0; | ||
parser.param_cmd.nu = 0.1; | ||
parser.param_cmd.epsilon = 0.001; | ||
*/ | ||
|
||
DataSet train_dataset; | ||
char input_file_path[1024] = DATASET_DIR; | ||
char model_file_path[1024] = DATASET_DIR; | ||
strcat(input_file_path, parser.svmtrain_input_file_name); | ||
strcat(model_file_path, parser.model_file_name); | ||
train_dataset.load_from_file(input_file_path); | ||
SvmModel *model = nullptr; | ||
switch (parser.param_cmd.svm_type) { | ||
case SvmParam::C_SVC: | ||
model = new SVC(); | ||
break; | ||
case SvmParam::NU_SVC: | ||
model = new NuSVC(); | ||
break; | ||
case SvmParam::ONE_CLASS: | ||
model = new OneClassSVC(); | ||
break; | ||
case SvmParam::EPSILON_SVR: | ||
model = new SVR(); | ||
break; | ||
case SvmParam::NU_SVR: | ||
model = new NuSVR(); | ||
break; | ||
} | ||
|
||
//todo add this to check_parameter method | ||
if (parser.param_cmd.svm_type == SvmParam::NU_SVC) { | ||
train_dataset.group_classes(); | ||
for (int i = 0; i < train_dataset.n_classes(); ++i) { | ||
int n1 = train_dataset.count()[i]; | ||
for (int j = i + 1; j < train_dataset.n_classes(); ++j) { | ||
int n2 = train_dataset.count()[j]; | ||
if (parser.param_cmd.nu * (n1 + n2) / 2 > min(n1, n2)) { | ||
printf("specified nu is infeasible\n"); | ||
return; | ||
} | ||
} | ||
} | ||
} | ||
|
||
#ifdef USE_CUDA | ||
CUDA_CHECK(cudaSetDevice(parser.gpu_id)); | ||
#endif | ||
|
||
vector<float_type> predict_y, test_y; | ||
if (parser.do_cross_validation) { | ||
vector<float_type> test_predict = model->cross_validation(train_dataset, parser.param_cmd, parser.nr_fold); | ||
uint dataset_size = test_predict.size() / 2; | ||
test_y.insert(test_y.end(), test_predict.begin(), test_predict.begin() + dataset_size); | ||
predict_y.insert(predict_y.end(), test_predict.begin() + dataset_size, test_predict.end()); | ||
} else { | ||
model->train(train_dataset, parser.param_cmd); | ||
model->save_to_file(model_file_path); | ||
//predict_y = model->predict(train_dataset.instances(), 10000); | ||
//test_y = train_dataset.y(); | ||
} | ||
/* | ||
//perform svm testing | ||
Metric *metric = nullptr; | ||
switch (parser.param_cmd.svm_type) { | ||
case SvmParam::C_SVC: | ||
case SvmParam::NU_SVC: { | ||
metric = new Accuracy(); | ||
break; | ||
} | ||
case SvmParam::EPSILON_SVR: | ||
case SvmParam::NU_SVR: { | ||
metric = new MSE(); | ||
break; | ||
} | ||
case SvmParam::ONE_CLASS: { | ||
} | ||
} | ||
if (metric) { | ||
LOG(INFO) << metric->name() << " = " << metric->score(predict_y, test_y); | ||
} | ||
*/ | ||
return; | ||
} | ||
|
||
void thundersvm_predict(int argc, char **argv){ | ||
CMDParser parser; | ||
parser.parse_command_line(argc, argv); | ||
|
||
char model_file_path[1024] = DATASET_DIR; | ||
char predict_file_path[1024] = DATASET_DIR; | ||
char output_file_path[1024] = DATASET_DIR; | ||
strcat(model_file_path, parser.svmpredict_model_file_name); | ||
strcat(predict_file_path, parser.svmpredict_input_file); | ||
strcat(output_file_path, parser.svmpredict_output_file); | ||
std::fstream file; | ||
file.open(model_file_path, std::fstream::in); | ||
string feature, svm_type; | ||
file >> feature >> svm_type; | ||
CHECK_EQ(feature, "svm_type"); | ||
SvmModel *model = nullptr; | ||
Metric *metric = nullptr; | ||
if (svm_type == "c_svc") { | ||
model = new SVC(); | ||
metric = new Accuracy(); | ||
} else if (svm_type == "nu_svc") { | ||
model = new NuSVC(); | ||
metric = new Accuracy(); | ||
} else if (svm_type == "one_class") { | ||
model = new OneClassSVC(); | ||
//todo determine a metric | ||
} else if (svm_type == "epsilon_svr") { | ||
model = new SVR(); | ||
metric = new MSE(); | ||
} else if (svm_type == "nu_svr") { | ||
model = new NuSVR(); | ||
metric = new MSE(); | ||
} | ||
|
||
#ifdef USE_CUDA | ||
CUDA_CHECK(cudaSetDevice(parser.gpu_id)); | ||
#endif | ||
|
||
model->load_from_file(model_file_path); | ||
file.close(); | ||
file.open(output_file_path); | ||
DataSet predict_dataset; | ||
predict_dataset.load_from_file(predict_file_path); | ||
vector<float_type> predict_y; | ||
predict_y = model->predict(predict_dataset.instances(), 10000); | ||
for (int i = 0; i < predict_y.size(); ++i) { | ||
file << predict_y[i] << std::endl; | ||
} | ||
file.close(); | ||
|
||
if (metric) { | ||
LOG(INFO) << metric->name() << " = " << metric->score(predict_y, predict_dataset.y()); | ||
} | ||
} | ||
} |