Skip to content

Commit

Permalink
Merge pull request #92 from wolfv/fix_qr_lstsq
Browse files Browse the repository at this point in the history
fixes for lstsq and qr + more and generated tests
  • Loading branch information
JohanMabille committed Nov 19, 2018
2 parents 5d2cb7d + 6799920 commit 3414dd2
Show file tree
Hide file tree
Showing 8 changed files with 1,069 additions and 75 deletions.
6 changes: 5 additions & 1 deletion include/xtensor-blas/xlapack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,11 @@ namespace lapack
uvector<blas_index_t> iwork(1);

blas_index_t b_dim = b.dimension() > 1 ? static_cast<blas_index_t>(b.shape().back()) : 1;
blas_index_t b_stride = b_dim == 1 ? static_cast<blas_index_t>(b.shape().front()) : stride_back(b);

std::size_t m = A.shape()[0];
std::size_t n = A.shape()[1];

blas_index_t b_stride = (blas_index_t) std::max(std::max(std::size_t(1), m), n);

int info = cxxlapack::gelsd<blas_index_t>(
static_cast<blas_index_t>(A.shape()[0]),
Expand Down
120 changes: 80 additions & 40 deletions include/xtensor-blas/xlinalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,21 @@ namespace linalg
}
}

namespace xblas_detail
{
template <class T>
inline void triu_inplace(T& R)
{
for (std::size_t i = 0; i < R.shape()[0]; ++i)
{
for (std::size_t j = 0; j < i && j < R.shape()[1]; j++)
{
R(i, j) = 0;
}
}
}
}

/// Select the mode for the qr decomposition ``K = min(M, K)``
enum class qrmode {
reduced, ///< return Q, R with dimensions (M, K), (K, N) (default)
Expand All @@ -1079,60 +1094,59 @@ namespace linalg
auto qr(const xexpression<T>& A, qrmode mode = qrmode::reduced)
{
using value_type = typename T::value_type;
using xtype = xtensor<value_type, 2, layout_type::column_major>;
using result_xtype = xtensor<value_type, 2>;
using xtype = xarray<value_type, layout_type::column_major>;

xtype R = A.derived_cast();

std::size_t M = R.shape()[0];
std::size_t N = R.shape()[1];
std::size_t K = std::min(M, N);

std::array<std::size_t, 2> tau_shp = {K, 1};
xtype tau(tau_shp);
auto tau = xarray<value_type, layout_type::column_major>::from_shape({K});
int info = lapack::geqrf(R, tau);

if (info != 0)
{
throw std::runtime_error("QR decomposition failed.");
}

xtype Q;
// explicitly set shape/size == 0!
auto Q = xtype::from_shape({0});

if (mode == qrmode::raw)
if (mode == qrmode::r)
{
return std::make_tuple(R, tau);
R = xt::view(R, range(0, K), all());
xblas_detail::triu_inplace(R);
return std::make_tuple(Q, R);
}

if (mode == qrmode::reduced)
if (mode == qrmode::raw)
{
Q = R;
detail::call_gqr(Q, tau, static_cast<blas_index_t>(K));
auto vR = view(R, range(std::size_t(0), K), all());
R = vR;
R = transpose(R);
return std::make_tuple(R, tau);
}
if (mode == qrmode::complete)

blas_index_t mc;

if (mode == qrmode::complete && M > N)
{
mc = (blas_index_t) M;
Q.resize({M, M});
// TODO replace with assignment to view
for (std::size_t i = 0; i < R.shape()[0]; ++i)
{
for (std::size_t j = 0; j < R.shape()[1]; ++j)
{
Q(i, j) = R(i, j);
}
}
detail::call_gqr(Q, tau, static_cast<blas_index_t>(M));
}

for (std::size_t i = 0; i < R.shape()[0]; ++i)
else
{
for (std::size_t j = 0; j < i && j < R.shape()[1]; j++)
{
R(i, j) = 0;
}
mc = (blas_index_t) K;
Q.resize({M, N});
}

xt::view(Q, all(), range(0, N)) = R;
detail::call_gqr(Q, tau, mc);

Q = xt::view(Q, all(), range(0, mc));
R = xt::view(R, range(0, mc), all());

xblas_detail::triu_inplace(R);

return std::make_tuple(Q, R);
}

Expand Down Expand Up @@ -1420,22 +1434,44 @@ namespace linalg
using underlying_value_type = xtl::complex_value_type_t<typename T::value_type>;

xtensor<value_type, 2, layout_type::column_major> dA = A.derived_cast();
xtensor<value_type, 2, layout_type::column_major> db;

const auto& db_t = b.derived_cast();
if (db_t.dimension() == 1)
std::size_t M = dA.shape()[0];
std::size_t N = dA.shape()[1];

auto& b_ref = b.derived_cast();

if (dA.dimension() != 2)
{
std::size_t sz = db_t.shape()[0];
db.resize({sz, 1});
std::copy(db_t.storage().begin(), db_t.storage().end(), db.storage().begin());
throw std::runtime_error("Expected 2D expression for A");
}
else

if (!(b_ref.dimension() <= 2))
{
db = db_t;
throw std::runtime_error("Expected 1- or 2D expression for A.");
}

std::size_t M = dA.shape()[0];
std::size_t N = dA.shape()[1];
if (b_ref.shape()[0] != M)
{
throw std::runtime_error("Shape of 'b' for lstsq does not match.");
}

// find number of rhs
std::size_t nrhs = (b_ref.dimension() == 1) ? 1 : b_ref.shape()[1];

// as the dgelsd docs say, on entry it's M-by-nrhs, then result N-by-nrhs
// that is why we need to allocate *MORE* space than just b here for M > N
auto db = xarray<value_type, layout_type::column_major>::from_shape({ std::max(M, N), nrhs });

bool is_1d = false;
if (b_ref.dimension() == 1)
{
is_1d = true;
xt::view(db, range(0, M), xt::all()) = xt::view(b_ref, xt::all(), xt::newaxis());
}
else
{
xt::view(db, range(0, M), xt::all()) = b_ref;
}

auto s = xtensor<underlying_value_type, 1, layout_type::column_major>::from_shape({ std::min(M, N) });

Expand All @@ -1460,8 +1496,12 @@ namespace linalg
}
}

