Skip to content

Commit

Permalink
make to_multiclass convert rather than assert contiguous labels (#4251)
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Apr 14, 2018
1 parent 8ca2ee4 commit afe105d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
44 changes: 30 additions & 14 deletions src/shogun/labels/MulticlassLabels.cpp
Expand Up @@ -191,20 +191,36 @@ namespace shogun
{
SG_FORCED_INLINE Some<CMulticlassLabels> to_multiclass(CDenseLabels* orig)
{
auto int_labels = orig->get_int_labels();
std::set<int32_t> unique(int_labels.begin(), int_labels.end());
// TODO: anything discrete should be a possible multiclasslabels.
// Once that is the case, this check can be removed
// For now, if we don't enforce this, there will be crashes as the
// class is used to index vectors
REQUIRE(
(*std::min_element(unique.begin(), unique.end())) == 0 &&
(*std::max_element(unique.begin(), unique.end())) ==
(index_t)unique.size() - 1,
"Multiclass labels must be contiguous integers in [0, ..., "
"num_classes -1].\n");

return some<CMulticlassLabels>(orig->get_labels());
auto result_vector = orig->get_labels();
std::set<int32_t> unique(result_vector.begin(), result_vector.end());
// potentially convert to [0,1, ..., num_classes-1] if not in that form
// TODO: remove this once multiclass labels can be any discrete set
auto min = (*std::min_element(unique.begin(), unique.end()));
auto max = (*std::max_element(unique.begin(), unique.end()));
if (!(min == 0 && max == (index_t)unique.size() - 1))
{
// print conversion table for users
SG_SWARNING(
"Converting non-contiguous multiclass labels to "
"contiguous version:\n",
unique.size() - 1);
std::for_each(
unique.begin(), unique.end(), [&unique](int32_t old_label) {
auto new_label =
std::distance(unique.begin(), unique.find(old_label));
SG_SWARNING("Converting %d to %d.\n", old_label, new_label);
});

SGVector<float64_t> converted(result_vector.vlen);
std::transform(
result_vector.begin(), result_vector.end(), converted.begin(),
[&unique](int32_t old_label) {
return std::distance(
unique.begin(), unique.find(old_label));
});
result_vector = converted;
}
return some<CMulticlassLabels>(result_vector);
}

Some<CMulticlassLabels> multiclass_labels(CLabels* orig)
Expand Down
9 changes: 8 additions & 1 deletion tests/unit/labels/MulticlassLabels_unittest.cc
Expand Up @@ -92,5 +92,12 @@ TEST_F(MulticlassLabels, multiclass_labels_from_dense_not_contiguous)
// i.e. [0,1,2,3,4,...], anymore
auto labels = some<CDenseLabels>(labels_true.size());
labels->set_labels({0, 1, 3});
EXPECT_THROW(multiclass_labels(labels), ShogunException);
auto converted = multiclass_labels(labels);
ASSERT_NE(converted, nullptr);
EXPECT_TRUE(converted->get_labels().equals({0, 1, 2}));

labels->set_labels({-1, 1, 1});
auto converted2 = multiclass_labels(labels);
ASSERT_NE(converted2, nullptr);
EXPECT_TRUE(converted2->get_labels().equals({0, 1, 1}));
}

0 comments on commit afe105d

Please sign in to comment.