Skip to content

Commit

Permalink
LinearSolver base, direct solver implementation added (log-det)
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday committed Jun 26, 2013
1 parent 2aa7804 commit 84258f4
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 0 deletions.
93 changes: 93 additions & 0 deletions src/shogun/mathematics/logdet/DirectLinearSolverComplex.cpp
@@ -0,0 +1,93 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General turalPublic License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2013 Soumyajit De
*/

#include <shogun/lib/config.h>

#ifdef HAVE_EIGEN3
#include <shogun/lib/SGVector.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/mathematics/eigen3.h>
#include <shogun/mathematics/logdet/DenseMatrixOperator.h>
#include <shogun/mathematics/logdet/DirectLinearSolverComplex.h>

using namespace Eigen;

namespace shogun
{

CDirectLinearSolverComplex::CDirectLinearSolverComplex()
: CLinearSolver<complex64_t>(),
m_type(DS_QR_NOPERM)
{
SG_GCDEBUG("%s created (%p)\n", this->get_name(), this)
}

CDirectLinearSolverComplex::CDirectLinearSolverComplex(EDirectSolverType type)
: CLinearSolver<complex64_t>(),
m_type(type)
{
SG_GCDEBUG("%s created (%p)\n", this->get_name(), this)
}

CDirectLinearSolverComplex::~CDirectLinearSolverComplex()
{
SG_GCDEBUG("%s destroyed (%p)\n", this->get_name(), this)
}

SGVector<complex64_t> CDirectLinearSolverComplex::solve(
CLinearOperator<complex64_t>* A, SGVector<complex64_t> b)
{
SG_DEBUG("Entering..\n");

SGVector<complex64_t> x(b.vlen);

REQUIRE(A, "Operator is NULL!\n");
REQUIRE(A->get_dim()==b.vlen, "Dimension mismatch!\n");

CDenseMatrixOperator<complex64_t> *op=
dynamic_cast<CDenseMatrixOperator<complex64_t>*>(A);
REQUIRE(op, "Operator is not CDenseMatrixOperator<complex64_t> type!\n");

SGMatrix<complex64_t> mat_A=op->get_matrix_operator();
Map<MatrixXcd> map_A(mat_A.matrix, mat_A.num_rows, mat_A.num_cols);
Map<VectorXcd> map_b(b.vector, b.vlen);
Map<VectorXcd> map_x(x.vector, x.vlen);

// rank checking for LLT
FullPivLU<MatrixXcd> lu(map_A);
bool full_rank=lu.rank()==mat_A.num_cols;

switch (m_type)
{
case DS_LLT:
if (full_rank)
map_x=map_A.llt().solve(map_b);
else
SG_WARNING("Matrix is not full rank!\n");
break;
case DS_QR_NOPERM:
map_x=map_A.householderQr().solve(map_b);
break;
case DS_QR_COLPERM:
map_x=map_A.colPivHouseholderQr().solve(map_b);
break;
case DS_QR_FULLPERM:
map_x=map_A.fullPivHouseholderQr().solve(map_b);
break;
case DS_SVD:
map_x=map_A.jacobiSvd(ComputeThinU|ComputeThinV).solve(map_b);
break;
};

SG_DEBUG("Leaving..\n");
return x;
}

}
#endif // HAVE_EIGEN3
74 changes: 74 additions & 0 deletions src/shogun/mathematics/logdet/DirectLinearSolverComplex.h
@@ -0,0 +1,74 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General turalPublic License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2013 Soumyajit De
*/
#ifndef DIRECT_LINEAR_SOLVER_COMPLEX_H_
#define DIRECT_LINEAR_SOLVER_COMPLEX_H_

#include <shogun/lib/config.h>

#ifdef HAVE_EIGEN3
#include <shogun/mathematics/logdet/LinearSolver.h>

