Skip to content

Commit

Permalink
Merge pull request #2077 from tklein23/change_multilabel_labeltype
Browse files Browse the repository at this point in the history
Changed int16_t to int32_t in MultilabelLabels.
  • Loading branch information
tklein23 committed Mar 25, 2014
2 parents a68d792 + ec90ad9 commit b493737
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 39 deletions.
44 changes: 20 additions & 24 deletions src/shogun/labels/MultilabelLabels.cpp
Expand Up @@ -41,14 +41,14 @@ CMultilabelLabels::CMultilabelLabels()
}


CMultilabelLabels::CMultilabelLabels(int16_t num_classes)
CMultilabelLabels::CMultilabelLabels(int32_t num_classes)
: CLabels()
{
init(0, num_classes);
}


CMultilabelLabels::CMultilabelLabels(int32_t num_labels, int16_t num_classes)
CMultilabelLabels::CMultilabelLabels(int32_t num_labels, int32_t num_classes)
: CLabels()
{
init(num_labels, num_classes);
Expand All @@ -62,7 +62,7 @@ CMultilabelLabels::~CMultilabelLabels()


void
CMultilabelLabels::init(int32_t num_labels, int16_t num_classes)
CMultilabelLabels::init(int32_t num_labels, int32_t num_classes)
{
REQUIRE(num_labels >= 0, "num_labels=%d should be >= 0", num_labels);
REQUIRE(num_classes > 0, "num_classes=%d should be > 0", num_classes);
Expand All @@ -86,7 +86,7 @@ CMultilabelLabels::init(int32_t num_labels, int16_t num_classes)

m_num_labels = num_labels;
m_num_classes = num_classes;
m_labels = new SGVector <int16_t>[m_num_labels];
m_labels = new SGVector <int32_t>[m_num_labels];
}


Expand Down Expand Up @@ -127,15 +127,15 @@ CMultilabelLabels::get_num_labels() const
}


int16_t
int32_t
CMultilabelLabels::get_num_classes() const
{
return m_num_classes;
}


