Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement QR solver in linalg library #3719

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/shogun/mathematics/linalg/LinalgBackendBase.h
Expand Up @@ -45,6 +45,22 @@
namespace shogun
{

namespace linalg
{

/**
* @brief
* Type used to choose the pivoting in algorithms
* that make use of QR decomposition
*/
enum class QRDecompositionPivoting {
None,
Column,
Full
};

}

/** @brief Base interface of generic linalg methods
* and generic memory transfer methods.
*/
Expand Down Expand Up @@ -132,6 +148,21 @@ class LinalgBackendBase
DEFINE_FOR_NON_INTEGER_PTYPE(BACKEND_GENERIC_CHOLESKY_SOLVER, SGMatrix)
#undef BACKEND_GENERIC_CHOLESKY_SOLVER

/**
* Wrapper solver with QR decomposition.
*
* @see linalg::qr_solver
*/
#define BACKEND_GENERIC_QR_SOLVER(Type, Container) \
virtual Container<Type> qr_solver(const Container<Type>& A, const Container<Type>& b, \
linalg::QRDecompositionPivoting pivoting = linalg::QRDecompositionPivoting::None) const \
{ \
SG_SNOTIMPLEMENTED; \
return 0; \
}
DEFINE_FOR_NON_INTEGER_PTYPE(BACKEND_GENERIC_QR_SOLVER, SGMatrix)
#undef BACKEND_GENERIC_QR_SOLVER

/**
* Wrapper method of vector dot-product that works with generic vectors.
*
Expand Down
44 changes: 44 additions & 0 deletions src/shogun/mathematics/linalg/LinalgBackendEigen.h
Expand Up @@ -127,6 +127,16 @@ class LinalgBackendEigen : public LinalgBackendBase
DEFINE_FOR_NON_INTEGER_PTYPE(BACKEND_GENERIC_CHOLESKY_SOLVER, SGMatrix)
#undef BACKEND_GENERIC_CHOLESKY_SOLVER

/** Implementation of @see LinalgBackendBase::qr_solver */
#define BACKEND_GENERIC_QR_SOLVER(Type, Container) \
virtual Container<Type> qr_solver(const Container<Type>& A, \
const Container<Type>& b, linalg::QRDecompositionPivoting pivoting) const \
{ \
return qr_solver_impl(A, b, pivoting); \
}
DEFINE_FOR_NON_INTEGER_PTYPE(BACKEND_GENERIC_QR_SOLVER, SGMatrix)
#undef BACKEND_GENERIC_QR_SOLVER

/** Implementation of @see LinalgBackendBase::dot */
#define BACKEND_GENERIC_DOT(Type, Container) \
virtual Type dot(const Container<Type>& a, const Container<Type>& b) const \
Expand Down Expand Up @@ -381,6 +391,40 @@ class LinalgBackendEigen : public LinalgBackendBase
return x;
}

/** Eigen3 QR solver */
template <typename T>
SGMatrix<T> qr_solver_impl(const SGMatrix<T>& A,
const SGMatrix<T>& b, linalg::QRDecompositionPivoting pivoting) const
{
SGMatrix<T> x = SGMatrix<T>(b.num_rows, b.num_cols);
x.zero();

typename SGMatrix<T>::EigenMatrixXtMap A_eig = A;
typename SGMatrix<T>::EigenMatrixXtMap b_eig = b;
typename SGMatrix<T>::EigenMatrixXtMap x_eig = x;

using linalg::QRDecompositionPivoting;
switch (pivoting) {
case QRDecompositionPivoting::None: {
auto qr = Eigen::HouseholderQR<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> >(A_eig);
x_eig = qr.solve(b_eig);
break;
}
case QRDecompositionPivoting::Column: {
auto qr = Eigen::ColPivHouseholderQR<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> >(A_eig);
x_eig = qr.solve(b_eig);
break;
}
case QRDecompositionPivoting::Full: {
auto qr = Eigen::FullPivHouseholderQR<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> >(A_eig);
x_eig = qr.solve(b_eig);
break;
}
}

return x;
}