namespace shogun
{

/** solver type for direct solvers */
enum EDirectSolverType
{
DS_LLT=0,
DS_QR_NOPERM=1,
DS_QR_COLPERM=2,
DS_QR_FULLPERM=3,
DS_SVD=4
};

/** @brief Class that provides a solve method for complex dense-matrix
* linear systems
*/
class CDirectLinearSolverComplex : public CLinearSolver<complex64_t>
{
public:
/** default constructor */
CDirectLinearSolverComplex();

/**
* constructor
*
* @param type the type of solver to be used in solve method
*/
CDirectLinearSolverComplex(EDirectSolverType type);

/** destructor */
virtual ~CDirectLinearSolverComplex();

/**
* solve method for solving complex linear systems
*
* @param A the linear operator of the system
* @param b the vector of the system
* @return the solution vector
*/
virtual SGVector<complex64_t> solve(
CLinearOperator<complex64_t>* A, SGVector<complex64_t> b);

/** @return object name */
virtual const char* get_name() const
{
return "CDirectLinearSolverComplex";
}

private:
/** the type of solver to be used in solve method */
const EDirectSolverType m_type;

};

}

#endif // HAVE_EIGEN3
#endif // DIRECT_LINEAR_SOLVER_COMPLEX_H_
73 changes: 73 additions & 0 deletions src/shogun/mathematics/logdet/LinearSolver.h
@@ -0,0 +1,73 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2013 Soumyajit De
*/
#ifndef LINEAR_SOLVER_H_
#define LINEAR_SOLVER_H_

#include <shogun/lib/config.h>

namespace shogun
{
template<class T> class SGVector;
template<class T> class CLinearOperator;

/** @brief Abstract template base class that provides an abstract solve method
* for linear systems, that takes a linear operator \f$A\f$, a vector \f$b\f$,
* solves the system \f$Ax=b\f$ and returns the vector \f$x\f$.
*/
template<class T> class CLinearSolver : public CSGObject
{
public:
/** default constructor */
CLinearSolver()
: CSGObject()
{
SG_GCDEBUG("%s created (%p)\n", this->get_name(), this)
}

/** destructor */
virtual ~CLinearSolver()
{
SG_GCDEBUG("%s destroyed (%p)\n", this->get_name(), this)
}

/**
* abstract solve method for solving linear systems
*
* @param A the linear operator of the system
* @param b the vector of the system
* @return the solution vector
*/
virtual SGVector<T> solve(CLinearOperator<T>* A, SGVector<T> b) = 0;

/** @return object name */
virtual const char* get_name() const
{
return "CLinearSolver";
}

};

template class CLinearSolver<bool>;
template class CLinearSolver<char>;
template class CLinearSolver<int8_t>;
template class CLinearSolver<uint8_t>;
template class CLinearSolver<int16_t>;
template class CLinearSolver<uint16_t>;
template class CLinearSolver<int32_t>;
template class CLinearSolver<uint32_t>;
template class CLinearSolver<int64_t>;
template class CLinearSolver<uint64_t>;
template class CLinearSolver<float32_t>;
template class CLinearSolver<float64_t>;
template class CLinearSolver<floatmax_t>;
template class CLinearSolver<complex64_t>;

}

#endif // LINEAR_SOLVER_H_
127 changes: 127 additions & 0 deletions tests/unit/mathematics/logdet/DirectLinearSolverComplex_unittest.cc
@@ -0,0 +1,127 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2013 Soumyajit De
*/

#include <shogun/lib/config.h>

#ifdef HAVE_EIGEN3
#include <shogun/lib/SGVector.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/mathematics/eigen3.h>
#include <shogun/mathematics/logdet/DenseMatrixOperator.h>
#include <shogun/mathematics/logdet/DirectLinearSolverComplex.h>
#include <gtest/gtest.h>

using namespace shogun;
using namespace Eigen;

