-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
streaming_onlineliblinear_sparse.cpp
136 lines (106 loc) · 4.2 KB
/
streaming_onlineliblinear_sparse.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Thoralf Klein, Viktor Gal, Dawei Chen, Vladimir Perić,
* Sergey Lisitsyn, Bjoern Esser
*/
#include <shogun/base/init.h>
#include <shogun/lib/common.h>
#include <shogun/lib/Time.h>
#include <shogun/classifier/svm/OnlineLibLinear.h>
#include <shogun/io/streaming/StreamingAsciiFile.h>
#include <shogun/features/streaming/StreamingSparseFeatures.h>
#include <shogun/labels/BinaryLabels.h>
#ifdef _WIN32
#include <io.h>
#include <fcntl.h>
#endif
using namespace shogun;
int main(int argc, char* argv[])
{
init_shogun_with_defaults();
float64_t C = 1.0;
char *train_file_name = (char*)"../data/train_sparsereal.light";
char *test_file_name = (char*)"../data/test_sparsereal.light";
char filename_tmp[] = "test_labels.XXXXXX";
#ifdef _WIN32
int err = _mktemp_s(filename_tmp, strlen(filename_tmp)+1);
ASSERT(err == 0);
#else
int fd = mkstemp(filename_tmp);
ASSERT(fd != -1);
int retval = close(fd);
ASSERT(retval != -1);
#endif
char *test_labels_file_name = filename_tmp;
if (argc > 4) {
int32_t idx = 1;
C = atof(argv[idx++]);
train_file_name = argv[idx++];
test_file_name = argv[idx++];
test_labels_file_name = argv[idx++];
ASSERT(idx <= argc);
}
fprintf(stderr, "*** training file %s with C %g\n", train_file_name, C);
// Create an OnlineLiblinear object from the features. The first parameter is 'C'.
COnlineLibLinear *svm = new COnlineLibLinear(C);
svm->set_bias_enabled(true);
{
CTime train_time;
train_time.start();
// Create a StreamingAsciiFile from the training data
CStreamingAsciiFile *train_file = new CStreamingAsciiFile(train_file_name);
SG_REF(train_file);
// The bool value is true if examples are labelled.
// 1024 is a good standard value for the number of examples for the parser to hold at a time.
CStreamingSparseFeatures < float32_t > *train_features =
new CStreamingSparseFeatures < float32_t > (train_file, true, 1024);
SG_REF(train_features);
svm->set_features(train_features);
svm->train();
train_file->close();
SG_UNREF(train_file);
SG_UNREF(train_features);
train_time.stop();
SGVector<float32_t> w_now = svm->get_w().clone();
float32_t w_now_norm = SGVector<float32_t>::twonorm(w_now.vector, w_now.vlen);
uint64_t train_time_int = train_time.cur_time_diff();
fprintf(stderr,
"*** total training time: %llum%llus (or %.1f sec), #dim = %d, ||w|| = %f\n",
train_time_int / 60, train_time_int % 60, train_time.cur_time_diff(),
w_now.vlen, w_now_norm
);
}
{
CTime test_time;
test_time.start();
// Now we want to test on holdout data
CStreamingAsciiFile *test_file = new CStreamingAsciiFile(test_file_name);
SG_REF(test_file);
// Set second parameter to 'false' if the file contains unlabelled examples
CStreamingSparseFeatures < float32_t > *test_features =
new CStreamingSparseFeatures < float32_t > (test_file, true, 1024);
SG_REF(test_features);
// Apply on all examples and return a CBinaryLabels*
CBinaryLabels *test_binary_labels = svm->apply_binary(test_features);
SG_REF(test_binary_labels);
test_time.stop();
uint64_t test_time_int = test_time.cur_time_diff();
fprintf(stderr, "*** testing took %llum%llus (or %.1f sec)\n",
test_time_int / 60, test_time_int % 60, test_time.cur_time_diff());
SG_UNREF(test_features);
SG_UNREF(test_file);
// Writing labels for evaluation
fprintf(stderr, "*** writing labels to file %s\n", test_labels_file_name);
FILE* fh = fopen(test_labels_file_name, "wb");
ASSERT(fh);
for (int32_t j = 0; j < test_binary_labels->get_num_labels(); j++)
fprintf(fh, "%d\n", test_binary_labels->get_int_label(j));
fclose(fh);
SG_UNREF(test_binary_labels);
unlink(test_labels_file_name);
}
SG_UNREF(svm);
exit_shogun();
return 0;
}