Skip to content

Commit

Permalink
added unit-test
Browse files Browse the repository at this point in the history
  • Loading branch information
hushell authored and hushell committed May 6, 2013
1 parent d39547b commit 2bb00e1
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 10 deletions.
20 changes: 12 additions & 8 deletions src/shogun/multiclass/MulticlassOneVsOneStrategy.cpp
Expand Up @@ -18,16 +18,18 @@ using namespace shogun;
CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy()
:CMulticlassStrategy(), m_num_machines(0), m_num_samples(SGVector<int32_t>())
{
SG_ADD(&m_num_samples, "num_samples", "Number of samples in each training machine", MS_NOT_AVAILABLE);

SG_WARNING("%s::CMulticlassOneVsOneStrategy(): register parameters!\n", get_name());
register_parameters();
}

CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy(EProbHeuristicType prob_heuris)
:CMulticlassStrategy(prob_heuris), m_num_machines(0), m_num_samples(SGVector<int32_t>())
{
SG_ADD(&m_num_samples, "num_samples", "Number of samples in each training machine", MS_NOT_AVAILABLE);
register_parameters();
}

void CMulticlassOneVsOneStrategy::register_parameters()
{
//SG_ADD(&m_num_samples, "num_samples", "Number of samples in each training machine", MS_NOT_AVAILABLE);
SG_WARNING("%s::CMulticlassOneVsOneStrategy(): register parameters!\n", get_name());
}

