Skip to content

Commit

Permalink
added serialisation support for multiclass labels now that the confid…
Browse files Browse the repository at this point in the history
…ences are stored in a matrix
  • Loading branch information
karlnapf committed Mar 14, 2013
1 parent a76e700 commit ea53122
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
39 changes: 28 additions & 11 deletions src/shogun/labels/MulticlassLabels.cpp
Expand Up @@ -6,60 +6,77 @@ using namespace shogun;

CMulticlassLabels::CMulticlassLabels() : CDenseLabels()
{
init();
}

CMulticlassLabels::CMulticlassLabels(int32_t num_labels) : CDenseLabels(num_labels)
{
init();
}

CMulticlassLabels::CMulticlassLabels(const SGVector<float64_t> src) : CDenseLabels()
{
init();
set_labels(src);
}

CMulticlassLabels::CMulticlassLabels(CFile* loader) : CDenseLabels(loader)
{
init();
}

CMulticlassLabels::~CMulticlassLabels()
{
}

void CMulticlassLabels::init()
{
SG_ADD(&m_multiclass_confidences, "multiclass_confidences", "Vectors of "
"multiclass confidences", MS_NOT_AVAILABLE);

m_multiclass_confidences=SGMatrix<float64_t>();
}

void CMulticlassLabels::set_multiclass_confidences(int32_t i, SGVector<float64_t> confidences)
{
REQUIRE(confidences.size()==m_multiclass_confidences.num_rows,"Length of confidences should match size of the matrix");
REQUIRE(confidences.size()==m_multiclass_confidences.num_rows,
"%s::set_multiclass_confidences(): Length of confidences should "
"match size of the matrix", get_name());

for (index_t j=0; j<confidences.size(); j++)
{
m_multiclass_confidences(j,i) = confidences[j];
}
}

SGVector<float64_t> CMulticlassLabels::get_multiclass_confidences(int32_t i)
{
SGVector<float64_t> confs(m_multiclass_confidences.num_rows);
for (index_t j=0; j<confs.size(); j++)
{
confs[j] = m_multiclass_confidences(j,i);
}

return confs;
}

void CMulticlassLabels::allocate_confidences_for(int32_t n_classes)
{
int32_t n_labels = m_labels.size();
REQUIRE(n_labels!=0,"There should be labels to store confidences");
REQUIRE(n_labels!=0,"%s::allocate_confidences_for(): There should be "
"labels to store confidences", get_name());

m_multiclass_confidences = SGMatrix<float64_t>(n_classes,n_labels);
}

CMulticlassLabels* CMulticlassLabels::obtain_from_generic(CLabels* base_labels)
{
if ( base_labels->get_label_type() == LT_MULTICLASS )
return (CMulticlassLabels*) base_labels;
else
SG_SERROR("base_labels must be of dynamic type CMulticlassLabels")
if (!base_labels)
return NULL;

if(base_labels->get_label_type()!=LT_MULTICLASS)
{
SG_SERROR("CMulticlassLabels::base_labels is of wrong type \"%s\"!\n",
base_labels->get_name());
}

return NULL;
return (CMulticlassLabels*) base_labels;
}

void CMulticlassLabels::ensure_valid(const char* context)
Expand Down
4 changes: 4 additions & 0 deletions src/shogun/labels/MulticlassLabels.h
Expand Up @@ -127,6 +127,10 @@ class CMulticlassLabels : public CDenseLabels
/** @return object name */
virtual const char* get_name() const { return "MulticlassLabels"; }

private:
/** initialises and register parameters */
void init();

protected:

/** multiclass confidences */
Expand Down

0 comments on commit ea53122

Please sign in to comment.