Skip to content

Commit

Permalink
Adding lstsq (#10)
Browse files Browse the repository at this point in the history
add lstsq and tests
  • Loading branch information
wolfv committed Apr 26, 2017
1 parent 369d49e commit 4fe83d8
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 23 deletions.
24 changes: 6 additions & 18 deletions include/xtensor-blas/xlapack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,17 +674,11 @@ namespace lapack
return info;
}

template <class E, class F, std::enable_if_t<!is_complex<typename E::value_type>::value>* = nullptr>
auto gelsd(E& A, F& b, double rcond = -1)
template <class E, class F, class S, std::enable_if_t<!is_complex<typename E::value_type>::value>* = nullptr>
int gelsd(E& A, F& b, S& s, XBLAS_INDEX& rank, double rcond)
{
using value_type = typename E::value_type;

std::size_t M = A.shape()[0], N = A.shape()[1];
std::array<std::size_t, 1> shp = {std::min(M, N)};
xtensor<value_type, 1, layout_type::column_major> s(shp);

XBLAS_INDEX rank;

uvector<value_type> work(1);
uvector<XBLAS_INDEX> iwork(1);

Expand Down Expand Up @@ -731,21 +725,15 @@ namespace lapack
iwork.data()
);

return std::make_tuple(info, s);
return info;
}

template <class E, class F, std::enable_if_t<is_complex<typename E::value_type>::value>* = nullptr>
auto gelsd(E& A, F& b, double rcond = -1)
template <class E, class F, class S, std::enable_if_t<is_complex<typename E::value_type>::value>* = nullptr>
int gelsd(E& A, F& b, S& s, XBLAS_INDEX& rank, double rcond = -1)
{
using value_type = typename E::value_type;
using underlying_value_type = typename value_type::value_type;

std::size_t M = A.shape()[0], N = A.shape()[1];
std::array<std::size_t, 1> shp = {std::min(M, N)};
xtensor<value_type, 1, layout_type::column_major> s(shp);

XBLAS_INDEX rank;

uvector<value_type> work(1);
uvector<underlying_value_type> rwork(1);
uvector<XBLAS_INDEX> iwork(1);
Expand Down Expand Up @@ -796,7 +784,7 @@ namespace lapack
iwork.data()
);

return std::make_tuple(info, s);
return info;
}
}

Expand Down
73 changes: 72 additions & 1 deletion include/xtensor-blas/xlinalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,6 @@ namespace linalg
/**
* Calculate the Kronecker product between two 2D xexpressions.
*/

template <class T, class E>
auto kron(const xexpression<T>& a, const xexpression<E>& b)
{
Expand Down Expand Up @@ -1075,6 +1074,78 @@ namespace linalg
return sm;
}

/**
* Calculate the least-squares solution to a linear matrix equation.
*
* @param A coefficient matrix
* @param b Ordinate, or dependent variable values. If b is two-dimensional,
* the least-squares solution is calculated for each of the K columns of b.
* @param rcond Cut-off ratio for small singular values of \em A.
* For the purposes of rank determination, singular values are treated
* as zero if they are smaller than rcond times the largest singular value of a.
*
* @return tuple containing (x, residuals, rank, s) where:
* \em x is the least squares solution. Note that the solution is always returned as
* a 2D matrix where the columns are the solutions (even for a 1D \em b).
* \em s Sums of residuals; squared Euclidean 2-norm for each column in b - a*x.
* If the rank of \em A is < N or M <= N, this is an empty xtensor.
* \em rank the rank of \em A
* \em s singular values of \em A
*/
template <class T, class E>
auto lstsq(const xexpression<T>& A, const xexpression<E>& b, double rcond = -1)
{
using value_type = typename T::value_type;
using underlying_value_type = underlying_value_type_t<typename T::value_type>;

xtensor<value_type, 2, layout_type::column_major> dA = A.derived_cast();
xtensor<value_type, 2, layout_type::column_major> db;

const auto& db_t = b.derived_cast();
if (db_t.dimension() == 1)
{
std::size_t sz = db_t.shape()[0];
db.reshape({sz, 1});
std::copy(db_t.data().begin(), db_t.data().end(), db.data().begin());
}
else
{
db = db_t;
}

std::size_t M = dA.shape()[0];
std::size_t N = dA.shape()[1];

std::array<std::size_t, 1> shp = { std::min(M, N) };
xtensor<underlying_value_type, 1, layout_type::column_major> s(shp);

XBLAS_INDEX rank;

int info = lapack::gelsd(dA, db, s, rank, rcond);

std::array<std::size_t, 1> residuals_shp({0});
xtensor<underlying_value_type, 1> residuals(residuals_shp);

if (std::size_t(rank) == N && M > N)
{
residuals.reshape({db.shape()[1]});
for (std::size_t i = 0; i < db.shape()[1]; ++i)
{
underlying_value_type temp = 0;
for (std::size_t j = N; j < db.shape()[0]; ++j)
{
temp += std::pow(std::abs(db(j, i)), 2);
}
residuals(i) = temp;
}
}

auto vdb = view(db, range(0ul, N));
db = vdb;

return std::make_tuple(db, residuals, rank, s);
}

