Skip to content

Commit

Permalink
cross-validation now predict y in original order in dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
shijiashuai committed Nov 17, 2017
1 parent 2434965 commit 0906293
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion include/thundersvm/util/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ const int NUM_BLOCKS = 32 * 56;
cudaError_t error = condition; \
CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
} while (0)
#define SAFE_KERNEL_LAUNCH(kernel_name, ...)\
#define SAFE_KERNEL_LAUNCH(kernel_name, ...) \
kernel_name<<<NUM_BLOCKS,BLOCK_SIZE>>>(__VA_ARGS__);\
CUDA_CHECK(cudaPeekAtLastError())
#define KERNEL_LOOP(i, n) \
Expand Down
10 changes: 6 additions & 4 deletions src/thundersvm/kernelmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,12 @@ void KernelMatrix::get_dot_product(const DataSet::node2d &instances, SyncData<fl
for (int i = 0; i < instances.size(); ++i) {
float_type sum = 0;
for (int j = 0; j < instances[i].size(); ++j) {
CHECK_LE(instances[i][j].index, n_features_)
<< "the number of features in testing set is larger than training set";
dense_ins[(instances[i][j].index - 1) * instances.size() + i] = instances[i][j].value;
sum += instances[i][j].value * instances[i][j].value;
if (instances[i][j].index < n_features_) {
dense_ins[(instances[i][j].index - 1) * instances.size() + i] = instances[i][j].value;
sum += instances[i][j].value * instances[i][j].value;
} else {
// LOG(WARNING)<<"the number of features in testing set is larger than training set";
}
}
}
dns_csr_mul(dense_ins, instances.size(), dot_product);
Expand Down
15 changes: 8 additions & 7 deletions src/thundersvm/model/svmmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ void SvmModel::model_setup(const DataSet &dataset, SvmParam &param) {
vector<float_type> SvmModel::cross_validation(DataSet dataset, SvmParam param, int n_fold) {
dataset.group_classes(this->param.svm_type == SvmParam::C_SVC);//group classes only for classification

vector<float_type> y_test_all;
vector<float_type> y_predict_all;
vector<float_type> y_predict_all(dataset.n_instances());

for (int k = 0; k < n_fold; ++k) {
LOG(INFO) << n_fold << " fold cross-validation(" << k + 1 << "/" << n_fold << ")";
DataSet::node2d x_train, x_test;
vector<float_type> y_train, y_test;
vector<int> test_idx;
for (int i = 0; i < dataset.n_classes(); ++i) {
int fold_test_count = dataset.count()[i] / n_fold;
vector<int> class_idx = dataset.original_index(i);
Expand All @@ -44,6 +44,7 @@ vector<float_type> SvmModel::cross_validation(DataSet dataset, SvmParam param, i
for (int j: vector<int>(idx_begin, idx_end)) {
x_test.push_back(dataset.instances()[j]);
y_test.push_back(dataset.y()[j]);
test_idx.push_back(j);
}
class_idx.erase(idx_begin, idx_end);
for (int j:class_idx) {
Expand All @@ -54,12 +55,12 @@ vector<float_type> SvmModel::cross_validation(DataSet dataset, SvmParam param, i
DataSet train_dataset(x_train, dataset.n_features(), y_train);
this->train(train_dataset, param);
vector<float_type> y_predict = this->predict(x_test, 1000);
y_test_all.insert(y_test_all.end(), y_test.begin(), y_test.end());
y_predict_all.insert(y_predict_all.end(), y_predict.begin(), y_predict.end());
CHECK_EQ(y_predict.size(), test_idx.size());
for (int i = 0; i < y_predict.size(); ++i) {
y_predict_all[test_idx[i]] = y_predict[i];
}
}
vector<float_type> test_predict = y_test_all;
test_predict.insert(test_predict.end(), y_predict_all.begin(), y_predict_all.end());
return test_predict;
return y_predict_all;
}


Expand Down
13 changes: 5 additions & 8 deletions src/thundersvm/thundersvm-train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,13 @@ int main(int argc, char **argv) {
CUDA_CHECK(cudaSetDevice(parser.gpu_id));
#endif

vector<float_type> predict_y, test_y;
vector<float_type> predict_y;
if (parser.do_cross_validation) {
vector<float_type> test_predict = model->cross_validation(train_dataset, parser.param_cmd, parser.nr_fold);
uint dataset_size = test_predict.size() / 2;
test_y.insert(test_y.end(), test_predict.begin(), test_predict.begin() + dataset_size);
predict_y.insert(predict_y.end(), test_predict.begin() + dataset_size, test_predict.end());
predict_y = model->cross_validation(train_dataset, parser.param_cmd, parser.nr_fold);
} else {
model->train(train_dataset, parser.param_cmd);
model->save_to_file(parser.model_file_name);
predict_y = model->predict(train_dataset.instances(), 10000);
test_y = train_dataset.y();
}

//perform svm testing
Expand All @@ -88,8 +84,9 @@ int main(int argc, char **argv) {
}
}
if (metric) {
LOG(INFO) << metric->name() << " = " << metric->score(predict_y, test_y);
LOG(INFO) << metric->name() << " = " << metric->score(predict_y, train_dataset.y());
}
return 0;
delete model;
delete metric;
}

0 comments on commit 0906293

Please sign in to comment.