void
CMultilabelLabels::set_labels(SGVector <int16_t> * labels)
CMultilabelLabels::set_labels(SGVector <int32_t> * labels)
{
for (int32_t label_j = 0; label_j < m_num_labels; label_j++)
{
Expand All @@ -152,7 +152,7 @@ SGVector <int32_t> ** CMultilabelLabels::get_class_labels() const
int32_t * num_label_idx =
SG_MALLOC(int32_t, get_num_classes());

for (int16_t class_i = 0; class_i < get_num_classes(); class_i++)
for (int32_t class_i = 0; class_i < get_num_classes(); class_i++)
{
num_label_idx[class_i] = 0;
}
Expand All @@ -161,22 +161,22 @@ SGVector <int32_t> ** CMultilabelLabels::get_class_labels() const
{
for (int32_t c_pos = 0; c_pos < m_labels[label_j].vlen; c_pos++)
{
int16_t class_i = m_labels[label_j][c_pos];
int32_t class_i = m_labels[label_j][c_pos];
REQUIRE(class_i < get_num_classes(),
"class_i exceeded number of classes");
num_label_idx[class_i]++;
}
}

for (int16_t class_i = 0; class_i < get_num_classes(); class_i++)
for (int32_t class_i = 0; class_i < get_num_classes(); class_i++)
{
labels_list[class_i] =
new SGVector <int32_t> (num_label_idx[class_i]);
}
SG_FREE(num_label_idx);

int32_t * next_label_idx = SG_MALLOC(int32_t, get_num_classes());
for (int16_t class_i = 0; class_i < get_num_classes(); class_i++)
for (int32_t class_i = 0; class_i < get_num_classes(); class_i++)
{
next_label_idx[class_i] = 0;
}
Expand All @@ -186,7 +186,7 @@ SGVector <int32_t> ** CMultilabelLabels::get_class_labels() const
for (int32_t c_pos = 0; c_pos < m_labels[label_j].vlen; c_pos++)
{
// get class_i of current position
int16_t class_i = m_labels[label_j][c_pos];
int32_t class_i = m_labels[label_j][c_pos];
REQUIRE(class_i < get_num_classes(),
"class_i exceeded number of classes");
// next free element in m_classes[class_i]:
Expand All @@ -204,7 +204,7 @@ SGVector <int32_t> ** CMultilabelLabels::get_class_labels() const
}


SGVector <int16_t> CMultilabelLabels::get_label(int32_t j)
SGVector <int32_t> CMultilabelLabels::get_label(int32_t j)
{
REQUIRE(j < get_num_labels(),
"label index j=%d should be within [%d,%d[",
Expand All @@ -231,19 +231,15 @@ SGVector <D> CMultilabelLabels::to_dense


template
SGVector <float64_t> CMultilabelLabels::to_dense <int16_t, float64_t>
(SGVector <int16_t> *, int32_t, float64_t, float64_t);

template
SGVector <int32_t> CMultilabelLabels::to_dense <int16_t, int32_t>
(SGVector <int16_t> *, int32_t, int32_t, int32_t);
SGVector <int32_t> CMultilabelLabels::to_dense <int32_t, int32_t>
(SGVector <int32_t> *, int32_t, int32_t, int32_t);

template
SGVector <float64_t> CMultilabelLabels::to_dense <int32_t, float64_t>
(SGVector <int32_t> *, int32_t, float64_t, float64_t);

void
CMultilabelLabels::set_label(int32_t j, SGVector <int16_t> label)
CMultilabelLabels::set_label(int32_t j, SGVector <int32_t> label)
{
REQUIRE(j < get_num_labels(),
"label index j=%d should be within [%d,%d[",
Expand All @@ -255,13 +251,13 @@ CMultilabelLabels::set_label(int32_t j, SGVector <int16_t> label)
void
CMultilabelLabels::set_class_labels(SGVector <int32_t> ** labels_list)
{
int16_t * num_class_idx = SG_MALLOC(int16_t , get_num_labels());
int32_t * num_class_idx = SG_MALLOC(int32_t , get_num_labels());
for (int32_t label_j = 0; label_j < get_num_labels(); label_j++)
{
num_class_idx[label_j] = 0;
}

for (int16_t class_i = 0; class_i < get_num_classes(); class_i++)
for (int32_t class_i = 0; class_i < get_num_classes(); class_i++)
{
for (int32_t l_pos = 0; l_pos < labels_list[class_i]->vlen; l_pos++)
{
Expand All @@ -280,13 +276,13 @@ CMultilabelLabels::set_class_labels(SGVector <int32_t> ** labels_list)
}
SG_FREE(num_class_idx);

int16_t * next_class_idx = SG_MALLOC(int16_t , get_num_labels());
int32_t * next_class_idx = SG_MALLOC(int32_t , get_num_labels());
for (int32_t label_j = 0; label_j < get_num_labels(); label_j++)
{
next_class_idx[label_j] = 0;
}

for (int16_t class_i = 0; class_i < get_num_classes(); class_i++)
for (int32_t class_i = 0; class_i < get_num_classes(); class_i++)
{
for (int32_t l_pos = 0; l_pos < labels_list[class_i]->vlen; l_pos++)
{
Expand Down Expand Up @@ -338,7 +334,7 @@ CMultilabelLabels::display() const
{
SG_PRINT(" y_{j=%d}", j);
SGVector <float64_t> dense =
to_dense <int16_t , float64_t> (&m_labels[j], get_num_classes(),
to_dense <int32_t , float64_t> (&m_labels[j], get_num_classes(),
+1, -1);
dense.display_vector("");
}
Expand Down
22 changes: 11 additions & 11 deletions src/shogun/labels/MultilabelLabels.h
Expand Up @@ -55,14 +55,14 @@ class CMultilabelLabels : public CLabels
*
* @param num_classes number of (binary) class assignments per label
*/
CMultilabelLabels(int16_t num_classes);
CMultilabelLabels(int32_t num_classes);

/** constructor
*
* @param num_labels number of labels
* @param num_classes number of (binary) class assignments per label
*/
CMultilabelLabels(int32_t num_labels, int16_t num_classes);
CMultilabelLabels(int32_t num_labels, int32_t num_classes);

/** destructor */
~CMultilabelLabels();
Expand Down Expand Up @@ -100,13 +100,13 @@ class CMultilabelLabels : public CLabels
*
* @return number of classes
*/
virtual int16_t get_num_classes() const;
virtual int32_t get_num_classes() const;

/** set labels
*
* @param labels list of sparse labels
*/
void set_labels(SGVector<int16_t> * labels);
void set_labels(SGVector<int32_t> * labels);

/** get list of sparse class labels (one vector per class)
*
Expand All @@ -116,9 +116,9 @@ class CMultilabelLabels : public CLabels

/** get sparse assignment for j-th label
*
* @return SGVector<int16_t > sparse label
* @return SGVector<int32_t > sparse label
*/
SGVector<int16_t> get_label(int32_t j);
SGVector<int32_t> get_label(int32_t j);

/** Convert sparse label vector to dense. The dense vector
* will be {d_true; d_false}^dense_dim. Indices in sparse
Expand All @@ -137,9 +137,9 @@ class CMultilabelLabels : public CLabels
/** set sparse assignment for j-th label
*
* @param int32_t label index
* @param SGVector<int16_t > sparse label
* @param SGVector<int32_t > sparse label
*/
void set_label(int32_t j, SGVector<int16_t> label);
void set_label(int32_t j, SGVector<int32_t> label);

/** assigning class labels */
void set_class_labels(SGVector <int32_t> ** labels_list);
Expand All @@ -148,13 +148,13 @@ class CMultilabelLabels : public CLabels
void display() const;

private:
void init(int32_t num_labels, int16_t num_classes);
void init(int32_t num_labels, int32_t num_classes);

protected:

int32_t m_num_labels;
int16_t m_num_classes;
SGVector<int16_t> * m_labels;
int32_t m_num_classes;
SGVector<int32_t> * m_labels;
};

}
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/labels/MultilabelLabels_unittest.cc
Expand Up @@ -93,13 +93,13 @@ TEST(MultilabelLabels, clone)

TEST(MultilabelLabels, to_dense)
{
SGVector<int16_t> sparse(2);
SGVector<int32_t> sparse(2);
sparse[0] = 2;
sparse[1] = 5;

EXPECT_EQ(2, sparse.size());

SGVector<float64_t> dense = CMultilabelLabels::to_dense<int16_t, float64_t> (&sparse, 20, +1, 0);
SGVector<float64_t> dense = CMultilabelLabels::to_dense<int32_t, float64_t> (&sparse, 20, +1, 0);
EXPECT_EQ(20, dense.size());
EXPECT_EQ(+1, dense[2]);
EXPECT_EQ(+1, dense[5]);
Expand All @@ -117,10 +117,10 @@ TEST(MultilabelLabels, get_label)

for (int32_t i = 0; i < ml->get_num_labels(); i++)
{
SGVector<int16_t> sparse = ml->get_label(i);
SGVector<int32_t> sparse = ml->get_label(i);
EXPECT_EQ(0, sparse.size());

SGVector<float64_t> dense = CMultilabelLabels::to_dense<int16_t, float64_t> (&sparse, ml->get_num_labels(), +1, -1);
SGVector<float64_t> dense = CMultilabelLabels::to_dense<int32_t, float64_t> (&sparse, ml->get_num_labels(), +1, -1);
EXPECT_EQ(ml->get_num_labels(), dense.size());
}

Expand Down

0 comments on commit b493737

Please sign in to comment.