/** Eigen3 vector dot-product method */
template <typename T>
T dot_impl(const SGVector<T>& a, const SGVector<T>& b) const
Expand Down
22 changes: 22 additions & 0 deletions src/shogun/mathematics/linalg/LinalgNamespace.h
Expand Up @@ -381,6 +381,28 @@ SGVector<T> cholesky_solver(const SGMatrix<T>& L, const SGVector<T>& b,
return infer_backend(L, SGMatrix<T>(b))->cholesky_solver(L, b, lower);
}

/**
* Solve the linear equations \f$Ax=b\f$ by means of the
* QR factorization of \f$A\f$
*
* @param A Square matrix
* @param b Right-hand side matrix
* @param p Pivoting type in QR decomposition
* @return \f$\x\f$
*/
template <typename T>
SGMatrix<T> qr_solver(const SGMatrix<T>& A, const SGMatrix<T>& b,
QRDecompositionPivoting pivoting = QRDecompositionPivoting::None)
{
REQUIRE(A.num_rows == A.num_cols, "Matrix is not square!\n");
REQUIRE((A.num_cols == b.num_rows),
"Number of columns of matrix A (%d) must match \
number of rows of matrix b (%d).\n",
A.num_cols, b.num_rows);

return infer_backend(A, b)->qr_solver(A, b, pivoting);
}

/**
* Vector dot-product that works with generic vectors.
*
Expand Down
10 changes: 7 additions & 3 deletions src/shogun/preprocessor/FisherLDA.cpp
Expand Up @@ -42,6 +42,7 @@
#include <shogun/preprocessor/FisherLDA.h>
#include <shogun/preprocessor/DimensionReductionPreprocessor.h>
#include <shogun/mathematics/eigen3.h>
#include <shogun/mathematics/linalg/LinalgNamespace.h>
#include <vector>

using namespace std;
Expand Down Expand Up @@ -233,10 +234,10 @@ bool CFisherLDA::fit(CFeatures *features, CLabels *labels, int32_t num_dimension
else
{
// For holding the within class scatter.
MatrixXd Sw=fmatrix*fmatrix.transpose();
SGMatrix<float64_t>::EigenMatrixXt Sw=fmatrix*fmatrix.transpose();

// For holding the between class scatter.
MatrixXd Sb(num_features, C);
SGMatrix<float64_t>::EigenMatrixXt Sb(num_features, C);

for (i=0; i<C; i++)
Sb.col(i)=mean_class[i];
Expand All @@ -250,7 +251,10 @@ bool CFisherLDA::fit(CFeatures *features, CLabels *labels, int32_t num_dimension
// x=M
// MatrixXd M=Sw.householderQr().solve(Sb);
// calculate the eigenvalues and eigenvectors of M.
EigenSolver<MatrixXd> es(Sw.householderQr().solve(Sb));
SGMatrix<float64_t> M=linalg::qr_solver<float64_t>(Sw, Sb);

SGMatrix<float64_t>::EigenMatrixXtMap M_eig=M;
EigenSolver<MatrixXd> es(M_eig);

MatrixXd all_eigenvectors=es.eigenvectors().real();
VectorXd all_eigenvalues=es.eigenvalues().real();
Expand Down
Expand Up @@ -158,6 +158,32 @@ TEST(LinalgBackendEigen, SGMatrix_cholesky_solver)
EXPECT_EQ(x_ref.size(), x_cal.size());
}

TEST(LinalgBackendEigen, SGMatrix_qr_solver)
{
const index_t size=3, n_x=2;
SGMatrix<float64_t> m(size, size);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data/matrix generators should go into a generator class in unit testing :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes there really should be a patch for this soon

SGMatrix<float64_t> b(size, n_x);

for (index_t i=0; i<size*size; ++i)
m[i]=i*i;
for (index_t i=0; i<size*n_x; ++i)
b[i]=i;

SGMatrix<float64_t> x_cal = qr_solver(m, b);

SGMatrix<float64_t> x_ref(size, n_x);
x_ref[0] = -0.25;
x_ref[1] = 0.33333333;
x_ref[2] = -0.08333333;
x_ref[3] = -0.08333333;
x_ref[4] = 0.0;
x_ref[5] = 0.08333333;

for (index_t i=0; i<size*n_x; ++i)
EXPECT_NEAR(x_ref[i], x_cal[i], 1E-8);
EXPECT_EQ(x_ref.size(), x_cal.size());
}

TEST(LinalgBackendEigen, SGVector_dot)
{
const index_t size = 3;
Expand Down