-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
evaluation_cross_validation_multiclass_mkl.cpp
123 lines (99 loc) · 3.46 KB
/
evaluation_cross_validation_multiclass_mkl.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Heiko Strathmann, Soeren Sonnenburg, Giovanni De Toni, Soumyajit De,
* Viktor Gal, Thoralf Klein, Alexander Binder, Sergey Lisitsyn
*/
#include <shogun/base/init.h>
#include <shogun/io/CSVFile.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/kernel/GaussianKernel.h>
#include <shogun/kernel/LinearKernel.h>
#include <shogun/kernel/PolyKernel.h>
#include <shogun/kernel/CombinedKernel.h>
#include <shogun/classifier/mkl/MKLMulticlass.h>
#include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
#include <shogun/evaluation/CrossValidation.h>
#include <shogun/evaluation/MulticlassAccuracy.h>
using namespace shogun;
/* cross-validation instances */
const index_t n_folds=2;
const index_t n_runs=2;
/* file data */
const char fname_feats[]="../data/fm_train_real.dat";
const char fname_labels[]="../data/label_train_multiclass.dat";
void test_multiclass_mkl_cv()
{
CMath::init_random(12);
/* dense features from matrix */
CCSVFile* feature_file = new CCSVFile(fname_feats);
SGMatrix<float64_t> mat=SGMatrix<float64_t>();
mat.load(feature_file);
SG_UNREF(feature_file);
CDenseFeatures<float64_t>* features=new CDenseFeatures<float64_t>(mat);
SG_REF(features);
/* labels from vector */
CCSVFile* label_file = new CCSVFile(fname_labels);
SGVector<float64_t> label_vec;
label_vec.load(label_file);
SG_UNREF(label_file);
CMulticlassLabels* labels=new CMulticlassLabels(label_vec);
SG_REF(labels);
/* combined features and kernel */
CCombinedFeatures *cfeats=new CCombinedFeatures();
CCombinedKernel *cker=new CCombinedKernel();
SG_REF(cfeats);
SG_REF(cker);
/** 1st kernel: gaussian */
cfeats->append_feature_obj(features);
cker->append_kernel(new CGaussianKernel(features, features, 1.2, 10));
/** 2nd kernel: linear */
cfeats->append_feature_obj(features);
cker->append_kernel(new CLinearKernel(features, features));
/** 3rd kernel: poly */
cfeats->append_feature_obj(features);
cker->append_kernel(new CPolyKernel(features, features, 2, true, 10));
cker->init(cfeats, cfeats);
/* create mkl instance */
CMKLMulticlass* mkl=new CMKLMulticlass(1.2, cker, labels);
SG_REF(mkl);
mkl->set_epsilon(0.00001);
mkl->parallel->set_num_threads(1);
mkl->set_mkl_epsilon(0.001);
mkl->set_mkl_norm(1.5);
/* train to see weights */
mkl->train();
cker->get_subkernel_weights().display_vector("weights");
CMulticlassAccuracy* eval_crit=new CMulticlassAccuracy();
CStratifiedCrossValidationSplitting* splitting=
new CStratifiedCrossValidationSplitting(labels, n_folds);
splitting->set_seed(12);
CCrossValidation *cross=new CCrossValidation(mkl, cfeats, labels, splitting,
eval_crit);
cross->set_autolock(false);
cross->set_num_runs(n_runs);
// cross->set_conf_int_alpha(0.05);
/* perform x-val and print result */
CCrossValidationResult* result=(CCrossValidationResult*)cross->evaluate();
SG_SPRINT(
"mean of %d %d-fold x-val runs: %f\n", n_runs, n_folds,
result->get_mean());
/* assert high accuracy */
ASSERT(result->get_mean() > 0.81);
/* clean up */
SG_UNREF(features);
SG_UNREF(labels);
SG_UNREF(cfeats);
SG_UNREF(cker);
SG_UNREF(mkl);
SG_UNREF(cross);
SG_UNREF(result);
}
int main(int argc, char** argv){
shogun::init_shogun_with_defaults();
// sg_io->set_loglevel(MSG_DEBUG);
/* performs cross-validation on a multi-class mkl machine */
test_multiclass_mkl_cv();
exit_shogun();
}