diff --git a/src/shogun/labels/MultilabelLabels.cpp b/src/shogun/labels/MultilabelLabels.cpp index c54dcb55463..3ff49c79f29 100644 --- a/src/shogun/labels/MultilabelLabels.cpp +++ b/src/shogun/labels/MultilabelLabels.cpp @@ -41,14 +41,14 @@ CMultilabelLabels::CMultilabelLabels() } -CMultilabelLabels::CMultilabelLabels(int16_t num_classes) +CMultilabelLabels::CMultilabelLabels(int32_t num_classes) : CLabels() { init(0, num_classes); } -CMultilabelLabels::CMultilabelLabels(int32_t num_labels, int16_t num_classes) +CMultilabelLabels::CMultilabelLabels(int32_t num_labels, int32_t num_classes) : CLabels() { init(num_labels, num_classes); @@ -62,7 +62,7 @@ CMultilabelLabels::~CMultilabelLabels() void -CMultilabelLabels::init(int32_t num_labels, int16_t num_classes) +CMultilabelLabels::init(int32_t num_labels, int32_t num_classes) { REQUIRE(num_labels >= 0, "num_labels=%d should be >= 0", num_labels); REQUIRE(num_classes > 0, "num_classes=%d should be > 0", num_classes); @@ -86,7 +86,7 @@ CMultilabelLabels::init(int32_t num_labels, int16_t num_classes) m_num_labels = num_labels; m_num_classes = num_classes; - m_labels = new SGVector [m_num_labels]; + m_labels = new SGVector [m_num_labels]; } @@ -127,7 +127,7 @@ CMultilabelLabels::get_num_labels() const } -int16_t +int32_t CMultilabelLabels::get_num_classes() const { return m_num_classes; @@ -135,7 +135,7 @@ CMultilabelLabels::get_num_classes() const void -CMultilabelLabels::set_labels(SGVector * labels) +CMultilabelLabels::set_labels(SGVector * labels) { for (int32_t label_j = 0; label_j < m_num_labels; label_j++) { @@ -152,7 +152,7 @@ SGVector ** CMultilabelLabels::get_class_labels() const int32_t * num_label_idx = SG_MALLOC(int32_t, get_num_classes()); - for (int16_t class_i = 0; class_i < get_num_classes(); class_i++) + for (int32_t class_i = 0; class_i < get_num_classes(); class_i++) { num_label_idx[class_i] = 0; } @@ -161,14 +161,14 @@ SGVector ** CMultilabelLabels::get_class_labels() const { for (int32_t c_pos = 0; c_pos < m_labels[label_j].vlen; c_pos++) { - int16_t class_i = m_labels[label_j][c_pos]; + int32_t class_i = m_labels[label_j][c_pos]; REQUIRE(class_i < get_num_classes(), "class_i exceeded number of classes"); num_label_idx[class_i]++; } } - for (int16_t class_i = 0; class_i < get_num_classes(); class_i++) + for (int32_t class_i = 0; class_i < get_num_classes(); class_i++) { labels_list[class_i] = new SGVector (num_label_idx[class_i]); @@ -176,7 +176,7 @@ SGVector ** CMultilabelLabels::get_class_labels() const SG_FREE(num_label_idx); int32_t * next_label_idx = SG_MALLOC(int32_t, get_num_classes()); - for (int16_t class_i = 0; class_i < get_num_classes(); class_i++) + for (int32_t class_i = 0; class_i < get_num_classes(); class_i++) { next_label_idx[class_i] = 0; } @@ -186,7 +186,7 @@ SGVector ** CMultilabelLabels::get_class_labels() const for (int32_t c_pos = 0; c_pos < m_labels[label_j].vlen; c_pos++) { // get class_i of current position - int16_t class_i = m_labels[label_j][c_pos]; + int32_t class_i = m_labels[label_j][c_pos]; REQUIRE(class_i < get_num_classes(), "class_i exceeded number of classes"); // next free element in m_classes[class_i]: @@ -204,7 +204,7 @@ SGVector ** CMultilabelLabels::get_class_labels() const } -SGVector CMultilabelLabels::get_label(int32_t j) +SGVector CMultilabelLabels::get_label(int32_t j) { REQUIRE(j < get_num_labels(), "label index j=%d should be within [%d,%d[", @@ -231,19 +231,15 @@ SGVector CMultilabelLabels::to_dense template -SGVector CMultilabelLabels::to_dense -(SGVector *, int32_t, float64_t, float64_t); - -template -SGVector CMultilabelLabels::to_dense -(SGVector *, int32_t, int32_t, int32_t); +SGVector CMultilabelLabels::to_dense +(SGVector *, int32_t, int32_t, int32_t); template SGVector CMultilabelLabels::to_dense (SGVector *, int32_t, float64_t, float64_t); void -CMultilabelLabels::set_label(int32_t j, SGVector label) +CMultilabelLabels::set_label(int32_t j, SGVector label) { REQUIRE(j < get_num_labels(), "label index j=%d should be within [%d,%d[", @@ -255,13 +251,13 @@ CMultilabelLabels::set_label(int32_t j, SGVector label) void CMultilabelLabels::set_class_labels(SGVector ** labels_list) { - int16_t * num_class_idx = SG_MALLOC(int16_t , get_num_labels()); + int32_t * num_class_idx = SG_MALLOC(int32_t , get_num_labels()); for (int32_t label_j = 0; label_j < get_num_labels(); label_j++) { num_class_idx[label_j] = 0; } - for (int16_t class_i = 0; class_i < get_num_classes(); class_i++) + for (int32_t class_i = 0; class_i < get_num_classes(); class_i++) { for (int32_t l_pos = 0; l_pos < labels_list[class_i]->vlen; l_pos++) { @@ -280,13 +276,13 @@ CMultilabelLabels::set_class_labels(SGVector ** labels_list) } SG_FREE(num_class_idx); - int16_t * next_class_idx = SG_MALLOC(int16_t , get_num_labels()); + int32_t * next_class_idx = SG_MALLOC(int32_t , get_num_labels()); for (int32_t label_j = 0; label_j < get_num_labels(); label_j++) { next_class_idx[label_j] = 0; } - for (int16_t class_i = 0; class_i < get_num_classes(); class_i++) + for (int32_t class_i = 0; class_i < get_num_classes(); class_i++) { for (int32_t l_pos = 0; l_pos < labels_list[class_i]->vlen; l_pos++) { @@ -338,7 +334,7 @@ CMultilabelLabels::display() const { SG_PRINT(" y_{j=%d}", j); SGVector dense = - to_dense (&m_labels[j], get_num_classes(), + to_dense (&m_labels[j], get_num_classes(), +1, -1); dense.display_vector(""); } diff --git a/src/shogun/labels/MultilabelLabels.h b/src/shogun/labels/MultilabelLabels.h index 1171c8cd0f9..872a27cd121 100644 --- a/src/shogun/labels/MultilabelLabels.h +++ b/src/shogun/labels/MultilabelLabels.h @@ -55,14 +55,14 @@ class CMultilabelLabels : public CLabels * * @param num_classes number of (binary) class assignments per label */ - CMultilabelLabels(int16_t num_classes); + CMultilabelLabels(int32_t num_classes); /** constructor * * @param num_labels number of labels * @param num_classes number of (binary) class assignments per label */ - CMultilabelLabels(int32_t num_labels, int16_t num_classes); + CMultilabelLabels(int32_t num_labels, int32_t num_classes); /** destructor */ ~CMultilabelLabels(); @@ -100,13 +100,13 @@ class CMultilabelLabels : public CLabels * * @return number of classes */ - virtual int16_t get_num_classes() const; + virtual int32_t get_num_classes() const; /** set labels * * @param labels list of sparse labels */ - void set_labels(SGVector * labels); + void set_labels(SGVector * labels); /** get list of sparse class labels (one vector per class) * @@ -116,9 +116,9 @@ class CMultilabelLabels : public CLabels /** get sparse assignment for j-th label * - * @return SGVector sparse label + * @return SGVector sparse label */ - SGVector get_label(int32_t j); + SGVector get_label(int32_t j); /** Convert sparse label vector to dense. The dense vector * will be {d_true; d_false}^dense_dim. Indices in sparse @@ -137,9 +137,9 @@ class CMultilabelLabels : public CLabels /** set sparse assignment for j-th label * * @param int32_t label index - * @param SGVector sparse label + * @param SGVector sparse label */ - void set_label(int32_t j, SGVector label); + void set_label(int32_t j, SGVector label); /** assigning class labels */ void set_class_labels(SGVector ** labels_list); @@ -148,13 +148,13 @@ class CMultilabelLabels : public CLabels void display() const; private: - void init(int32_t num_labels, int16_t num_classes); + void init(int32_t num_labels, int32_t num_classes); protected: int32_t m_num_labels; - int16_t m_num_classes; - SGVector * m_labels; + int32_t m_num_classes; + SGVector * m_labels; }; } diff --git a/tests/unit/labels/MultilabelLabels_unittest.cc b/tests/unit/labels/MultilabelLabels_unittest.cc index 09b16bcf5af..7eb6db938fc 100644 --- a/tests/unit/labels/MultilabelLabels_unittest.cc +++ b/tests/unit/labels/MultilabelLabels_unittest.cc @@ -93,13 +93,13 @@ TEST(MultilabelLabels, clone) TEST(MultilabelLabels, to_dense) { - SGVector sparse(2); + SGVector sparse(2); sparse[0] = 2; sparse[1] = 5; EXPECT_EQ(2, sparse.size()); - SGVector dense = CMultilabelLabels::to_dense (&sparse, 20, +1, 0); + SGVector dense = CMultilabelLabels::to_dense (&sparse, 20, +1, 0); EXPECT_EQ(20, dense.size()); EXPECT_EQ(+1, dense[2]); EXPECT_EQ(+1, dense[5]); @@ -117,10 +117,10 @@ TEST(MultilabelLabels, get_label) for (int32_t i = 0; i < ml->get_num_labels(); i++) { - SGVector sparse = ml->get_label(i); + SGVector sparse = ml->get_label(i); EXPECT_EQ(0, sparse.size()); - SGVector dense = CMultilabelLabels::to_dense (&sparse, ml->get_num_labels(), +1, -1); + SGVector dense = CMultilabelLabels::to_dense (&sparse, ml->get_num_labels(), +1, -1); EXPECT_EQ(ml->get_num_labels(), dense.size()); }