Skip to content

Commit

Permalink
Merge 206ed42 into 8f12810
Browse files Browse the repository at this point in the history
  • Loading branch information
kislayabhi committed Jun 25, 2014
2 parents 8f12810 + 206ed42 commit 046749d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 131 deletions.
208 changes: 79 additions & 129 deletions src/shogun/classifier/LDA.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,195 +5,145 @@
* (at your option) any later version.
*
* Written (W) 1999-2009 Soeren Sonnenburg
* Written (W) 2014 Abhijeet Kislay
* Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
*/
#include <shogun/lib/config.h>

#ifdef HAVE_EIGEN3
#include <shogun/lib/common.h>

#ifdef HAVE_LAPACK
#include <shogun/machine/Machine.h>
#include <shogun/machine/LinearMachine.h>
#include <shogun/classifier/LDA.h>
#include <shogun/labels/Labels.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/mathematics/Math.h>
#include <shogun/mathematics/lapack.h>
#include <shogun/mathematics/eigen3.h>

using namespace Eigen;
using namespace shogun;

CLDA::CLDA(float64_t gamma)
: CLinearMachine(), m_gamma(gamma)
: CLinearMachine(), m_gamma(gamma)
{
}

CLDA::CLDA(float64_t gamma, CDenseFeatures<float64_t>* traindat, CLabels* trainlab)
: CLinearMachine(), m_gamma(gamma)
CLDA::CLDA(float64_t gamma, CDenseFeatures<float64_t> *traindat,
CLabels *trainlab)
: CLinearMachine(), m_gamma(gamma)
{
set_features(traindat);
set_labels(trainlab);
}


CLDA::~CLDA()
{
}

bool CLDA::train_machine(CFeatures* data)
bool CLDA::train_machine(CFeatures *data)
{
ASSERT(m_labels)
if (data)

if(data)
{
if (!data->has_property(FP_DOT))
if(!data->has_property(FP_DOT))
SG_ERROR("Specified features are not of type CDotFeatures\n")
set_features((CDotFeatures*) data);
set_features((CDotFeatures *) data);
}

ASSERT(features)
SGVector<int32_t> train_labels=((CBinaryLabels*) m_labels)->get_int_labels();
SGVector<int32_t>train_labels=((CBinaryLabels *)m_labels)->get_int_labels();
ASSERT(train_labels.vector)

int32_t num_feat=features->get_dim_feature_space();
int32_t num_vec=features->get_num_vectors();
ASSERT(num_vec==train_labels.vlen)
SGMatrix<float64_t>feature_matrix=((CDenseFeatures<float64_t> *)features)
->get_feature_matrix();
int32_t num_feat=feature_matrix.num_rows;
int32_t num_vec=feature_matrix.num_cols;
REQUIRE(num_vec ==train_labels.vlen,
"Number of training examples should be equal to number of labels specified");

int32_t* classidx_neg=SG_MALLOC(int32_t, num_vec);
int32_t* classidx_pos=SG_MALLOC(int32_t, num_vec);
SGVector<int32_t> classidx_neg(num_vec);
SGVector<int32_t> classidx_pos(num_vec);

int32_t i=0;
int32_t j=0;
int32_t num_neg=0;
int32_t num_pos=0;
for (i=0; i<train_labels.vlen; i++)

for(i=0; i<train_labels.vlen; i++)
{
if (train_labels.vector[i]==-1)
if(train_labels.vector[i]==-1)
classidx_neg[num_neg++]=i;
else if (train_labels.vector[i]==+1)
else if(train_labels.vector[i]==+1)
classidx_pos[num_pos++]=i;
else
{
SG_ERROR("found label != +/- 1 bailing...")
SG_ERROR("found label !=+/- 1 bailing...")
return false;
}
}

if (num_neg<=0 || num_pos<=0)
if(num_neg<=0 || num_pos<=0)
{
SG_ERROR("whooooo ? only a single class found\n")
return false;
}

w=SGVector<float64_t>(num_feat);

float64_t* mean_neg=SG_MALLOC(float64_t, num_feat);
memset(mean_neg,0,num_feat*sizeof(float64_t));

float64_t* mean_pos=SG_MALLOC(float64_t, num_feat);
memset(mean_pos,0,num_feat*sizeof(float64_t));

/* calling external lib */
double* scatter=SG_MALLOC(double, num_feat*num_feat);
double* buffer=SG_MALLOC(double, num_feat*CMath::max(num_neg, num_pos));
int nf = (int) num_feat;
Map<MatrixXd> fmatrix_org(feature_matrix.matrix, num_feat, num_vec);
MatrixXd fmatrix=fmatrix_org;
VectorXd mean_neg(num_feat);
VectorXd mean_pos(num_feat);

CDenseFeatures<float64_t>* rf = (CDenseFeatures<float64_t>*) features;
//mean neg
for (i=0; i<num_neg; i++)
{
int32_t vlen;
bool vfree;
float64_t* vec=
rf->get_feature_vector(classidx_neg[i], vlen, vfree);
ASSERT(vec)

for (j=0; j<vlen; j++)
{
mean_neg[j]+=vec[j];
buffer[num_feat*i+j]=vec[j];
}
for(i=0; i<num_neg; i++)
mean_neg+=fmatrix.col(classidx_neg[i]);

rf->free_feature_vector(vec, classidx_neg[i], vfree);
}
mean_neg /=num_neg;

for (j=0; j<num_feat; j++)
mean_neg[j]/=num_neg;

for (i=0; i<num_neg; i++)
{
for (j=0; j<num_feat; j++)
buffer[num_feat*i+j]-=mean_neg[j];
}
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf,
(int) num_neg, 1.0, buffer, nf, buffer, nf, 0, scatter, nf);
// get m(-ve) - mean(-ve)
for(i=0; i<num_neg; i++)
fmatrix.col(classidx_neg[i])-=mean_neg;

