Skip to content

Commit

Permalink
Merge pull request #125 from potpath/fix_lstsq_n_eq_1
Browse files Browse the repository at this point in the history
fix lstsq when N == 1, + test
  • Loading branch information
wolfv committed Jun 25, 2019
2 parents 65c123c + 97424f9 commit bfc5e86
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 2 deletions.
5 changes: 3 additions & 2 deletions include/xtensor-blas/xlapack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,14 +763,15 @@ namespace lapack
std::size_t m = A.shape()[0];
std::size_t n = A.shape()[1];

blas_index_t a_stride = static_cast<blas_index_t>(std::max(std::size_t(1), m));
blas_index_t b_stride = static_cast<blas_index_t>(std::max(std::max(std::size_t(1), m), n));

int info = cxxlapack::gelsd<blas_index_t>(
static_cast<blas_index_t>(A.shape()[0]),
static_cast<blas_index_t>(A.shape()[1]),
b_dim,
A.data(),
stride_back(A),
a_stride,
b.data(),
b_stride,
s.data(),
Expand All @@ -794,7 +795,7 @@ namespace lapack
static_cast<blas_index_t>(A.shape()[1]),
b_dim,
A.data(),
stride_back(A),
a_stride,
b.data(),
b_stride,
s.data(),
Expand Down
2 changes: 2 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ set(XTENSOR_BLAS_TESTS
test_blas.cpp
test_lapack.cpp
test_linalg.cpp
test_lstsq.cpp
test_qr.cpp
test_dot.cpp
test_tensordot.cpp
)
Expand Down
23 changes: 23 additions & 0 deletions test/test_generator/cppy_source/test_lstsq.cppy
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,28 @@ namespace xt
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}

/*py
a = np.array([[1.], [1.]])
b = np.array([1., 1.])
*/
TEST(xtest_extended, lstsq7)
{
// cannot use "// py_a" due to ambiguous initializer list conversion below
// xarray<double> py_a = {{1.},
// {1.}};
xarray<double> py_a = xt::ones<double>({2, 1});
// py_b
// py_res0 = np.linalg.lstsq(a, b)[0]
// py_res1 = np.linalg.lstsq(a, b)[1]
// py_res2 = np.linalg.lstsq(a, b)[2]
// py_res3 = np.linalg.lstsq(a, b)[3]

auto xres = xt::linalg::lstsq(py_a, py_b);
EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
EXPECT_EQ(std::get<2>(xres), py_res2);
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}


}
28 changes: 28 additions & 0 deletions test/test_lstsq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,5 +317,33 @@ namespace xt
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}

/*py
a = np.array([[1.], [1.]])
b = np.array([1., 1.])
*/
TEST(xtest_extended, lstsq7)
{
// cannot use "// py_a" due to ambiguous initializer list conversion below
// xarray<double> py_a = {{1.},
// {1.}};
xarray<double> py_a = xt::ones<double>({2, 1});
// py_b
xarray<double> py_b = {1.,1.};
// py_res0 = np.linalg.lstsq(a, b)[0]
xarray<double> py_res0 = {0.9999999999999997};
// py_res1 = np.linalg.lstsq(a, b)[1]
xarray<double> py_res1 = {2.2508083912556065e-33};
// py_res2 = np.linalg.lstsq(a, b)[2]
int py_res2 = 1;
// py_res3 = np.linalg.lstsq(a, b)[3]
xarray<double> py_res3 = {1.4142135623730951};

auto xres = xt::linalg::lstsq(py_a, py_b);
EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
EXPECT_EQ(std::get<2>(xres), py_res2);
EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
}


}

0 comments on commit bfc5e86

Please sign in to comment.