Skip to content

Commit

Permalink
Update FisherLDA api
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and vigsterkr committed Jun 4, 2018
1 parent 23c34e9 commit ca8397d
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 42 deletions.
42 changes: 12 additions & 30 deletions src/shogun/preprocessor/FisherLDA.cpp
Expand Up @@ -50,10 +50,12 @@ using namespace Eigen;
using namespace shogun;

CFisherLDA::CFisherLDA(
EFLDAMethod method, float64_t thresh, float64_t gamma, bool bdc_svd)
int32_t num_dimensions, EFLDAMethod method, float64_t thresh,
float64_t gamma, bool bdc_svd)
: CDimensionReductionPreprocessor()
{
initialize_parameters();
m_num_dim = num_dimensions;
m_method=method;
m_threshold=thresh;
m_gamma = gamma;
Expand Down Expand Up @@ -90,17 +92,10 @@ CFisherLDA::~CFisherLDA()
{
}

bool CFisherLDA::fit(CFeatures *features, CLabels *labels, int32_t num_dimensions)
void CFisherLDA::fit(CFeatures* features, CLabels* labels)
{
REQUIRE(features, "Features are not provided!\n")

REQUIRE(features->get_feature_class()==C_DENSE,
"LDA only works with dense features. you provided %s\n",
features->get_name());

REQUIRE(features->get_feature_type()==F_DREAL,
"LDA only works with real features.\n");

REQUIRE(labels, "Labels for the given features are not specified!\n")

REQUIRE(
Expand All @@ -109,8 +104,7 @@ bool CFisherLDA::fit(CFeatures *features, CLabels *labels, int32_t num_dimension
"the type MulticlassLabels! you provided %s\n",
labels->get_name());

CDenseFeatures<float64_t>* dense_features =
static_cast<CDenseFeatures<float64_t>*>(features);
auto dense_features = features->as<CDenseFeatures<float64_t>>();
CMulticlassLabels* multiclass_labels =
static_cast<CMulticlassLabels*>(labels);

Expand All @@ -127,8 +121,6 @@ bool CFisherLDA::fit(CFeatures *features, CLabels *labels, int32_t num_dimension

REQUIRE(num_class > 1, "At least two classes are needed to perform LDA.\n")

m_num_dim=num_dimensions;

// clip number if Dimensions to be a valid number
if ((m_num_dim <= 0) || (m_num_dim > (num_class - 1)))
m_num_dim = (num_class - 1);
Expand All @@ -137,12 +129,12 @@ bool CFisherLDA::fit(CFeatures *features, CLabels *labels, int32_t num_dimension
m_method == AUTO_FLDA && num_vectors < num_features;

if ((m_method == CANVAR_FLDA) || lda_more_efficient)
return solver_canvar(dense_features, multiclass_labels);
solver_canvar(dense_features, multiclass_labels);
else
return solver_classic(dense_features, multiclass_labels);
solver_classic(dense_features, multiclass_labels);
}

bool CFisherLDA::solver_canvar(
void CFisherLDA::solver_canvar(
CDenseFeatures<float64_t>* features, CMulticlassLabels* labels)
{
auto solver = std::unique_ptr<LDACanVarSolver<float64_t>>(
Expand All @@ -151,11 +143,9 @@ bool CFisherLDA::solver_canvar(

m_transformation_matrix = solver->get_eigenvectors();
m_eigenvalues_vector = solver->get_eigenvalues();

return true;
}

bool CFisherLDA::solver_classic(
void CFisherLDA::solver_classic(
CDenseFeatures<float64_t>* features, CMulticlassLabels* labels)
{
SGMatrix<float64_t> data = features->get_feature_matrix();
Expand Down Expand Up @@ -199,8 +189,6 @@ bool CFisherLDA::solver_classic(
m_eigenvalues_vector[i] = eigenvalues[k];
m_transformation_matrix.set_column(k, eigenvectors.get_column(i));
}

return true;
}

void CFisherLDA::cleanup()
Expand All @@ -212,14 +200,8 @@ void CFisherLDA::cleanup()

SGMatrix<float64_t> CFisherLDA::apply_to_feature_matrix(CFeatures*features)
{
REQUIRE(features->get_feature_class()==C_DENSE,
"LDA only works with dense features\n");

REQUIRE(features->get_feature_type()==F_DREAL,
"LDA only works with real features\n");

SGMatrix<float64_t> m =
((CDenseFeatures<float64_t>*)features)->get_feature_matrix();
auto simple_features = features->as<CDenseFeatures<float64_t>>();
auto m = simple_features->get_feature_matrix();

int32_t num_vectors=m.num_cols;
int32_t num_features=m.num_rows;
Expand All @@ -244,7 +226,7 @@ SGMatrix<float64_t> CFisherLDA::apply_to_feature_matrix(CFeatures*features)
}
m.num_rows=m_num_dim;
m.num_cols=num_vectors;
((CDenseFeatures<float64_t>*)features)->set_feature_matrix(m);
simple_features->set_feature_matrix(m);
return m;
}

Expand Down
12 changes: 6 additions & 6 deletions src/shogun/preprocessor/FisherLDA.h
Expand Up @@ -93,6 +93,7 @@ class CFisherLDA: public CDimensionReductionPreprocessor
{
public:
/** standard constructor
* @param num_dimensions number of dimensions to retain
* @param method LDA based on :
* ::CLASSIC_FLDA/::CANVAR_FLDA/::AUTO_FLDA[default]
* @param thresh threshold value for ::CANVAR_FLDA only. This is used to
Expand All @@ -107,8 +108,8 @@ class CFisherLDA: public CDimensionReductionPreprocessor
* [default = BDC-SVD]
*/
CFisherLDA(
EFLDAMethod method = AUTO_FLDA, float64_t thresh = 0.01,
float64_t gamma = 0, bool bdc_svd = true);
int32_t num_dimensions = 0, EFLDAMethod method = AUTO_FLDA,
float64_t thresh = 0.01, float64_t gamma = 0, bool bdc_svd = true);

/** destructor */
virtual ~CFisherLDA();
Expand All @@ -117,9 +118,8 @@ class CFisherLDA: public CDimensionReductionPreprocessor
* @param features using which the transformation matrix will be formed
* @param labels of the given features which will be used here to find
* the transformation matrix unlike PCA where it is not needed.
* @param num_dimensions number of dimensions to retain
*/
virtual bool fit(CFeatures* features, CLabels* labels, int32_t num_dimensions=0);
virtual void fit(CFeatures* features, CLabels* labels);

/** cleanup */
virtual void cleanup();
Expand Down Expand Up @@ -165,15 +165,15 @@ class CFisherLDA: public CDimensionReductionPreprocessor
* @param features training data.
* @param labels multiclass labels.
*/
bool solver_canvar(
void solver_canvar(
CDenseFeatures<float64_t>* features, CMulticlassLabels* labels);

/**
* Train the preprocessor with the classic method.
* @param features training data.
* @param labels multiclass labels.
*/
bool solver_classic(
void solver_classic(
CDenseFeatures<float64_t>* features, CMulticlassLabels* labels);

/** transformation matrix */
Expand Down
6 changes: 6 additions & 0 deletions src/shogun/transformer/Transformer.h
Expand Up @@ -42,6 +42,12 @@ namespace shogun
virtual void fit(CFeatures* features)
{
}

/** Fit transformer to features and labels */
virtual void fit(CFeatures* features, CLabels* labels)
{
SG_SNOTIMPLEMENTED;
}
};
}
#endif /* TRANSFORMER_H_ */
12 changes: 6 additions & 6 deletions tests/unit/preprocessor/FisherLDA_unittest.cc
Expand Up @@ -143,8 +143,8 @@ TEST_F(FLDATest, CANVAR_FLDA_Unit_test)

// comparing outputs against BRMLtoolbox MATLAB "CannonVar.m" implementation
// http://web4.cs.ucl.ac.uk/staff/D.Barber/pmwiki/pmwiki.php?n=Brml.Software
CFisherLDA fisherlda(CANVAR_FLDA);
fisherlda.fit(dense_feat, labels, 1);
CFisherLDA fisherlda(1, CANVAR_FLDA);
fisherlda.fit(dense_feat, labels);
SGMatrix<float64_t> y=fisherlda.apply_to_feature_matrix(dense_feat);

float64_t epsilon=0.00000000001;
Expand Down Expand Up @@ -181,8 +181,8 @@ TEST_F(FLDATest, CLASSIC_FLDA_Unit_test)
SG_REF(dense_feat);
SG_REF(labels);

CFisherLDA fisherlda(CLASSIC_FLDA);
fisherlda.fit(dense_feat, labels, 1);
CFisherLDA fisherlda(1, CLASSIC_FLDA);
fisherlda.fit(dense_feat, labels);
SGMatrix<float64_t> y=fisherlda.apply_to_feature_matrix(dense_feat);

float64_t epsilon=0.00000000001;
Expand Down Expand Up @@ -266,8 +266,8 @@ TEST_F(FLDATest, CANVAR_FLDA_for_D_greater_than_N )
SG_REF(l);
SG_REF(df);

CFisherLDA fisherlda(CANVAR_FLDA);
fisherlda.fit(df, l, 1);
CFisherLDA fisherlda(1, CANVAR_FLDA);
fisherlda.fit(df, l);
SGMatrix<float64_t> transformy=fisherlda.get_transformation_matrix();

// comparing eigenvectors from the transformation_matrix with that from the
Expand Down

0 comments on commit ca8397d

Please sign in to comment.