//mean pos
for (i=0; i<num_pos; i++)
{
int32_t vlen;
bool vfree;
float64_t* vec=
rf->get_feature_vector(classidx_pos[i], vlen, vfree);
ASSERT(vec)

for (j=0; j<vlen; j++)
{
mean_pos[j]+=vec[j];
buffer[num_feat*i+j]=vec[j];
}

rf->free_feature_vector(vec, classidx_pos[i], vfree);
}

for (j=0; j<num_feat; j++)
mean_pos[j]/=num_pos;

for (i=0; i<num_pos; i++)
{
for (j=0; j<num_feat; j++)
buffer[num_feat*i+j]-=mean_pos[j];
}
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf, (int) num_pos,
1.0/(train_labels.vlen-1), buffer, nf, buffer, nf,
1.0/(train_labels.vlen-1), scatter, nf);

float64_t trace=SGMatrix<float64_t>::trace((float64_t*) scatter, num_feat, num_feat);

double s=1.0-m_gamma; /* calling external lib; indirectly */
for (i=0; i<num_feat*num_feat; i++)
scatter[i]*=s;

for (i=0; i<num_feat; i++)
scatter[i*num_feat+i]+= trace*m_gamma/num_feat;

double* inv_scatter= (double*) SGMatrix<float64_t>::pinv(
scatter, num_feat, num_feat, NULL);

float64_t* w_pos=buffer;
float64_t* w_neg=&buffer[num_feat];

cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
(double*) mean_pos, 1, 0., (double*) w_pos, 1);
cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
(double*) mean_neg, 1, 0, (double*) w_neg, 1);

bias=0.5*(SGVector<float64_t>::dot(w_neg, mean_neg, num_feat)-SGVector<float64_t>::dot(w_pos, mean_pos, num_feat));
for (i=0; i<num_feat; i++)
w.vector[i]=w_pos[i]-w_neg[i];

#ifdef DEBUG_LDA
SG_PRINT("bias: %f\n", bias)
SGVector<float64_t>::display_vector(w.vector, num_feat, "w");
SGVector<float64_t>::display_vector(w_pos, num_feat, "w_pos");
SGVector<float64_t>::display_vector(w_neg, num_feat, "w_neg");
SGVector<float64_t>::display_vector(mean_pos, num_feat, "mean_pos");
SGVector<float64_t>::display_vector(mean_neg, num_feat, "mean_neg");
#endif

SG_FREE(mean_neg);
SG_FREE(mean_pos);
SG_FREE(scatter);
SG_FREE(inv_scatter);
SG_FREE(classidx_neg);
SG_FREE(classidx_pos);
SG_FREE(buffer);
for(i=0; i<num_pos; i++)
mean_pos+=fmatrix.col(classidx_pos[i]);

mean_pos/=num_pos;

// get m(+ve) - mean(+ve)
for(i=0; i<num_pos; i++)
fmatrix.col(classidx_pos[i])-=mean_pos;

SGMatrix<float64_t>scatter_matrix(num_feat, num_feat);
Map<MatrixXd> scatter(scatter_matrix.matrix, num_feat, num_feat);

// covariance matrix.
MatrixXd cov_mat(num_feat, num_feat);
cov_mat=fmatrix*fmatrix.transpose();
scatter=cov_mat/(num_vec-1);
float64_t trace=scatter.trace();
double s=1.0-m_gamma;
scatter *=s;
scatter.diagonal()+=VectorXd::Constant(num_feat, trace*m_gamma/num_feat);

// we need to find a Basic Linear Solution of A.x=b for 'x'.
// Instead of crudely Inverting A, we go for solve() using Decompositions.
// where:
// MatrixXd A=scatter;
// VectorXd b=mean_pos-mean_neg;
// VectorXd x=w;
Map<VectorXd> x(w.vector, num_feat);
ColPivHouseholderQR<MatrixXd> decomposition(scatter);
x=decomposition.solve(mean_pos-mean_neg);
MatrixXd scatter_inv=decomposition.inverse();

// get the weights w_neg(for -ve class) and w_pos(for +ve class)
VectorXd w_neg=scatter_inv*mean_neg;
VectorXd w_pos=scatter_inv*mean_pos;

// get the bias.
bias=0.5*(w_neg.dot(mean_neg)-w_pos.dot(mean_pos));
return true;
}
#endif
#endif//HAVE_EIGEN3
4 changes: 2 additions & 2 deletions tests/unit/classifier/LDA_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <shogun/classifier/LDA.h>
#include <gtest/gtest.h>

#ifdef HAVE_LAPACK
#ifdef HAVE_EIGEN3
using namespace shogun;
class LDATest: public::testing::Test
{
Expand Down Expand Up @@ -133,4 +133,4 @@ TEST_F(LDATest, CheckProjection)
EXPECT_NEAR(+271.14418463, projection[8], epsilon);
EXPECT_NEAR(+291.21213655, projection[9], epsilon);
}
#endif //HAVE_LAPACK
#endif //HAVE_EIGEN3

0 comments on commit 046749d

Please sign in to comment.