-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
StratifiedCrossValidationSplitting.cpp
126 lines (106 loc) · 3.58 KB
/
StratifiedCrossValidationSplitting.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Heiko Strathmann, Soeren Sonnenburg, Thoralf Klein, Viktor Gal
*/
#include <shogun/base/range.h>
#include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/labels/Labels.h>
#include <shogun/labels/MulticlassLabels.h>
using namespace shogun;
CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting() :
CSplittingStrategy()
{
m_rng = sg_rand;
}
CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting(
CLabels* labels, index_t num_subsets) :
CSplittingStrategy(labels, num_subsets)
{
/* check for "stupid" combinations of label numbers and num_subsets.
* if there are of a class less labels than num_subsets, the class will not
* appear in every subset, leading to subsets of only one class in the
* extreme case of a two class labeling. */
auto dense_labels = labels->as<CDenseLabels>();
auto classes = dense_labels->get_labels().unique();
SGVector<index_t> labels_per_class(classes.size());
for (auto i : range(classes.size()))
{
labels_per_class[i] = 0;
for (auto j : range(labels->get_num_labels()))
{
if (classes[i] == dense_labels->get_label(j))
labels_per_class[i]++;
}
}
for (index_t i = 0; i < classes.size(); ++i)
{
if (labels_per_class[i] < num_subsets)
{
SG_WARNING(
"There are only %d labels of class %.18g, but %d "
"subsets. Labels of that class will not appear in every "
"subset!\n",
labels_per_class[i], classes[i], num_subsets);
}
}
m_rng = sg_rand;
}
void CStratifiedCrossValidationSplitting::build_subsets()
{
/* ensure that subsets are empty and set flag to filled */
reset_subsets();
m_is_filled=true;
auto dense_labels = m_labels->as<CDenseLabels>();
auto classes = dense_labels->get_labels().unique();
/* for every label, build set for indices */
CDynamicObjectArray label_indices;
for (auto i : range(classes.size()))
label_indices.append_element(new CDynamicArray<index_t> ());
/* fill set with indices, for each label type ... */
for (auto i : range(classes.size()))
{
/* ... iterate over all labels and add indices with same label to set */
for (auto j : range(m_labels->get_num_labels()))
{
if (dense_labels->get_label(j) == classes[i])
{
CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
label_indices.get_element(i);
current->append_element(j);
SG_UNREF(current);
}
}
}
/* shuffle created label sets */
for (index_t i=0; i<label_indices.get_num_elements(); ++i)
{
CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
label_indices.get_element(i);
// external random state important for threads
current->shuffle(m_rng);
SG_UNREF(current);
}
/* distribute labels to subsets for all label types */
index_t target_set=0;
for (auto i : range(classes.size()))
{
/* current index set for current label */
CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
label_indices.get_element(i);
for (index_t j=0; j<current->get_num_elements(); ++j)
{
CDynamicArray<index_t>* next=(CDynamicArray<index_t>*)
m_subset_indices->get_element(target_set++);
next->append_element(current->get_element(j));
target_set%=m_subset_indices->get_num_elements();
SG_UNREF(next);
}
SG_UNREF(current);
}
/* finally shuffle to avoid that subsets with low indices have more
* elements, which happens if the number of class labels is not equal to
* the number of subsets (external random state important for threads) */
m_subset_indices->shuffle(m_rng);
}