Skip to content

Commit

Permalink
avoid specialisation of labels in stratified xvalidation (#4313)
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed May 30, 2018
1 parent 7e9699e commit bf2f545
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 53 deletions.
Expand Up @@ -2,14 +2,14 @@ CSVFile f_feats("../../data/fm_train_real.dat")
CSVFile f_labels("../../data/label_train_twoclass.dat")

Features feats = features(f_feats)
BinaryLabels labels(f_labels)
Labels labs = labels(f_labels)

Machine svm = machine("LibLinear")
svm.put("liblinear_solver_type", enum LIBLINEAR_SOLVER_TYPE.L2R_L2LOSS_SVC)

StratifiedCrossValidationSplitting splitting_strategy(labels, 5)
StratifiedCrossValidationSplitting splitting_strategy(labs, 5)
AccuracyMeasure evaluation_criterion()
CrossValidation cross(svm, feats, labels, splitting_strategy, evaluation_criterion)
CrossValidation cross(svm, feats, labs, splitting_strategy, evaluation_criterion)
cross.set_num_runs(2)

CrossValidationResult result = CrossValidationResult:obtain_from_generic(cross.evaluate())
Expand Down
74 changes: 24 additions & 50 deletions src/shogun/evaluation/StratifiedCrossValidationSplitting.cpp
Expand Up @@ -4,9 +4,10 @@
* Authors: Heiko Strathmann, Soeren Sonnenburg, Thoralf Klein, Viktor Gal
*/

#include <shogun/base/range.h>
#include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
#include <shogun/labels/Labels.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/labels/Labels.h>
#include <shogun/labels/MulticlassLabels.h>

using namespace shogun;
Expand All @@ -25,44 +26,31 @@ CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting(
* if there are of a class less labels than num_subsets, the class will not
* appear in every subset, leading to subsets of only one class in the
* extreme case of a two class labeling. */
SGVector<float64_t> classes;

int32_t num_classes=2;
if (labels->get_label_type() == LT_MULTICLASS)
{
num_classes=((CMulticlassLabels*) labels)->get_num_classes();
classes=((CMulticlassLabels*) labels)->get_unique_labels();
}
else if (labels->get_label_type() == LT_BINARY)
{
classes=SGVector<float64_t>(2);
classes[0]=-1;
classes[1]=+1;
}
else
{
SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n")
}
auto dense_labels = labels->as<CDenseLabels>();
auto classes = dense_labels->get_labels().unique();

SGVector<index_t> labels_per_class(num_classes);
SGVector<index_t> labels_per_class(classes.size());

for (index_t i=0; i<num_classes; ++i)
for (auto i : range(classes.size()))
{
labels_per_class.vector[i]=0;
for (index_t j=0; j<labels->get_num_labels(); ++j)
labels_per_class[i] = 0;
for (auto j : range(labels->get_num_labels()))
{
if (classes.vector[i]==((CDenseLabels*) labels)->get_label(j))
labels_per_class.vector[i]++;
if (classes[i] == dense_labels->get_label(j))
labels_per_class[i]++;
}
}

for (index_t i=0; i<num_classes; ++i)
for (index_t i = 0; i < classes.size(); ++i)
{
if (labels_per_class.vector[i]<num_subsets)
if (labels_per_class[i] < num_subsets)
{
SG_WARNING("There are only %d labels of class %.18g, but %d "
"subsets. Labels of that class will not appear in every "
"subset!\n", labels_per_class.vector[i], classes.vector[i], num_subsets);
SG_WARNING(
"There are only %d labels of class %.18g, but %d "
"subsets. Labels of that class will not appear in every "
"subset!\n",
labels_per_class[i], classes[i], num_subsets);
}
}

Expand All @@ -75,35 +63,21 @@ void CStratifiedCrossValidationSplitting::build_subsets()
reset_subsets();
m_is_filled=true;

SGVector<float64_t> unique_labels;

if (m_labels->get_label_type() == LT_MULTICLASS)
{
unique_labels=((CMulticlassLabels*) m_labels)->get_unique_labels();
}
else if (m_labels->get_label_type() == LT_BINARY)
{
unique_labels=SGVector<float64_t>(2);
unique_labels[0]=-1;
unique_labels[1]=+1;
}
else
{
SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n")
}
auto dense_labels = m_labels->as<CDenseLabels>();
auto classes = dense_labels->get_labels().unique();

/* for every label, build set for indices */
CDynamicObjectArray label_indices;
for (index_t i=0; i<unique_labels.vlen; ++i)
for (auto i : range(classes.size()))
label_indices.append_element(new CDynamicArray<index_t> ());

/* fill set with indices, for each label type ... */
for (index_t i=0; i<unique_labels.vlen; ++i)
for (auto i : range(classes.size()))
{
/* ... iterate over all labels and add indices with same label to set */
for (index_t j=0; j<m_labels->get_num_labels(); ++j)
for (auto j : range(m_labels->get_num_labels()))
{
if (((CDenseLabels*) m_labels)->get_label(j)==unique_labels.vector[i])
if (dense_labels->get_label(j) == classes[i])
{
CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
label_indices.get_element(i);
Expand All @@ -127,7 +101,7 @@ void CStratifiedCrossValidationSplitting::build_subsets()

/* distribute labels to subsets for all label types */
index_t target_set=0;
for (index_t i=0; i<unique_labels.vlen; ++i)
for (auto i : range(classes.size()))
{
/* current index set for current label */
CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
Expand Down
16 changes: 16 additions & 0 deletions src/shogun/lib/SGVector.cpp
Expand Up @@ -884,6 +884,22 @@ int32_t SGVector<complex128_t>::unique(complex128_t* output, int32_t size)
return j;
}

template <class T>
SGVector<T> SGVector<T>::unique()
{
SGVector<T> result = clone();
auto new_size = unique(result.data(), result.size());
result.resize_vector(new_size);
return result;
}

template <>
SGVector<complex128_t> SGVector<complex128_t>::unique()
{
SG_SNOTIMPLEMENTED
return SGVector<complex128_t>();
}

template <class T>
SGVector<index_t> SGVector<T>::find(T elem)
{
Expand Down
3 changes: 3 additions & 0 deletions src/shogun/lib/SGVector.h
Expand Up @@ -473,6 +473,9 @@ template<class T> class SGVector : public SGReferencedData
*/
static int32_t unique(T* output, int32_t size);

/** Returns new vector with sorted unique elements of current */
SGVector<T> unique();

/** Display array size */
void display_size() const;

Expand Down
13 changes: 13 additions & 0 deletions tests/unit/lib/SGVector_unittest.cc
Expand Up @@ -428,3 +428,16 @@ TEST(SGVectorTest,unique)
for (index_t i = 0; i < num_unique; ++i)
EXPECT_EQ(i+1, vec[i]);
}

TEST(SGVectorTest, unique_method)
{
SGVector<int32_t> vec{1, 4, 3, 1, 4, 3, 3};
auto vec_unique = vec.unique();

ASSERT_NE(vec_unique.data(), nullptr);
ASSERT_EQ(3, vec_unique.size());
EXPECT_NE(vec_unique.data(), vec.data());
EXPECT_EQ(vec_unique[0], 1);
EXPECT_EQ(vec_unique[1], 3);
EXPECT_EQ(vec_unique[2], 4);
}

0 comments on commit bf2f545

Please sign in to comment.