Skip to content

Commit

Permalink
add cross validation option for R interface
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed Jun 27, 2018
1 parent a114e8b commit ef7e7b6
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 23 deletions.
3 changes: 3 additions & 0 deletions R/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ svm_predict_R(test_dataset = "../dataset/test_dataset.txt", model_file = "../dat

*class_weight*: {dict, ‘balanced’}, optional(default=None)\
set the parameter C of class i to weight*C, for C-SVC

*cv*: int, optional(default=-1)\
specify the number of folds in cross-validation, or -1 for no cross-validation.

*verbose*: bool(default=False)\
enable verbose output. Note that this setting takes advantage of a per-process runtime setting; if enabled, ThunderSVM may not work properly in a multithreaded context.
Expand Down
4 changes: 2 additions & 2 deletions R/svm.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ svm_train_R <-
function(
svm_type = 0, kernel = 2,degree = 3,gamma = 'auto',
coef0 = 0.0, nu = 0.5, cost = 1.0, epsilon = 0.1,
tol = 0.001, probability = FALSE, class_weight = 'None',
tol = 0.001, probability = FALSE, class_weight = 'None', cv = '-1',
verbose = FALSE, max_iter = -1, n_cores = -1, dataset = 'None', model_file = 'None'
)
{
res <- .C("train_R", as.character(dataset), as.integer(kernel), as.integer(svm_type),
as.integer(degree), as.character(gamma), as.double(coef0), as.double(nu),
as.double(cost), as.double(epsilon), as.double(tol), as.integer(probability),
as.character(class_weight), as.integer(length(class_weight)),
as.character(class_weight), as.integer(length(class_weight)), as.integer(cv),
as.integer(verbose), as.integer(max_iter), as.integer(n_cores), as.character(model_file))
}

Expand Down
58 changes: 37 additions & 21 deletions src/thundersvm/svm_R_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ extern "C" {
int* degree, char** gamma, double* coef0,
double* nu, double* cost, double* epsilon,
double* tol, int* probability,
char** class_weight, int* weight_length,
char** class_weight, int* weight_length,int* n_fold,
int* verbose, int* max_iter, int* n_cores, char **model_file){
int* succeed = new int[1];
succeed[0] = 1;
Expand Down Expand Up @@ -239,29 +239,45 @@ extern "C" {
ind++;
}
}
vector<float_type> predict_y, test_y;
model->train(train_dataset, param_cmd);
model->save_to_file(model_file_path);
LOG(INFO) << "evaluating training score";
predict_y = model->predict(train_dataset.instances(), -1);
Metric *metric = nullptr;
switch (param_cmd.svm_type) {
case SvmParam::C_SVC:
case SvmParam::NU_SVC: {
metric = new Accuracy();
break;
}
case SvmParam::EPSILON_SVR:
case SvmParam::NU_SVR: {
metric = new MSE();
break;

int nr_fold = *n_fold;
vector<float_type> predict_y;
bool do_cross_validation = false;
if (nr_fold != -1) {
do_cross_validation = true;
predict_y = model->cross_validation(train_dataset, param_cmd, nr_fold);
} else {
model->train(train_dataset, param_cmd);
LOG(INFO) << "training finished";
model->save_to_file(model_file_path);
// LOG(INFO) << "evaluating training score";
// predict_y = model->predict(train_dataset.instances(), -1);
}
// vector<float_type> predict_y, test_y;
// model->train(train_dataset, param_cmd);
// model->save_to_file(model_file_path);
// LOG(INFO) << "evaluating training score";
// predict_y = model->predict(train_dataset.instances(), -1);
if(do_cross_validation) {
Metric *metric = nullptr;
switch (param_cmd.svm_type) {
case SvmParam::C_SVC:
case SvmParam::NU_SVC: {
metric = new Accuracy();
break;
}
case SvmParam::EPSILON_SVR:
case SvmParam::NU_SVR: {
metric = new MSE();
break;
}
case SvmParam::ONE_CLASS: {
}
}
case SvmParam::ONE_CLASS: {
if (metric) {
std::cout << metric->name() << " = " << metric->score(predict_y, train_dataset.y()) << std::endl;
}
}
if (metric) {
std::cout << metric->name() << " = " << metric->score(predict_y, train_dataset.y()) << std::endl;
}
return succeed;
// model->train(train_dataset, param_cmd);
// model->save_to_file(model_file_path);
Expand Down

0 comments on commit ef7e7b6

Please sign in to comment.