-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
balanced_conditional_probability_tree.cpp
105 lines (87 loc) · 3.11 KB
/
balanced_conditional_probability_tree.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2011 Shashwat Lal Das
* Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
*
* This example demonstrates use of the Vowpal Wabbit learning algorithm.
*/
#include <iostream>
#include <shogun/base/init.h>
#include <shogun/lib/common.h>
#include <shogun/io/streaming/StreamingAsciiFile.h>
#include <shogun/features/streaming/StreamingDenseFeatures.h>
#include <shogun/multiclass/tree/BalancedConditionalProbabilityTree.h>
using namespace shogun;
int main(int argc, char **argv)
{
init_shogun_with_defaults();
try {
const char* train_file_name = "../data/7class_example4_train.dense";
const char* test_file_name = "../data/7class_example4_test.dense";
CStreamingAsciiFile* train_file = new CStreamingAsciiFile(train_file_name);
SG_REF(train_file);
CStreamingDenseFeatures<float32_t>* train_features = new CStreamingDenseFeatures<float32_t>(train_file, true, 1024);
SG_REF(train_features);
CBalancedConditionalProbabilityTree *cpt = new CBalancedConditionalProbabilityTree();
cpt->set_num_passes(1);
cpt->set_features(train_features);
if (argc > 1)
{
float64_t alpha = 0.5;
sscanf(argv[1], "%lf", &alpha);
SG_SPRINT("Setting alpha to %.2lf\n", alpha);
cpt->set_alpha(alpha);
}
cpt->train();
cpt->print_tree();
CStreamingAsciiFile* test_file = new CStreamingAsciiFile(test_file_name);
SG_REF(test_file);
CStreamingDenseFeatures<float32_t>* test_features = new CStreamingDenseFeatures<float32_t>(test_file, true, 1024);
SG_REF(test_features);
CMulticlassLabels *pred = cpt->apply_multiclass(test_features);
test_features->reset_stream();
SG_SPRINT("num_labels = %d\n", pred->get_num_labels());
SG_UNREF(test_features);
SG_UNREF(test_file);
test_file = new CStreamingAsciiFile(test_file_name);
SG_REF(test_file);
test_features = new CStreamingDenseFeatures<float32_t>(test_file, true, 1024);
SG_REF(test_features);
CMulticlassLabels *gnd = new CMulticlassLabels(pred->get_num_labels());
SG_REF(gnd);
test_features->start_parser();
for (int32_t i=0; i < pred->get_num_labels(); ++i)
{
test_features->get_next_example();
gnd->set_int_label(i, test_features->get_label());
test_features->release_example();
}
test_features->end_parser();
int32_t n_correct = 0;
for (index_t i=0; i < pred->get_num_labels(); ++i)
{
if (pred->get_int_label(i) == gnd->get_int_label(i))
n_correct++;
//SG_SPRINT("%d-%d ", pred->get_int_label(i), gnd->get_int_label(i));
}
SG_SPRINT("\n");
SG_SPRINT("Multiclass Accuracy = %.2f%%\n", 100.0*n_correct / gnd->get_num_labels());
SG_UNREF(gnd);
SG_UNREF(train_features);
SG_UNREF(test_features);
SG_UNREF(train_file);
SG_UNREF(test_file);
SG_UNREF(cpt);
SG_UNREF(pred);
} catch (const ShogunException& e) {
std::cout << "got shogun exception: " << e.get_exception_string() << std::endl;
} catch(...) {
std::cout << "unknown exception!" << std::endl;
}
exit_shogun();
return 0;
}