/**
* Non-broadcasting cross product between two vectors
* Calculate cross product between two 1D vectors with 2- or 3 entries.
Expand Down
55 changes: 51 additions & 4 deletions test/test_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,19 @@ namespace xt

TEST(xlinalg, matrix_rank)
{
int a = linalg::matrix_rank(eye<double>(4));
xarray<double> eall = eye<double>(4);
int a = linalg::matrix_rank(eall);
EXPECT_EQ(4, a);

xarray<double> b = eye<double>(4);
b(1, 1) = 0;
int rb = linalg::matrix_rank(b);
EXPECT_EQ(3, rb);
int ro = linalg::matrix_rank(ones<double>({4, 4}));
xarray<double> ones_arr = ones<double>({4, 4});
int ro = linalg::matrix_rank(ones_arr);
EXPECT_EQ(1, ro);
int rz = linalg::matrix_rank(zeros<double>({4, 4}));
xarray<double> zarr = zeros<double>({4, 4});
int rz = linalg::matrix_rank(zarr);
EXPECT_EQ(0, rz);
}

Expand Down Expand Up @@ -355,7 +358,6 @@ namespace xt
xarray<std::complex<double>> cmplexpected = {{ 1.+0.i, 0.+0.i},
{ 0.+2.i, 1.+0.i}};
EXPECT_EQ(cmplexpected, cmplres);

}

TEST(xlinalg, qr)
Expand Down Expand Up @@ -409,4 +411,49 @@ namespace xt
// EXPECT_TRUE(allclose(tau, eTau));
// EXPECT_TRUE(allclose(erawR, rawR));
}

TEST(xlinalg, lstsq)
{
xarray<double> arg_0 = {{ 0., 1.},
{ 1., 1.},
{ 2., 1.},
{ 3., 1.}};

xarray<double> arg_1 = {{-1., 0.2, 0.9, 2.1},
{ 2., 3. , 2. , 1. }};
arg_1 = transpose(arg_1);
auto res = xt::linalg::lstsq(arg_0, arg_1);

xarray<double, layout_type::column_major> el_0 = {{ 1. ,-0.4 },
{-0.95, 2.6 }};
xarray<double> el_1 = { 0.05, 1.2 };
int el_2 = 2;
xarray<double> el_3 = { 4.10003045, 1.09075677};


EXPECT_TRUE(allclose(el_0, std::get<0>(res)));
EXPECT_TRUE(allclose(el_1, std::get<1>(res)));
EXPECT_EQ(el_2, std::get<2>(res));
EXPECT_TRUE(allclose(el_3, std::get<3>(res)));

xarray<std::complex<double>> carg_0 = {{ 0., 1.},
{ 1. - 3i, 1.},
{ 2., 1.},
{ 3., 1.}};
xarray<std::complex<double>> carg_1 = {{-1. , 0.2+4i, 0.9, 2.1-1i}, {2,3i,2,1}};
carg_1 = transpose(carg_1);
auto cres = xt::linalg::lstsq(carg_0, carg_1);

xarray<std::complex<double>, layout_type::column_major> cel_0 = {{-0.40425532-0.38723404i,-0.61702128-0.44680851i},
{ 1.44680851+1.02765957i, 2.51063830+0.95744681i}};
xarray<double> cel_1 = { 16.11787234, 2.68085106};
int cel_2 = 2;
xarray<double> cel_3 = { 5.01295356, 1.36758789};

EXPECT_TRUE(allclose(imag(cel_0), imag(std::get<0>(cres))));
EXPECT_TRUE(allclose(real(cel_0), real(std::get<0>(cres))));
EXPECT_TRUE(allclose(cel_1, std::get<1>(cres)));
EXPECT_EQ(cel_2, std::get<2>(cres));
EXPECT_TRUE(allclose(cel_3, std::get<3>(cres)));
}
}

0 comments on commit 4fe83d8

Please sign in to comment.