-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
evaluation_cross_validation_classification.cpp
134 lines (106 loc) · 3.84 KB
/
evaluation_cross_validation_classification.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Heiko Strathmann, Soeren Sonnenburg, Jacob Walker, Evgeniy Andreev,
* Soumyajit De, Sergey Lisitsyn
*/
#include <shogun/base/init.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/kernel/GaussianKernel.h>
#include <shogun/classifier/svm/LibSVM.h>
#include <shogun/evaluation/CrossValidation.h>
#include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
#include <shogun/evaluation/ContingencyTableEvaluation.h>
using namespace shogun;
void print_message(FILE* target, const char* str)
{
fprintf(target, "%s", str);
}
void test_cross_validation()
{
/* data matrix dimensions */
index_t num_vectors=40;
index_t num_features=5;
auto m_rng = std::unique_ptr<CRandom>(new CRandom());
/* data means -1, 1 in all components, std deviation of 3 */
SGVector<float64_t> mean_1(num_features);
SGVector<float64_t> mean_2(num_features);
SGVector<float64_t>::fill_vector(mean_1.vector, mean_1.vlen, -1.0);
SGVector<float64_t>::fill_vector(mean_2.vector, mean_2.vlen, 1.0);
float64_t sigma=3;
SGVector<float64_t>::display_vector(mean_1.vector, mean_1.vlen, "mean 1");
SGVector<float64_t>::display_vector(mean_2.vector, mean_2.vlen, "mean 2");
/* fill data matrix around mean */
SGMatrix<float64_t> train_dat(num_features, num_vectors);
for (index_t i=0; i<num_vectors; ++i)
{
for (index_t j=0; j<num_features; ++j)
{
float64_t mean=i<num_vectors/2 ? mean_1.vector[0] : mean_2.vector[0];
train_dat.matrix[i * num_features + j] =
m_rng->normal_random(mean, sigma);
}
}
/* training features */
CDenseFeatures<float64_t>* features=
new CDenseFeatures<float64_t>(train_dat);
SG_REF(features);
/* training labels +/- 1 for each cluster */
SGVector<float64_t> lab(num_vectors);
for (index_t i=0; i<num_vectors; ++i)
lab.vector[i]=i<num_vectors/2 ? -1.0 : 1.0;
CBinaryLabels* labels=new CBinaryLabels(lab);
/* gaussian kernel */
int32_t kernel_cache=100;
int32_t width=10;
CGaussianKernel* kernel=new CGaussianKernel(kernel_cache, width);
kernel->init(features, features);
/* create svm via libsvm */
float64_t svm_C=10;
float64_t svm_eps=0.0001;
CLibSVM* svm=new CLibSVM(svm_C, kernel, labels);
svm->set_epsilon(svm_eps);
/* train and output */
svm->train(features);
CBinaryLabels* output=CLabelsFactory::to_binary(svm->apply(features));
for (index_t i=0; i<num_vectors; ++i)
SG_SPRINT("i=%d, class=%f,\n", i, output->get_label(i));
/* evaluation criterion */
CContingencyTableEvaluation* eval_crit=
new CContingencyTableEvaluation(ACCURACY);
/* evaluate training error */
float64_t eval_result=eval_crit->evaluate(output, labels);
SG_SPRINT("training error: %f\n", eval_result);
SG_UNREF(output);
/* assert that regression "works". this is not guaranteed to always work
* but should be a really coarse check to see if everything is going
* approx. right */
ASSERT(eval_result<2);
/* splitting strategy */
index_t n_folds=5;
CStratifiedCrossValidationSplitting* splitting=
new CStratifiedCrossValidationSplitting(labels, n_folds);
/* cross validation instance, 10 runs, 95% confidence interval */
CCrossValidation* cross=new CCrossValidation(svm, features, labels,
splitting, eval_crit);
cross->set_num_runs(10);
// cross->set_conf_int_alpha(0.05);
/* actual evaluation */
CCrossValidationResult* result=(CCrossValidationResult*)cross->evaluate();
if (result->get_result_type() != CROSSVALIDATION_RESULT)
SG_SERROR("Evaluation result is not of type CrossValidationResult!");
result->print_result();
/* clean up */
SG_UNREF(result);
SG_UNREF(cross);
SG_UNREF(features);
}
int main(int argc, char **argv)
{
init_shogun(&print_message, &print_message, &print_message);
sg_io->set_loglevel(MSG_DEBUG);
test_cross_validation();
exit_shogun();
return 0;
}