Permalink
Browse files

Use the remembered num_classes in a MulticlassStrategy

  • Loading branch information...
1 parent 7830726 commit 022c18ddeb8fb5784b514efcb0c430b3f68a7ca9 @pluskid pluskid committed Apr 28, 2012
@@ -80,7 +80,6 @@ CLabels* CMulticlassMachine::apply()
if (is_ready())
{
- int32_t num_classes=m_labels->get_num_classes();
int32_t num_vectors=get_num_rhs_vectors();
int32_t num_machines=m_machines->get_num_elements();
if (num_machines <= 0)
@@ -104,7 +103,7 @@ CLabels* CMulticlassMachine::apply()
for (int32_t j=0; j<num_machines; j++)
output_for_i[j] = outputs[j]->get_label(i);
- result->set_label(i, m_multiclass_strategy->decide_label(output_for_i, num_classes));
+ result->set_label(i, m_multiclass_strategy->decide_label(output_for_i));
}
output_for_i.destroy_vector();
@@ -176,7 +175,7 @@ float64_t CMulticlassMachine::apply(int32_t num)
SG_UNREF(machine);
}
- float64_t result=m_multiclass_strategy->decide_label(outputs, m_labels->get_num_classes());
+ float64_t result=m_multiclass_strategy->decide_label(outputs);
outputs.destroy_vector();
return result;
@@ -65,7 +65,7 @@ bool CGMNPSVM::train_machine(CFeatures* data)
}
int32_t num_data = m_labels->get_num_labels();
- int32_t num_classes = m_labels->get_num_classes();
+ int32_t num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_virtual_data= num_data*(num_classes-1);
SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes);
@@ -35,7 +35,7 @@ bool CMulticlassLibSVM::train_machine(CFeatures* data)
problem = svm_problem();
ASSERT(m_labels && m_labels->get_num_labels());
- int32_t num_classes = m_labels->get_num_classes();
+ int32_t num_classes = m_multiclass_strategy->get_num_classes();
problem.l=m_labels->get_num_labels();
SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes);
@@ -69,7 +69,7 @@ bool CMulticlassOCAS::train_machine(CFeatures* data)
set_features((CDotFeatures*)data);
int32_t num_vectors = m_features->get_num_vectors();
- int32_t num_classes = m_labels->get_num_classes();
+ int32_t num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_features = m_features->get_dim_feature_space();
float64_t C = m_C;
@@ -63,15 +63,15 @@ SGVector<int32_t> CMulticlassOneVsOneStrategy::train_prepare_next()
return SGVector<int32_t>(subset.vector, tot);
}
-int32_t CMulticlassOneVsOneStrategy::decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)
+int32_t CMulticlassOneVsOneStrategy::decide_label(const SGVector<float64_t> &outputs)
{
int32_t s=0;
- SGVector<int32_t> votes(num_classes);
+ SGVector<int32_t> votes(m_num_classes);
votes.zero();
- for (int32_t i=0; i<num_classes; i++)
+ for (int32_t i=0; i<m_num_classes; i++)
{
- for (int32_t j=i+1; j<num_classes; j++)
+ for (int32_t j=i+1; j<m_num_classes; j++)
{
if (outputs[s++]>0)
votes[i]++;
@@ -35,9 +35,8 @@ class CMulticlassOneVsOneStrategy: public CMulticlassStrategy
/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
- * @param num_classes number of classes
*/
- virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes);
+ virtual int32_t decide_label(const SGVector<float64_t> &outputs);
/** get number of machines used in this strategy.
*/
@@ -39,7 +39,7 @@ SGVector<int32_t> CMulticlassOneVsRestStrategy::train_prepare_next()
return SGVector<int32_t>();
}
-int32_t CMulticlassOneVsRestStrategy::decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)
+int32_t CMulticlassOneVsRestStrategy::decide_label(const SGVector<float64_t> &outputs)
{
if (m_rejection_strategy && m_rejection_strategy->reject(outputs))
return CLabels::REJECTION_LABEL;
@@ -62,9 +62,8 @@ class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
- * @param num_classes number of classes
*/
- virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes);
+ virtual int32_t decide_label(const SGVector<float64_t> &outputs);
/** get number of machines used in this strategy.
*/
@@ -265,7 +265,7 @@ bool CMulticlassSVM::save(FILE* modelfl)
SG_INFO( "Writing model file...");
fprintf(modelfl,"%%MultiClassSVM\n");
- fprintf(modelfl,"num_classes=%d;\n", m_labels->get_num_classes());
+ fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes());
fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements());
fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name());
@@ -61,9 +61,8 @@ class CMulticlassStrategy: public CSGObject
/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
- * @param num_classes number of classes
*/
- virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)=0;
+ virtual int32_t decide_label(const SGVector<float64_t> &outputs)=0;
/** get number of machines used in this strategy.
*/
@@ -48,7 +48,7 @@ CScatterSVM::~CScatterSVM()
bool CScatterSVM::train_machine(CFeatures* data)
{
ASSERT(m_labels && m_labels->get_num_labels());
- m_num_classes = m_labels->get_num_classes();
+ m_num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_vectors = m_labels->get_num_labels();
if (data)

0 comments on commit 022c18d

Please sign in to comment.