Permalink
Browse files

Merge branch 'multiclass-ecoc' of git://github.com/pluskid/shogun

  • Loading branch information...
2 parents 1927645 + 1981a90 commit 3be395f1055f986e2976fda7c8879f529f560649 @lisitsyn lisitsyn committed Apr 28, 2012
@@ -35,6 +35,9 @@ CMulticlassMachine::CMulticlassMachine(
SG_REF(machine);
m_machine = machine;
register_parameters();
+
+ if (labs)
+ init_strategy();
}
CMulticlassMachine::~CMulticlassMachine()
@@ -44,13 +47,26 @@ CMulticlassMachine::~CMulticlassMachine()
SG_UNREF(m_machines);
}
+void CMulticlassMachine::set_labels(CLabels* lab)
+{
+ CMachine::set_labels(lab);
+ if (lab)
+ init_strategy();
+}
+
void CMulticlassMachine::register_parameters()
{
SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE);
SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE);
SG_ADD((CSGObject**)&m_machines, "machines", "Machines that jointly make up the multi-class machine.", MS_NOT_AVAILABLE);
}
+void CMulticlassMachine::init_strategy()
+{
+ int32_t num_classes = m_labels->get_num_classes();
+ m_multiclass_strategy->set_num_classes(num_classes);
+}
+
CLabels* CMulticlassMachine::apply(CFeatures* features)
{
init_machines_for_apply(features);
@@ -64,7 +80,6 @@ CLabels* CMulticlassMachine::apply()
if (is_ready())
{
- int32_t num_classes=m_labels->get_num_classes();
int32_t num_vectors=get_num_rhs_vectors();
int32_t num_machines=m_machines->get_num_elements();
if (num_machines <= 0)
@@ -88,15 +103,15 @@ CLabels* CMulticlassMachine::apply()
for (int32_t j=0; j<num_machines; j++)
output_for_i[j] = outputs[j]->get_label(i);
- result->set_label(i, m_multiclass_strategy->decide_label(output_for_i, num_classes));
+ result->set_label(i, m_multiclass_strategy->decide_label(output_for_i));
}
output_for_i.destroy_vector();
for (int32_t i=0; i < num_machines; ++i)
SG_UNREF(outputs[i]);
SG_FREE(outputs);
-
+
return result;
}
else
@@ -160,7 +175,7 @@ float64_t CMulticlassMachine::apply(int32_t num)
SG_UNREF(machine);
}
- float64_t result=m_multiclass_strategy->decide_label(outputs, m_labels->get_num_classes());
+ float64_t result=m_multiclass_strategy->decide_label(outputs);
outputs.destroy_vector();
return result;
@@ -39,6 +39,12 @@ class CMulticlassMachine : public CMachine
/** destructor */
virtual ~CMulticlassMachine();
+ /** set labels
+ *
+ * @param lab labels
+ */
+ virtual void set_labels(CLabels* lab);
+
/** set machine
*
* @param num index of machine
@@ -110,6 +116,8 @@ class CMulticlassMachine : public CMachine
}
protected:
+ /** init strategy */
+ void init_strategy();
/** clear machines */
void clear_machines();
@@ -65,7 +65,7 @@ bool CGMNPSVM::train_machine(CFeatures* data)
}
int32_t num_data = m_labels->get_num_labels();
- int32_t num_classes = m_labels->get_num_classes();
+ int32_t num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_virtual_data= num_data*(num_classes-1);
SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes);
@@ -35,7 +35,7 @@ bool CMulticlassLibSVM::train_machine(CFeatures* data)
problem = svm_problem();
ASSERT(m_labels && m_labels->get_num_labels());
- int32_t num_classes = m_labels->get_num_classes();
+ int32_t num_classes = m_multiclass_strategy->get_num_classes();
problem.l=m_labels->get_num_labels();
SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes);
@@ -69,7 +69,7 @@ bool CMulticlassOCAS::train_machine(CFeatures* data)
set_features((CDotFeatures*)data);
int32_t num_vectors = m_features->get_num_vectors();
- int32_t num_classes = m_labels->get_num_classes();
+ int32_t num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_features = m_features->get_dim_feature_space();
float64_t C = m_C;
@@ -13,14 +13,13 @@
using namespace shogun;
CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy()
- :CMulticlassStrategy(), m_num_machines(0), m_num_classes(0)
+ :CMulticlassStrategy(), m_num_machines(0)
{
}
void CMulticlassOneVsOneStrategy::train_start(CLabels *orig_labels, CLabels *train_labels)
{
CMulticlassStrategy::train_start(orig_labels, train_labels);
- m_num_classes = m_orig_labels->get_num_classes();
m_num_machines=m_num_classes*(m_num_classes-1)/2;
m_train_pair_idx_1 = 0;
@@ -64,15 +63,15 @@ SGVector<int32_t> CMulticlassOneVsOneStrategy::train_prepare_next()
return SGVector<int32_t>(subset.vector, tot);
}
-int32_t CMulticlassOneVsOneStrategy::decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)
+int32_t CMulticlassOneVsOneStrategy::decide_label(const SGVector<float64_t> &outputs)
{
int32_t s=0;
- SGVector<int32_t> votes(num_classes);
+ SGVector<int32_t> votes(m_num_classes);
votes.zero();
- for (int32_t i=0; i<num_classes; i++)
+ for (int32_t i=0; i<m_num_classes; i++)
{
- for (int32_t j=i+1; j<num_classes; j++)
+ for (int32_t j=i+1; j<m_num_classes; j++)
{
if (outputs[s++]>0)
votes[i]++;
@@ -12,7 +12,7 @@
namespace shogun
{
-
+
class CMulticlassOneVsOneStrategy: public CMulticlassStrategy
{
public:
@@ -35,22 +35,14 @@ class CMulticlassOneVsOneStrategy: public CMulticlassStrategy
/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
- * @param num_classes number of classes
*/
- virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes);
+ virtual int32_t decide_label(const SGVector<float64_t> &outputs);
/** get number of machines used in this strategy.
- * @param num_classes number of classes in this problem
*/
- virtual int32_t get_num_machines(int32_t num_classes)
- {
- return num_classes*(num_classes-1)/2;
- }
-
- /** get strategy type */
- virtual EMulticlassStrategy get_strategy_type()
+ virtual int32_t get_num_machines()
{
- return ONE_VS_ONE_STRATEGY;
+ return m_num_classes*(m_num_classes-1)/2;
}
/** get name */
@@ -61,7 +53,6 @@ class CMulticlassOneVsOneStrategy: public CMulticlassStrategy
protected:
int32_t m_num_machines;
- int32_t m_num_classes;
int32_t m_train_pair_idx_1;
int32_t m_train_pair_idx_2;
};
@@ -13,12 +13,12 @@
using namespace shogun;
CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy()
- :CMulticlassStrategy(), m_num_machines(0), m_rejection_strategy(NULL)
+ :CMulticlassStrategy(), m_rejection_strategy(NULL)
{
}
CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy(CRejectionStrategy *rejection_strategy)
- :CMulticlassStrategy(), m_num_machines(0), m_rejection_strategy(rejection_strategy)
+ :CMulticlassStrategy(), m_rejection_strategy(rejection_strategy)
{
SG_REF(m_rejection_strategy);
}
@@ -39,7 +39,7 @@ SGVector<int32_t> CMulticlassOneVsRestStrategy::train_prepare_next()
return SGVector<int32_t>();
}
-int32_t CMulticlassOneVsRestStrategy::decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)
+int32_t CMulticlassOneVsRestStrategy::decide_label(const SGVector<float64_t> &outputs)
{
if (m_rejection_strategy && m_rejection_strategy->reject(outputs))
return CLabels::REJECTION_LABEL;
@@ -47,38 +47,29 @@ class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
virtual void train_start(CLabels *orig_labels, CLabels *train_labels)
{
CMulticlassStrategy::train_start(orig_labels, train_labels);
- m_num_machines=m_orig_labels->get_num_classes();
}
/** has more training phase */
virtual bool train_has_more()
{
- return m_train_iter < m_num_machines;
+ return m_train_iter < m_num_classes;
}
/** prepare for the next training phase.
* @return NULL, since no subset is needed in one-vs-rest strategy
- */
+ */
virtual SGVector<int32_t> train_prepare_next();
/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
- * @param num_classes number of classes
*/
- virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes);
+ virtual int32_t decide_label(const SGVector<float64_t> &outputs);
/** get number of machines used in this strategy.
- * @param num_classes number of classes in this problem
*/
- virtual int32_t get_num_machines(int32_t num_classes)
- {
- return num_classes;
- }
-
- /** get strategy type */
- virtual EMulticlassStrategy get_strategy_type()
+ virtual int32_t get_num_machines()
{
- return ONE_VS_REST_STRATEGY;
+ return m_num_classes;
}
/** get name */
@@ -88,7 +79,6 @@ class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
};
protected:
- int32_t m_num_machines;
CRejectionStrategy *m_rejection_strategy;
};
@@ -46,7 +46,7 @@ bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes)
{
if (num_classes>0)
{
- int32_t num_svms=m_multiclass_strategy->get_num_machines(num_classes);
+ int32_t num_svms=m_multiclass_strategy->get_num_machines();
m_machines->clear_array();
for (index_t i=0; i<num_svms; ++i)
@@ -128,16 +128,6 @@ bool CMulticlassSVM::load(FILE* modelfl)
}
int_buffer=0;
- if (fscanf(modelfl," multiclass_strategy=%d; \n", &int_buffer) != 1)
- SG_ERROR( "error in svm file, line nr:%d\n", line_number);
-
- if (!feof(modelfl))
- line_number++;
-
- if (int_buffer != m_multiclass_strategy->get_strategy_type())
- SG_ERROR("multiclass strategy does not match %ld vs. %ld\n", int_buffer, m_multiclass_strategy->get_strategy_type());
-
- int_buffer=0;
if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
SG_ERROR( "error in svm file, line nr:%d\n", line_number);
@@ -275,8 +265,7 @@ bool CMulticlassSVM::save(FILE* modelfl)
SG_INFO( "Writing model file...");
fprintf(modelfl,"%%MultiClassSVM\n");
- fprintf(modelfl,"multiclass_strategy=%d;\n", m_multiclass_strategy->get_strategy_type());
- fprintf(modelfl,"num_classes=%d;\n", m_labels->get_num_classes());
+ fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes());
fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements());
fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name());
@@ -17,6 +17,7 @@ using namespace shogun;
CMulticlassStrategy::CMulticlassStrategy()
:m_train_labels(NULL), m_orig_labels(NULL), m_train_iter(0)
{
+ SG_ADD(&m_num_classes, "num_classes", "Number of classes", MS_NOT_AVAILABLE);
}
void CMulticlassStrategy::train_start(CLabels *orig_labels, CLabels *train_labels)
@@ -18,14 +18,6 @@
namespace shogun
{
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-enum EMulticlassStrategy
-{
- ONE_VS_REST_STRATEGY,
- ONE_VS_ONE_STRATEGY,
-};
-#endif
-
class CMulticlassStrategy: public CSGObject
{
public:
@@ -41,8 +33,17 @@ class CMulticlassStrategy: public CSGObject
return "MulticlassStrategy";
};
- /** get strategy type */
- virtual EMulticlassStrategy get_strategy_type()=0;
+ /** set number of classes */
+ void set_num_classes(int32_t num_classes)
+ {
+ m_num_classes = num_classes;
+ }
+
+ /** get number of classes */
+ int32_t get_num_classes() const
+ {
+ return m_num_classes;
+ }
/** start training */
virtual void train_start(CLabels *orig_labels, CLabels *train_labels);
@@ -60,19 +61,18 @@ class CMulticlassStrategy: public CSGObject
/** decide the final label.
* @param outputs a vector of output from each machine (in that order)
- * @param num_classes number of classes
*/
- virtual int32_t decide_label(const SGVector<float64_t> &outputs, int32_t num_classes)=0;
+ virtual int32_t decide_label(const SGVector<float64_t> &outputs)=0;
/** get number of machines used in this strategy.
- * @param num_classes number of classes in this problem
*/
- virtual int32_t get_num_machines(int32_t num_classes)=0;
+ virtual int32_t get_num_machines()=0;
protected:
CLabels *m_train_labels;
CLabels *m_orig_labels;
int32_t m_train_iter;
+ int32_t m_num_classes;
};
} // namespace shogun
@@ -48,7 +48,7 @@ CScatterSVM::~CScatterSVM()
bool CScatterSVM::train_machine(CFeatures* data)
{
ASSERT(m_labels && m_labels->get_num_labels());
- m_num_classes = m_labels->get_num_classes();
+ m_num_classes = m_multiclass_strategy->get_num_classes();
int32_t num_vectors = m_labels->get_num_labels();
if (data)

0 comments on commit 3be395f

Please sign in to comment.