auto vdb = view(db, range(std::size_t(0), N));
auto vdb = view(db, range(std::size_t(0), N), xt::all());
db = vdb;
if (is_1d)
{
db = xt::squeeze(db);
}

return std::make_tuple(db, residuals, rank, s);
}
Expand Down Expand Up @@ -1685,4 +1725,4 @@ namespace linalg
}
}
}
#endif
#endif
168 changes: 168 additions & 0 deletions test/test_generator/cppy_source/test_lstsq.cppy
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/***************************************************************************
* Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
* *
* Distributed under the terms of the BSD 3-Clause License. *
* *
* The full license is in the file LICENSE, distributed with this software. *
****************************************************************************/

#include <algorithm>

#include "gtest/gtest.h"
#include "xtensor/xarray.hpp"
#include "xtensor/xfixed.hpp"
#include "xtensor/xnoalias.hpp"
#include "xtensor/xstrided_view.hpp"
#include "xtensor/xtensor.hpp"
#include "xtensor/xview.hpp"

#include "xtensor-blas/xlinalg.hpp"

namespace xt
{
using namespace xt::placeholders;

/*py
a = np.random.random((6, 3))
b = np.ones((6))
*/
TEST(xtest_extended, lstsq1)
{
// py_a
// py_b
// py_res0 = np.linalg.lstsq(a, b)[0]
// py_res1 = np.linalg.lstsq(a, b)[1]
// py_res2 = np.linalg.lstsq(a, b)[2]
// py_res3 = np.linalg.lstsq(a, b)[3]

auto xres = xt::linalg::lstsq(py_a, py_b);
EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
EXPECT_EQ(std::get<2>(xres), py_res2);
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}

/*py
a = np.random.random((3, 3))
b = np.ones((3))
*/
TEST(xtest_extended, lstsq20)
{
// py_a
// py_b
// py_res0 = np.linalg.lstsq(a, b)[0]
// py_res1 = np.linalg.lstsq(a, b)[1]
// py_res2 = np.linalg.lstsq(a, b)[2]
// py_res3 = np.linalg.lstsq(a, b)[3]

auto xres = xt::linalg::lstsq(py_a, py_b);

EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
EXPECT_EQ(std::get<2>(xres), py_res2);
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}

/*py
a = np.random.random((3, 3))
b = np.ones((3, 3))
*/
TEST(xtest_extended, lstsq21)
{
// py_a
// py_b
// py_res0 = np.linalg.lstsq(a, b)[0]
// py_res1 = np.linalg.lstsq(a, b)[1]
// py_res2 = np.linalg.lstsq(a, b)[2]
// py_res3 = np.linalg.lstsq(a, b)[3]

auto xres = xt::linalg::lstsq(py_a, py_b);

EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
EXPECT_EQ(std::get<2>(xres), py_res2);
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}

/*py
a = np.random.random((2, 5))
b = np.ones((2))
*/
TEST(xtest_extended, lstsq3)
{
// py_a
// py_b
// py_res0 = np.linalg.lstsq(a, b)[0]
// py_res1 = np.linalg.lstsq(a, b)[1]
// py_res2 = np.linalg.lstsq(a, b)[2]
// py_res3 = np.linalg.lstsq(a, b)[3]

auto xres = xt::linalg::lstsq(py_a, py_b);
EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
EXPECT_EQ(std::get<2>(xres), py_res2);
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}

/*py
a = np.random.random((2, 5))
b = np.ones((2, 10))
*/
TEST(xtest_extended, lstsq4)
{
// py_a
// py_b
// py_res0 = np.linalg.lstsq(a, b)[0]
// py_res1 = np.linalg.lstsq(a, b)[1]
// py_res2 = np.linalg.lstsq(a, b)[2]
// py_res3 = np.linalg.lstsq(a, b)[3]

auto xres = xt::linalg::lstsq(py_a, py_b);
EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
EXPECT_EQ(std::get<2>(xres), py_res2);
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}

/*py
a = np.random.random((10, 5))
b = np.ones((10, 20))
*/
TEST(xtest_extended, lstsq5)
{
// py_a
// py_b
// py_res0 = np.linalg.lstsq(a, b)[0]
// py_res1 = np.linalg.lstsq(a, b)[1]
// py_res2 = np.linalg.lstsq(a, b)[2]
// py_res3 = np.linalg.lstsq(a, b)[3]

auto xres = xt::linalg::lstsq(py_a, py_b);
EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
EXPECT_EQ(std::get<2>(xres), py_res2);
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}

/*py
a = np.array([[0., 1.]])
b = np.array([1.])
*/
TEST(xtest_extended, lstsq6)
{
// py_a
// py_b
// py_res0 = np.linalg.lstsq(a, b)[0]
// py_res1 = np.linalg.lstsq(a, b)[1]
// py_res2 = np.linalg.lstsq(a, b)[2]
// py_res3 = np.linalg.lstsq(a, b)[3]

auto xres = xt::linalg::lstsq(py_a, py_b);
EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
EXPECT_EQ(std::get<2>(xres), py_res2);
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}


}

0 comments on commit 3414dd2

Please sign in to comment.