Skip to content

Commit

Permalink
Merge pull request #150 from wolfv/fix_svd_horizontal_vertical
Browse files Browse the repository at this point in the history
fix svd for horizontal and vertical
  • Loading branch information
JohanMabille committed Mar 19, 2020
2 parents 0ab6257 + 327b385 commit c579f74
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 22 deletions.
19 changes: 18 additions & 1 deletion include/xtensor-blas/xblas_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,26 @@ namespace xt
std::false_type is_xfunction_impl(...);
}

template<typename T>
template<class T>
constexpr bool is_xfunction(T&& t) {
return decltype(detail::is_xfunction_impl(t))::value;
}

/***********************************
* assert_nd_square implementation *
***********************************/

template <class T>
#if !defined(_MSC_VER) || _MSC_VER >= 1910
constexpr
#endif
void assert_nd_square(const xexpression<T>& t)
{
auto& dt = t.derived_cast();
if (dt.shape()[dt.dimension() - 1] != dt.shape()[dt.dimension() - 2])
{
throw std::runtime_error("Last 2 dimensions of the array must be square.");
}
}
}
#endif
47 changes: 29 additions & 18 deletions include/xtensor-blas/xlapack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,15 @@ namespace lapack
n = static_cast<blas_index_t>(A.shape()[1]);
}

blas_index_t m = static_cast<blas_index_t>(A.shape()[0]);
blas_index_t a_stride = std::max(blas_index_t(1), m);