Expand Down Expand Up @@ -87,9 +89,7 @@ int32_t CMulticlassOneVsOneStrategy::decide_label(SGVector<float64_t> outputs)
{
// if OVO with prob outputs, find max posterior
if (outputs.vlen==m_num_classes)
{
return SGVector<float64_t>::arg_max(outputs.vector, 1, outputs.vlen);
}

int32_t s=0;
SGVector<int32_t> votes(m_num_classes);
Expand Down Expand Up @@ -143,9 +143,7 @@ int32_t CMulticlassOneVsOneStrategy::decide_label(SGVector<float64_t> outputs)
void CMulticlassOneVsOneStrategy::rescale_outputs(SGVector<float64_t>& outputs)
{
if (m_num_machines < 1)
{
return;
}

SGVector<int32_t> indx1(m_num_machines);
SGVector<int32_t> indx2(m_num_machines);
Expand Down Expand Up @@ -187,8 +185,10 @@ void CMulticlassOneVsOneStrategy::rescale_heuris_price(SGVector<float64_t>& outp
const SGVector<int32_t> indx1, const SGVector<int32_t> indx2)
{
if (m_num_machines != outputs.vlen)
{
SG_ERROR("%s::rescale_heuris_price(): size(outputs) = %d != m_num_machines = %d\n",
get_name(), outputs.vlen, m_num_machines);
}

SGVector<float64_t> new_outputs(m_num_classes);
new_outputs.zero();
Expand Down Expand Up @@ -217,8 +217,10 @@ void CMulticlassOneVsOneStrategy::rescale_heuris_hastie(SGVector<float64_t>& out
const SGVector<int32_t> indx1, const SGVector<int32_t> indx2)
{
if (m_num_machines != outputs.vlen)
{
SG_ERROR("%s::rescale_heuris_hastie(): size(outputs) = %d != m_num_machines = %d\n",
get_name(), outputs.vlen, m_num_machines);
}

SGVector<float64_t> new_outputs(m_num_classes);
new_outputs.zero();
Expand Down Expand Up @@ -291,8 +293,10 @@ void CMulticlassOneVsOneStrategy::rescale_heuris_hamamura(SGVector<float64_t>& o
const SGVector<int32_t> indx1, const SGVector<int32_t> indx2)
{
if (m_num_machines != outputs.vlen)
{
SG_ERROR("%s::rescale_heuris_hamamura(): size(outputs) = %d != m_num_machines = %d\n",
get_name(), outputs.vlen, m_num_machines);
}

SGVector<float64_t> new_outputs(m_num_classes);
SGVector<float64_t>::fill_vector(new_outputs.vector, new_outputs.vlen, 1.0);
Expand Down
14 changes: 14 additions & 0 deletions src/shogun/multiclass/MulticlassOneVsOneStrategy.h
Expand Up @@ -74,6 +74,16 @@ class CMulticlassOneVsOneStrategy: public CMulticlassStrategy
*/
virtual void rescale_outputs(SGVector<float64_t>& outputs);

/** set the number of classes, since the number of machines totally
* depends on the number of classes, which will also be set.
* @param num_classes number of classes
*/
void set_num_classes(int32_t num_classes)
{
CMulticlassStrategy::set_num_classes(num_classes);
m_num_machines = m_num_classes*(m_num_classes-1)/2;
}

protected:
/** OVO Price's heuristic see [1]
* @param outputs a vector of output from each machine (in that order)
Expand All @@ -99,6 +109,10 @@ class CMulticlassOneVsOneStrategy: public CMulticlassStrategy
void rescale_heuris_hamamura(SGVector<float64_t>& outputs,
const SGVector<int32_t> indx1, const SGVector<int32_t> indx2);

private:
/** register parameters */
void register_parameters();

protected:
int32_t m_num_machines; ///< number of machines
int32_t m_train_pair_idx_1; ///< 1st index of current submachine being trained
Expand Down
4 changes: 4 additions & 0 deletions src/shogun/multiclass/MulticlassOneVsRestStrategy.cpp
Expand Up @@ -97,8 +97,10 @@ void CMulticlassOneVsRestStrategy::rescale_outputs(SGVector<float64_t>& outputs,
void CMulticlassOneVsRestStrategy::rescale_heuris_norm(SGVector<float64_t>& outputs)
{
if (m_num_classes != outputs.vlen)
{
SG_ERROR("%s::rescale_heuris_norm(): size(outputs) = %d != m_num_classes = %d\n",
get_name(), outputs.vlen, m_num_classes);
}

float64_t norm = SGVector<float64_t>::sum(outputs);
norm += 1E-10;
Expand All @@ -110,8 +112,10 @@ void CMulticlassOneVsRestStrategy::rescale_heuris_softmax(SGVector<float64_t>& o
const SGVector<float64_t> As, const SGVector<float64_t> Bs)
{
if (m_num_classes != outputs.vlen)
{
SG_ERROR("%s::rescale_heuris_softmax(): size(outputs) = %d != m_num_classes = %d\n",
get_name(), outputs.vlen, m_num_classes);
}

for (int32_t i=0; i<outputs.vlen; i++)
outputs[i] = CMath::exp(-As[i]*outputs[i]-Bs[i]);
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/multiclass/MulticlassStrategy.cpp
Expand Up @@ -33,7 +33,7 @@ void CMulticlassStrategy::register_parameters()
{
SG_ADD((CSGObject**)&m_rejection_strategy, "rejection_strategy", "Strategy of rejection", MS_NOT_AVAILABLE);
SG_ADD(&m_num_classes, "num_classes", "Number of classes", MS_NOT_AVAILABLE);
SG_ADD((machine_int_t*)&m_prob_heuris, "prob_heuris", "Probability estimation heuristics", MS_NOT_AVAILABLE);
//SG_ADD((machine_int_t*)&m_prob_heuris, "prob_heuris", "Probability estimation heuristics", MS_NOT_AVAILABLE);

SG_WARNING("%s::CMulticlassStrategy(): register parameters!\n", get_name());
}
Expand Down
1 change: 0 additions & 1 deletion src/shogun/multiclass/MulticlassStrategy.h
Expand Up @@ -153,7 +153,6 @@ class CMulticlassStrategy: public CSGObject
virtual void rescale_outputs(SGVector<float64_t>& outputs,
const SGVector<float64_t> As, const SGVector<float64_t> Bs)
{
SG_NOTIMPLEMENTED
}

private:
Expand Down
130 changes: 130 additions & 0 deletions tests/unit/multiclass/MulticlassStrategy_unittest.cc
@@ -0,0 +1,130 @@
#include <shogun/multiclass/MulticlassOneVsOneStrategy.h>
#include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
#include <shogun/labels/BinaryLabels.h> |
#include <shogun/labels/MulticlassLabels.h>
#include <gtest/gtest.h>

using namespace shogun;

TEST(MulticlassStrategy,rescale_ova_norm)
{
SGVector<float64_t> scores(3);

for (int32_t i=0; i<3; i++)
scores[i] = (i+1)*0.1;

CMulticlassOneVsRestStrategy ova(OVA_NORM);
ova.set_num_classes(3);
ova.rescale_outputs(scores);

//SGVector<float64_t>::display_vector(scores.vector,scores.vlen);

// GT caculated manually
// scores[0] = scores[0] / sum(scores)
// scores[1] = scores[1] / sum(scores)
// scores[2] = scores[2] / sum(scores)
EXPECT_NEAR(scores[0],0.16666666666666669,1E-5);
EXPECT_NEAR(scores[1],0.33333333333333337,1E-5);
EXPECT_NEAR(scores[2],0.5,1E-5);
}

TEST(MulticlassStrategy,rescale_ova_softmax)
{
SGVector<float64_t> scores(3);
scores.range_fill(1);

SGVector<float64_t> As(3);
SGVector<float64_t> Bs(3);
Bs.zero();
for (int32_t i=0; i<3; i++)
As[i] = (i+1)*0.1;

CMulticlassOneVsRestStrategy ova(OVA_SOFTMAX);
ova.set_num_classes(3);
ova.rescale_outputs(scores,As,Bs);

//SGVector<float64_t>::display_vector(scores.vector,scores.vlen);

// GT caculated manually
// scores[0] = exp(-0.1) / norm
// scores[1] = exp(-0.4) / norm
// scores[2] = exp(-0.9) / norm
// norm = exp(-0.1)+exp(-0.4)+exp(-0.9)
EXPECT_NEAR(scores[0],0.4565903181944378,1E-5);
EXPECT_NEAR(scores[1],0.33825042710530284,1E-5);
EXPECT_NEAR(scores[2],0.20515925470025934,1E-5);
}

TEST(MulticlassStrategy,rescale_ova_price)
{
SGVector<float64_t> scores(3);
SGVector<float64_t>::fill_vector(scores.vector,scores.vlen,0.5);

CMulticlassOneVsOneStrategy ovo(OVO_PRICE);
ovo.set_num_classes(3);
ovo.rescale_outputs(scores);

//SGVector<float64_t>::display_vector(scores.vector,scores.vlen);

// GT caculated manually
// scores[0] = \frac{1}{1/0.5+1/0.5-(3-2)} / norm
// scores[1] = \frac{1}{1/0.5+1/0.5-(3-2)} / norm
// scores[2] = \frac{1}{1/0.5+1/0.5-(3-2)} / norm
// norm = sum(scores)
EXPECT_NEAR(scores[0],0.3333333333333333,1E-5);
EXPECT_NEAR(scores[1],0.3333333333333333,1E-5);
EXPECT_NEAR(scores[2],0.3333333333333333,1E-5);
}

TEST(MulticlassStrategy,rescale_ova_hastie)
{
CMulticlassOneVsOneStrategy ovo(OVO_HASTIE);
ovo.set_num_classes(3);

// training simulation
SGVector<float64_t> labels(3);
labels.range_fill(0);

CMulticlassLabels *orig_labels = new CMulticlassLabels(labels);
SG_REF(orig_labels);

CBinaryLabels *train_labels = new CBinaryLabels(2);
SG_REF(train_labels);

ovo.train_start(orig_labels, train_labels);
for (int32_t i=0; i<3; i++)
{
ovo.train_prepare_next();
}
ovo.train_stop();

SGVector<float64_t> scores(3);
SGVector<float64_t>::fill_vector(scores.vector,scores.vlen,0.5);

ovo.rescale_outputs(scores);

//SGVector<float64_t>::display_vector(scores.vector,scores.vlen);

EXPECT_NEAR(scores[0],0.3333333333333333,1E-5);
EXPECT_NEAR(scores[1],0.3333333333333333,1E-5);
EXPECT_NEAR(scores[2],0.3333333333333333,1E-5);

SG_UNREF(orig_labels);
SG_UNREF(train_labels);
}

TEST(MulticlassStrategy,rescale_ova_hamamura)
{
SGVector<float64_t> scores(3);
SGVector<float64_t>::fill_vector(scores.vector,scores.vlen,0.5);

CMulticlassOneVsOneStrategy ovo(OVO_HAMAMURA);
ovo.set_num_classes(3);
ovo.rescale_outputs(scores);

//SGVector<float64_t>::display_vector(scores.vector,scores.vlen);

EXPECT_NEAR(scores[0],0.3333333333333333,1E-5);
EXPECT_NEAR(scores[1],0.3333333333333333,1E-5);
EXPECT_NEAR(scores[2],0.3333333333333333,1E-5);
}

0 comments on commit 2bb00e1

Please sign in to comment.