TEST(DirectLinearSolverComplex, solve_SVD)
{
const index_t size=2;
SGMatrix<complex64_t> m(size, size);
m(0,0)=complex64_t(2.0);
m(0,1)=complex64_t(1.0, 2.0);
m(1,0)=complex64_t(1.0, 2.0);
m(1,1)=complex64_t(3.0);

CDenseMatrixOperator<complex64_t>* A=new CDenseMatrixOperator<complex64_t>(m);
SG_REF(A);
SGVector<complex64_t> b(size);
b.set_const(complex64_t(1.0));

CDirectLinearSolverComplex solver(DS_SVD);
SGVector<complex64_t> x=solver.solve((CLinearOperator<complex64_t>*)A, b);

SGVector<complex64_t> bp=A->apply(x);
Map<VectorXcd> map_b(b.vector, b.vlen);
Map<VectorXcd> map_bp(bp.vector, bp.vlen);

EXPECT_NEAR((map_b-map_bp).norm()/map_b.norm(), 0.0, 1E-15);

SG_UNREF(A);
}

TEST(DirectLinearSolverComplex, solve_QR_NOPERM)
{
const index_t size=2;
SGMatrix<complex64_t> m(size, size);
m(0,0)=complex64_t(2.0);
m(0,1)=complex64_t(1.0, 2.0);
m(1,0)=complex64_t(1.0, 2.0);
m(1,1)=complex64_t(3.0);

CDenseMatrixOperator<complex64_t>* A=new CDenseMatrixOperator<complex64_t>(m);
SG_REF(A);
SGVector<complex64_t> b(size);
b.set_const(complex64_t(1.0));

CDirectLinearSolverComplex solver(DS_QR_NOPERM);
SGVector<complex64_t> x=solver.solve((CLinearOperator<complex64_t>*)A, b);

SGVector<complex64_t> bp=A->apply(x);
Map<VectorXcd> map_b(b.vector, b.vlen);
Map<VectorXcd> map_bp(bp.vector, bp.vlen);

EXPECT_NEAR((map_b-map_bp).norm()/map_b.norm(), 0.0, 1E-15);

SG_UNREF(A);
}

TEST(DirectLinearSolverComplex, solve_QR_COLPERM)
{
const index_t size=2;
SGMatrix<complex64_t> m(size, size);
m(0,0)=complex64_t(2.0);
m(0,1)=complex64_t(1.0, 2.0);
m(1,0)=complex64_t(1.0, 2.0);
m(1,1)=complex64_t(3.0);

CDenseMatrixOperator<complex64_t>* A=new CDenseMatrixOperator<complex64_t>(m);
SG_REF(A);
SGVector<complex64_t> b(size);
b.set_const(complex64_t(1.0));

CDirectLinearSolverComplex solver(DS_QR_COLPERM);
SGVector<complex64_t> x=solver.solve((CLinearOperator<complex64_t>*)A, b);

SGVector<complex64_t> bp=A->apply(x);
Map<VectorXcd> map_b(b.vector, b.vlen);
Map<VectorXcd> map_bp(bp.vector, bp.vlen);

EXPECT_NEAR((map_b-map_bp).norm()/map_b.norm(), 0.0, 1E-15);

SG_UNREF(A);
}

TEST(DirectLinearSolverComplex, solve_QR_FULLPERM)
{
const index_t size=2;
SGMatrix<complex64_t> m(size, size);
m(0,0)=complex64_t(2.0);
m(0,1)=complex64_t(1.0, 2.0);
m(1,0)=complex64_t(1.0, 2.0);
m(1,1)=complex64_t(3.0);

CDenseMatrixOperator<complex64_t>* A=new CDenseMatrixOperator<complex64_t>(m);
SG_REF(A);
SGVector<complex64_t> b(size);
b.set_const(complex64_t(1.0));

CDirectLinearSolverComplex solver(DS_QR_FULLPERM);
SGVector<complex64_t> x=solver.solve((CLinearOperator<complex64_t>*)A, b);

SGVector<complex64_t> bp=A->apply(x);
Map<VectorXcd> map_b(b.vector, b.vlen);
Map<VectorXcd> map_bp(bp.vector, bp.vlen);

EXPECT_NEAR((map_b-map_bp).norm()/map_b.norm(), 0.0, 1E-15);

SG_UNREF(A);
}

#endif //HAVE_EIGEN3

0 comments on commit 84258f4

Please sign in to comment.