diff --git a/src/shogun/preprocessor/FisherLDA.cpp b/src/shogun/preprocessor/FisherLDA.cpp index 1b103a1ab7e..24fe7bf8a1c 100644 --- a/src/shogun/preprocessor/FisherLDA.cpp +++ b/src/shogun/preprocessor/FisherLDA.cpp @@ -50,10 +50,12 @@ using namespace Eigen; using namespace shogun; CFisherLDA::CFisherLDA( - EFLDAMethod method, float64_t thresh, float64_t gamma, bool bdc_svd) + int32_t num_dimensions, EFLDAMethod method, float64_t thresh, + float64_t gamma, bool bdc_svd) : CDimensionReductionPreprocessor() { initialize_parameters(); + m_num_dim = num_dimensions; m_method=method; m_threshold=thresh; m_gamma = gamma; @@ -90,17 +92,10 @@ CFisherLDA::~CFisherLDA() { } -bool CFisherLDA::fit(CFeatures *features, CLabels *labels, int32_t num_dimensions) +void CFisherLDA::fit(CFeatures* features, CLabels* labels) { REQUIRE(features, "Features are not provided!\n") - REQUIRE(features->get_feature_class()==C_DENSE, - "LDA only works with dense features. you provided %s\n", - features->get_name()); - - REQUIRE(features->get_feature_type()==F_DREAL, - "LDA only works with real features.\n"); - REQUIRE(labels, "Labels for the given features are not specified!\n") REQUIRE( @@ -109,8 +104,7 @@ bool CFisherLDA::fit(CFeatures *features, CLabels *labels, int32_t num_dimension "the type MulticlassLabels! you provided %s\n", labels->get_name()); - CDenseFeatures* dense_features = - static_cast*>(features); + auto dense_features = features->as>(); CMulticlassLabels* multiclass_labels = static_cast(labels); @@ -127,8 +121,6 @@ bool CFisherLDA::fit(CFeatures *features, CLabels *labels, int32_t num_dimension REQUIRE(num_class > 1, "At least two classes are needed to perform LDA.\n") - m_num_dim=num_dimensions; - // clip number if Dimensions to be a valid number if ((m_num_dim <= 0) || (m_num_dim > (num_class - 1))) m_num_dim = (num_class - 1); @@ -137,12 +129,12 @@ bool CFisherLDA::fit(CFeatures *features, CLabels *labels, int32_t num_dimension m_method == AUTO_FLDA && num_vectors < num_features; if ((m_method == CANVAR_FLDA) || lda_more_efficient) - return solver_canvar(dense_features, multiclass_labels); + solver_canvar(dense_features, multiclass_labels); else - return solver_classic(dense_features, multiclass_labels); + solver_classic(dense_features, multiclass_labels); } -bool CFisherLDA::solver_canvar( +void CFisherLDA::solver_canvar( CDenseFeatures* features, CMulticlassLabels* labels) { auto solver = std::unique_ptr>( @@ -151,11 +143,9 @@ bool CFisherLDA::solver_canvar( m_transformation_matrix = solver->get_eigenvectors(); m_eigenvalues_vector = solver->get_eigenvalues(); - - return true; } -bool CFisherLDA::solver_classic( +void CFisherLDA::solver_classic( CDenseFeatures* features, CMulticlassLabels* labels) { SGMatrix data = features->get_feature_matrix(); @@ -199,8 +189,6 @@ bool CFisherLDA::solver_classic( m_eigenvalues_vector[i] = eigenvalues[k]; m_transformation_matrix.set_column(k, eigenvectors.get_column(i)); } - - return true; } void CFisherLDA::cleanup() @@ -212,14 +200,8 @@ void CFisherLDA::cleanup() SGMatrix CFisherLDA::apply_to_feature_matrix(CFeatures*features) { - REQUIRE(features->get_feature_class()==C_DENSE, - "LDA only works with dense features\n"); - - REQUIRE(features->get_feature_type()==F_DREAL, - "LDA only works with real features\n"); - - SGMatrix m = - ((CDenseFeatures*)features)->get_feature_matrix(); + auto simple_features = features->as>(); + auto m = simple_features->get_feature_matrix(); int32_t num_vectors=m.num_cols; int32_t num_features=m.num_rows; @@ -244,7 +226,7 @@ SGMatrix CFisherLDA::apply_to_feature_matrix(CFeatures*features) } m.num_rows=m_num_dim; m.num_cols=num_vectors; - ((CDenseFeatures*)features)->set_feature_matrix(m); + simple_features->set_feature_matrix(m); return m; } diff --git a/src/shogun/preprocessor/FisherLDA.h b/src/shogun/preprocessor/FisherLDA.h index 10e1cf392f5..89eb683aeeb 100644 --- a/src/shogun/preprocessor/FisherLDA.h +++ b/src/shogun/preprocessor/FisherLDA.h @@ -93,6 +93,7 @@ class CFisherLDA: public CDimensionReductionPreprocessor { public: /** standard constructor + * @param num_dimensions number of dimensions to retain * @param method LDA based on : * ::CLASSIC_FLDA/::CANVAR_FLDA/::AUTO_FLDA[default] * @param thresh threshold value for ::CANVAR_FLDA only. This is used to @@ -107,8 +108,8 @@ class CFisherLDA: public CDimensionReductionPreprocessor * [default = BDC-SVD] */ CFisherLDA( - EFLDAMethod method = AUTO_FLDA, float64_t thresh = 0.01, - float64_t gamma = 0, bool bdc_svd = true); + int32_t num_dimensions = 0, EFLDAMethod method = AUTO_FLDA, + float64_t thresh = 0.01, float64_t gamma = 0, bool bdc_svd = true); /** destructor */ virtual ~CFisherLDA(); @@ -117,9 +118,8 @@ class CFisherLDA: public CDimensionReductionPreprocessor * @param features using which the transformation matrix will be formed * @param labels of the given features which will be used here to find * the transformation matrix unlike PCA where it is not needed. - * @param num_dimensions number of dimensions to retain */ - virtual bool fit(CFeatures* features, CLabels* labels, int32_t num_dimensions=0); + virtual void fit(CFeatures* features, CLabels* labels); /** cleanup */ virtual void cleanup(); @@ -165,7 +165,7 @@ class CFisherLDA: public CDimensionReductionPreprocessor * @param features training data. * @param labels multiclass labels. */ - bool solver_canvar( + void solver_canvar( CDenseFeatures* features, CMulticlassLabels* labels); /** @@ -173,7 +173,7 @@ class CFisherLDA: public CDimensionReductionPreprocessor * @param features training data. * @param labels multiclass labels. */ - bool solver_classic( + void solver_classic( CDenseFeatures* features, CMulticlassLabels* labels); /** transformation matrix */ diff --git a/src/shogun/transformer/Transformer.h b/src/shogun/transformer/Transformer.h index 55ae65546ae..72ce253c311 100644 --- a/src/shogun/transformer/Transformer.h +++ b/src/shogun/transformer/Transformer.h @@ -42,6 +42,12 @@ namespace shogun virtual void fit(CFeatures* features) { } + + /** Fit transformer to features and labels */ + virtual void fit(CFeatures* features, CLabels* labels) + { + SG_SNOTIMPLEMENTED; + } }; } #endif /* TRANSFORMER_H_ */ diff --git a/tests/unit/preprocessor/FisherLDA_unittest.cc b/tests/unit/preprocessor/FisherLDA_unittest.cc index 15ce676b5f0..9a1b00078fb 100644 --- a/tests/unit/preprocessor/FisherLDA_unittest.cc +++ b/tests/unit/preprocessor/FisherLDA_unittest.cc @@ -143,8 +143,8 @@ TEST_F(FLDATest, CANVAR_FLDA_Unit_test) // comparing outputs against BRMLtoolbox MATLAB "CannonVar.m" implementation // http://web4.cs.ucl.ac.uk/staff/D.Barber/pmwiki/pmwiki.php?n=Brml.Software - CFisherLDA fisherlda(CANVAR_FLDA); - fisherlda.fit(dense_feat, labels, 1); + CFisherLDA fisherlda(1, CANVAR_FLDA); + fisherlda.fit(dense_feat, labels); SGMatrix y=fisherlda.apply_to_feature_matrix(dense_feat); float64_t epsilon=0.00000000001; @@ -181,8 +181,8 @@ TEST_F(FLDATest, CLASSIC_FLDA_Unit_test) SG_REF(dense_feat); SG_REF(labels); - CFisherLDA fisherlda(CLASSIC_FLDA); - fisherlda.fit(dense_feat, labels, 1); + CFisherLDA fisherlda(1, CLASSIC_FLDA); + fisherlda.fit(dense_feat, labels); SGMatrix y=fisherlda.apply_to_feature_matrix(dense_feat); float64_t epsilon=0.00000000001; @@ -266,8 +266,8 @@ TEST_F(FLDATest, CANVAR_FLDA_for_D_greater_than_N ) SG_REF(l); SG_REF(df); - CFisherLDA fisherlda(CANVAR_FLDA); - fisherlda.fit(df, l, 1); + CFisherLDA fisherlda(1, CANVAR_FLDA); + fisherlda.fit(df, l); SGMatrix transformy=fisherlda.get_transformation_matrix(); // comparing eigenvectors from the transformation_matrix with that from the