From 1f2ae251c5feda250e3d7d12894e6be3c09fd268 Mon Sep 17 00:00:00 2001 From: Justin Carpentier Date: Thu, 5 Aug 2021 10:25:04 +0200 Subject: [PATCH] cholesky: support solve with input matrices --- include/eigenpy/decompositions/LDLT.hpp | 17 ++++++++++------- include/eigenpy/decompositions/LLT.hpp | 17 ++++++++++------- unittest/python/test_LDLT.py | 7 ++++++- unittest/python/test_LLT.py | 8 ++++++-- 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/include/eigenpy/decompositions/LDLT.hpp b/include/eigenpy/decompositions/LDLT.hpp index 6c30ae645..1e4dbab8b 100644 --- a/include/eigenpy/decompositions/LDLT.hpp +++ b/include/eigenpy/decompositions/LDLT.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2020 INRIA + * Copyright 2020-2021 INRIA */ #ifndef __eigenpy_decomposition_ldlt_hpp__ @@ -23,7 +23,8 @@ namespace eigenpy typedef _MatrixType MatrixType; typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::RealScalar RealScalar; - typedef Eigen::Matrix VectorType; + typedef Eigen::Matrix VectorXs; + typedef Eigen::Matrix MatrixXs; typedef Eigen::LDLT Solver; template @@ -55,7 +56,7 @@ namespace eigenpy "Returns the LDLT decomposition matrix.", bp::return_internal_reference<>()) - .def("rankUpdate",(Solver & (Solver::*)(const Eigen::MatrixBase &, const RealScalar &))&Solver::template rankUpdate, + .def("rankUpdate",(Solver & (Solver::*)(const Eigen::MatrixBase &, const RealScalar &))&Solver::template rankUpdate, bp::args("self","vector","sigma"), bp::return_self<>()) @@ -78,8 +79,10 @@ namespace eigenpy #endif .def("reconstructedMatrix",&Solver::reconstructedMatrix,bp::arg("self"), "Returns the matrix represented by the decomposition, i.e., it returns the product: L L^*. This function is provided for debug purpose.") - .def("solve",&solve,bp::args("self","b"), + .def("solve",&solve,bp::args("self","b"), "Returns the solution x of A x = b using the current decomposition of A.") + .def("solve",&solve,bp::args("self","B"), + "Returns the solution X of A X = B using the current decomposition of A where B is a right hand side matrix.") .def("setZero",&Solver::setZero,bp::arg("self"), "Clear any existing decomposition.") @@ -107,7 +110,7 @@ namespace eigenpy static MatrixType matrixL(const Solver & self) { return self.matrixL(); } static MatrixType matrixU(const Solver & self) { return self.matrixU(); } - static VectorType vectorD(const Solver & self) { return self.vectorD(); } + static VectorXs vectorD(const Solver & self) { return self.vectorD(); } static MatrixType transpositionsP(const Solver & self) { @@ -115,8 +118,8 @@ namespace eigenpy self.matrixL().rows()); } - template - static VectorType solve(const Solver & self, const VectorType & vec) + template + static MatrixOrVector solve(const Solver & self, const MatrixOrVector & vec) { return self.solve(vec); } diff --git a/include/eigenpy/decompositions/LLT.hpp b/include/eigenpy/decompositions/LLT.hpp index ba8c2f2b7..a06e93724 100644 --- a/include/eigenpy/decompositions/LLT.hpp +++ b/include/eigenpy/decompositions/LLT.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2020 INRIA + * Copyright 2020-2021 INRIA */ #ifndef __eigenpy_decomposition_llt_hpp__ @@ -23,7 +23,8 @@ namespace eigenpy typedef _MatrixType MatrixType; typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::RealScalar RealScalar; - typedef Eigen::Matrix VectorType; + typedef Eigen::Matrix VectorXs; + typedef Eigen::Matrix MatrixXs; typedef Eigen::LLT Solver; template @@ -46,10 +47,10 @@ namespace eigenpy bp::return_internal_reference<>()) #if EIGEN_VERSION_AT_LEAST(3,3,90) - .def("rankUpdate",(Solver& (Solver::*)(const VectorType &, const RealScalar &))&Solver::template rankUpdate, + .def("rankUpdate",(Solver& (Solver::*)(const VectorXs &, const RealScalar &))&Solver::template rankUpdate, bp::args("self","vector","sigma"), bp::return_self<>()) #else - .def("rankUpdate",(Solver (Solver::*)(const VectorType &, const RealScalar &))&Solver::template rankUpdate, + .def("rankUpdate",(Solver (Solver::*)(const VectorXs &, const RealScalar &))&Solver::template rankUpdate, bp::args("self","vector","sigma")) #endif @@ -72,8 +73,10 @@ namespace eigenpy #endif .def("reconstructedMatrix",&Solver::reconstructedMatrix,bp::arg("self"), "Returns the matrix represented by the decomposition, i.e., it returns the product: L L^*. This function is provided for debug purpose.") - .def("solve",&solve,bp::args("self","b"), + .def("solve",&solve,bp::args("self","b"), "Returns the solution x of A x = b using the current decomposition of A.") + .def("solve",&solve,bp::args("self","B"), + "Returns the solution X of A X = B using the current decomposition of A where B is a right hand side matrix.") ; } @@ -99,8 +102,8 @@ namespace eigenpy static MatrixType matrixL(const Solver & self) { return self.matrixL(); } static MatrixType matrixU(const Solver & self) { return self.matrixU(); } - template - static VectorType solve(const Solver & self, const VectorType & vec) + template + static MatrixOrVector solve(const Solver & self, const MatrixOrVector & vec) { return self.solve(vec); } diff --git a/unittest/python/test_LDLT.py b/unittest/python/test_LDLT.py index 7937aaec6..1cdb930c7 100644 --- a/unittest/python/test_LDLT.py +++ b/unittest/python/test_LDLT.py @@ -1,5 +1,4 @@ import eigenpy -eigenpy.switchToNumpyArray() import numpy as np import numpy.linalg as la @@ -16,3 +15,9 @@ P = ldlt.transpositionsP() assert eigenpy.is_approx(np.transpose(P).dot(L.dot(np.diag(D).dot(np.transpose(L).dot(P)))),A) + +X = np.random.rand(dim,20) +B = A.dot(X) +X_est = ldlt.solve(B) +assert eigenpy.is_approx(X,X_est) +assert eigenpy.is_approx(A.dot(X_est),B) diff --git a/unittest/python/test_LLT.py b/unittest/python/test_LLT.py index 4ad38fc21..2aac689ae 100644 --- a/unittest/python/test_LLT.py +++ b/unittest/python/test_LLT.py @@ -1,5 +1,4 @@ import eigenpy -eigenpy.switchToNumpyArray() import numpy as np import numpy.linalg as la @@ -12,5 +11,10 @@ llt = eigenpy.LLT(A) L = llt.matrixL() - assert eigenpy.is_approx(L.dot(np.transpose(L)),A) + +X = np.random.rand(dim,20) +B = A.dot(X) +X_est = llt.solve(B) +assert eigenpy.is_approx(X,X_est) +assert eigenpy.is_approx(A.dot(X_est),B)