int info = cxxlapack::orgqr<blas_index_t>(
static_cast<blas_index_t>(A.shape()[0]),
m,
n,
static_cast<blas_index_t>(tau.size()),
A.data(),
stride_back(A),
a_stride,
tau.data(),
work.data(),
static_cast<blas_index_t>(-1)
Expand All @@ -108,11 +111,11 @@ namespace lapack
work.resize(static_cast<std::size_t>(work[0]));

info = cxxlapack::orgqr<blas_index_t>(
static_cast<blas_index_t>(A.shape()[0]),
m,
n,
static_cast<blas_index_t>(tau.size()),
A.data(),
stride_back(A),
a_stride,
tau.data(),
work.data(),
static_cast<blas_index_t>(work.size())
Expand All @@ -133,12 +136,15 @@ namespace lapack
n = static_cast<blas_index_t>(A.shape()[1]);
}

blas_index_t m = static_cast<blas_index_t>(A.shape()[0]);
blas_index_t a_stride = std::max(blas_index_t(1), m);

int info = cxxlapack::ungqr<blas_index_t>(
static_cast<blas_index_t>(A.shape()[0]),
m,
n,
static_cast<blas_index_t>(tau.size()),
A.data(),
stride_back(A),
a_stride,
tau.data(),
work.data(),
static_cast<blas_index_t>(-1)
Expand All @@ -152,11 +158,11 @@ namespace lapack
work.resize(static_cast<std::size_t>(std::real(work[0])));

info = cxxlapack::ungqr<blas_index_t>(
static_cast<blas_index_t>(A.shape()[0]),
m,
n,
static_cast<blas_index_t>(tau.size()),
A.data(),
stride_back(A),
a_stride,
tau.data(),
work.data(),
static_cast<blas_index_t>(work.size())
Expand All @@ -174,12 +180,14 @@ namespace lapack
XTENSOR_ASSERT(A.layout() == layout_type::column_major);

uvector<value_type> work(1);
blas_index_t m = static_cast<blas_index_t>(A.shape()[0]);
blas_index_t a_stride = std::max(blas_index_t(1), m);

int info = cxxlapack::geqrf<blas_index_t>(
static_cast<blas_index_t>(A.shape()[0]),
m,
static_cast<blas_index_t>(A.shape()[1]),
A.data(),
stride_back(A),
a_stride,
tau.data(),
work.data(),
static_cast<blas_index_t>(-1)
Expand All @@ -193,10 +201,10 @@ namespace lapack
work.resize(static_cast<std::size_t>(std::real(work[0])));

info = cxxlapack::geqrf<blas_index_t>(
static_cast<blas_index_t>(A.shape()[0]),
m,
static_cast<blas_index_t>(A.shape()[1]),
A.data(),
stride_back(A),
a_stride,
tau.data(),
work.data(),
static_cast<blas_index_t>(work.size())
Expand Down Expand Up @@ -241,7 +249,9 @@ namespace lapack
return m >= n ? std::make_pair(1, stride_back(vt)) :
std::make_pair(stride_back(u), 1);
}
return std::make_pair(stride_back(u), stride_back(vt));

return std::make_pair(std::max(blas_index_t(u.shape()[0]), 1),
std::max(blas_index_t(vt.shape()[0]), 1));
}
}

Expand Down Expand Up @@ -269,13 +279,14 @@ namespace lapack
std::tie(u_stride, vt_stride) = detail::init_u_vt(u, vt, jobz, m, n);

uvector<blas_index_t> iwork(8 * std::min(m, n));
blas_index_t a_stride = static_cast<blas_index_t>(std::max(std::size_t(1), m));

int info = cxxlapack::gesdd<blas_index_t>(
jobz,
static_cast<blas_index_t>(A.shape()[0]),
static_cast<blas_index_t>(A.shape()[1]),
A.data(),
stride_back(A),
a_stride,
s.data(),
u.data(),
u_stride,
Expand All @@ -292,13 +303,12 @@ namespace lapack
}

work.resize(static_cast<std::size_t>(work[0]));

info = cxxlapack::gesdd<blas_index_t>(
jobz,
static_cast<blas_index_t>(A.shape()[0]),
static_cast<blas_index_t>(A.shape()[1]),
A.data(),
stride_back(A),
a_stride,
s.data(),
u.data(),
u_stride,
Expand Down Expand Up @@ -355,13 +365,14 @@ namespace lapack

blas_index_t u_stride, vt_stride;
std::tie(u_stride, vt_stride) = detail::init_u_vt(u, vt, jobz, m, n);
blas_index_t a_stride = static_cast<blas_index_t>(std::max(std::size_t(1), m));

int info = cxxlapack::gesdd<blas_index_t>(
jobz,
static_cast<blas_index_t>(A.shape()[0]),
static_cast<blas_index_t>(A.shape()[1]),
A.data(),
stride_back(A),
a_stride,
s.data(),
u.data(),
u_stride,
Expand All @@ -384,7 +395,7 @@ namespace lapack
static_cast<blas_index_t>(A.shape()[0]),
static_cast<blas_index_t>(A.shape()[1]),
A.data(),
stride_back(A),
a_stride,
s.data(),
u.data(),
u_stride,
Expand Down
20 changes: 18 additions & 2 deletions include/xtensor-blas/xlinalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ namespace linalg
template <class E1, class E2>
auto solve(const xexpression<E1>& A, const xexpression<E2>& b)
{
assert_nd_square(A);
auto dA = copy_to_layout<layout_type::column_major>(A.derived_cast());
auto db = copy_to_layout<layout_type::column_major>(b.derived_cast());

Expand All @@ -248,6 +249,7 @@ namespace linalg
template <class E1>
auto inv(const xexpression<E1>& A)
{
assert_nd_square(A);
auto dA = copy_to_layout<layout_type::column_major>(A.derived_cast());

uvector<blas_index_t> piv(std::min(dA.shape()[0], dA.shape()[1]));
Expand Down Expand Up @@ -299,6 +301,7 @@ namespace linalg
using underlying_type = typename E::value_type;
using value_type = typename E::value_type;

assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());

std::size_t N = M.shape()[0];
Expand Down Expand Up @@ -348,6 +351,7 @@ namespace linalg
{
using value_type = typename E::value_type;

assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());

std::size_t N = M.shape()[0];
Expand Down Expand Up @@ -380,6 +384,7 @@ namespace linalg
{
using value_type = typename E::value_type;

assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());

std::size_t N = M.shape()[0];
Expand All @@ -401,6 +406,7 @@ namespace linalg
using value_type = typename E::value_type;
using underlying_value_type = typename value_type::value_type;

assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());

std::size_t N = M.shape()[0];
Expand All @@ -424,10 +430,11 @@ namespace linalg
* @return xtensor containing the eigenvalues.
*/
template <class E, std::enable_if_t<!xtl::is_complex<typename E::value_type>::value>* = nullptr>
auto eigh(const xexpression<E>& A, const xexpression<E>& B,const char UPLO = 'L')
auto eigh(const xexpression<E>& A, const xexpression<E>& B, const char UPLO = 'L')
{
using value_type = typename E::value_type;

assert_nd_square(A);
auto M1 = copy_to_layout<layout_type::column_major>(A.derived_cast());
auto M2 = copy_to_layout<layout_type::column_major>(B.derived_cast());

Expand Down Expand Up @@ -478,6 +485,7 @@ namespace linalg
{
using value_type = typename E::value_type;

assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());

std::size_t N = M.shape()[0];
Expand Down Expand Up @@ -511,6 +519,7 @@ namespace linalg
{
using value_type = typename E::value_type;

assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());

std::size_t N = M.shape()[0];
Expand Down Expand Up @@ -545,6 +554,7 @@ namespace linalg
{
using value_type = typename E::value_type;

assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());

std::size_t N = M.shape()[0];
Expand All @@ -566,6 +576,7 @@ namespace linalg
using value_type = typename E::value_type;
using underlying_value_type = typename value_type::value_type;

assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());

std::size_t N = M.shape()[0];
Expand Down Expand Up @@ -989,8 +1000,9 @@ namespace linalg
auto det(const xexpression<T>& A)
{
using value_type = typename T::value_type;
xtensor<value_type, 2, layout_type::column_major> LU = A.derived_cast();
assert_nd_square(A);

xtensor<value_type, 2, layout_type::column_major> LU = A.derived_cast();
uvector<blas_index_t> piv(std::min(LU.shape()[0], LU.shape()[1]));

lapack::getrf(LU, piv);
Expand Down Expand Up @@ -1025,6 +1037,7 @@ namespace linalg
auto slogdet(const xexpression<T>& A)
{
using value_type = typename T::value_type;
assert_nd_square(A);

xtensor<value_type, 2, layout_type::column_major> LU = A.derived_cast();
uvector<blas_index_t> piv(std::min(LU.shape()[0], LU.shape()[1]));
Expand Down Expand Up @@ -1059,6 +1072,8 @@ namespace linalg
auto slogdet(const xexpression<T>& A)
{
using value_type = typename T::value_type;
assert_nd_square(A);

xtensor<value_type, 2, layout_type::column_major> LU = A.derived_cast();
uvector<blas_index_t> piv(std::min(LU.shape()[0], LU.shape()[1]));

Expand Down Expand Up @@ -1214,6 +1229,7 @@ namespace linalg
template <class T>
auto cholesky(const xexpression<T>& A)
{
assert_nd_square(A);
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());

int info = lapack::potr(M, 'L');
Expand Down
36 changes: 35 additions & 1 deletion test/test_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,19 @@ namespace xt
EXPECT_TRUE(allclose(std::get<2>(res), expected_2));
}

TEST(xlinalg, svd_horizontal_vertical)
{
xarray<double> a = xt::ones<double>({3, 1});
xarray<double> b = xt::ones<double>({1, 3});
xarray<double> u, s, vt;

std::tie(u, s, vt) = linalg::svd(a, false);
EXPECT_TRUE(allclose(a, xt::linalg::dot(u * s, vt)));

std::tie(u, s, vt) = linalg::svd(b, false);
EXPECT_TRUE(allclose(b, xt::linalg::dot(u * s, vt)));
}

TEST(xlinalg, matrix_rank)
{
xarray<double> eall = eye<double>(4);
Expand Down Expand Up @@ -590,8 +603,29 @@ namespace xt

auto res = xt::linalg::dot(A1, A2);
EXPECT_EQ(res(), 94);
}


TEST(xlinalg, asserts)
{
EXPECT_THROW(xt::linalg::eigh(xt::ones<double>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::eig(xt::ones<double>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::solve(xt::ones<double>({3, 1}), xt::ones<double>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::inv(xt::ones<double>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::eigvals(xt::ones<double>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::eigvalsh(xt::ones<double>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::det(xt::ones<double>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::slogdet(xt::ones<double>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::cholesky(xt::ones<double>({3, 1})), std::runtime_error);

EXPECT_THROW(xt::linalg::eigh(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::eig(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::solve(xt::ones<std::complex<double>>({3, 1}), xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::inv(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::eigvals(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::eigvalsh(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::det(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::slogdet(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
EXPECT_THROW(xt::linalg::cholesky(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
}

}

0 comments on commit c579f74

Please sign in to comment.