Skip to content

Commit

Permalink
Merge pull request #41 from wolfv/0.6.0
Browse files Browse the repository at this point in the history
Fix 2D vector dot
  • Loading branch information
wolfv committed Nov 8, 2017
2 parents f2e838c + 3a4a7c6 commit e9be213
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 8 deletions.
8 changes: 4 additions & 4 deletions include/xtensor-blas/xblas_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,25 @@ namespace xt
template <class A, std::enable_if_t<A::static_layout == layout_type::column_major>* = nullptr>
inline BLAS_IDX get_leading_stride(const A& a)
{
return (BLAS_IDX) a.strides().back();
return (BLAS_IDX) (a.strides().back() == 0 ? a.shape().front() : a.strides().back());
}

template <class A, std::enable_if_t<A::static_layout == layout_type::row_major>* = nullptr>
inline BLAS_IDX get_leading_stride(const A& a)
{
return (BLAS_IDX) a.strides().front();
return (BLAS_IDX) (a.strides().front() == 0 ? a.shape().back() : a.strides().front());
}

template <class A, std::enable_if_t<A::static_layout != layout_type::row_major && A::static_layout != layout_type::column_major>* = nullptr>
inline BLAS_IDX get_leading_stride(const A& a)
{
if (a.layout() == layout_type::row_major)
{
return (BLAS_IDX) a.strides().front();
return (BLAS_IDX) (a.strides().front() == 0 ? a.shape().back() : a.strides().front());
}
else if (a.layout() == layout_type::column_major)
{
return (BLAS_IDX) a.strides().back();
return (BLAS_IDX) (a.strides().back() == 0 ? a.shape().front() : a.strides().back());
}
else
{
Expand Down
28 changes: 24 additions & 4 deletions include/xtensor-blas/xlinalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,14 +595,20 @@ namespace linalg
using common_type = std::common_type_t<typename T::value_type, typename O::value_type>;
using return_type = xarray<common_type, T::static_layout>;

auto&& t = view_eval<layout_type::row_major>(xt.derived_cast());
auto&& o = view_eval<layout_type::row_major>(xo.derived_cast());

auto&& t = view_eval<T::static_layout>(xt.derived_cast());
auto&& o = view_eval<T::static_layout>(xo.derived_cast());

return_type result;

if (t.dimension() == 1 && o.dimension() == 1)
{
result.reshape(std::vector<std::size_t>{1});
if (t.shape()[0] != o.shape()[0])
{
throw std::runtime_error("Dot: shape mismatch.");
}

if (xtl::is_complex<typename T::value_type>::value)
{
blas::dotu(t, o, result(0));
Expand All @@ -617,16 +623,30 @@ namespace linalg
{
if (t.dimension() == 2 && o.dimension() == 1)
{
if (t.shape()[1] != o.shape()[0])
{
throw std::runtime_error("Dot: shape mismatch.");
}

result.reshape({t.shape()[0]});
blas::gemv(t, o, result);
}
else if (t.dimension() == 1 && o.dimension() == 2)
{
if (t.shape()[0] != o.shape()[0])
{
throw std::runtime_error("Dot: shape mismatch.");
}

result.reshape({o.shape()[1]});
blas::gemv(o, t, result, true);
}
else if (t.dimension() == 2 && o.dimension() == 2)
{
if (t.shape()[1] != o.shape()[0])
{
throw std::runtime_error("Dot: shape mismatch.");
}
result.reshape({t.shape()[0], o.shape()[1]});
blas::gemm(t, o, result);
}
Expand All @@ -641,7 +661,7 @@ namespace linalg
}
if (o.shape()[match_dim] != l)
{
throw std::runtime_error("Dot alignment error.");
throw std::runtime_error("Dot: shape mismatch.");
}

int a_dim = (int) t.dimension();
Expand Down Expand Up @@ -1092,7 +1112,7 @@ namespace linalg
xtype mat = A.derived_cast();

XTENSOR_ASSERT(mat.dimension() == 2);
XTENSOR_ASSERT(mat.shape()[0] == mat_inp.shape()[1]);
XTENSOR_ASSERT(mat.shape()[0] == mat.shape()[1]);

xtype res(mat.shape());
if (n == 0)
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ set(XTENSOR_BLAS_TESTS
test_blas.cpp
test_lapack.cpp
test_linalg.cpp
test_dot.cpp
)

set(XTENSOR_BLAS_TARGET test_xtensor_blas)
Expand Down
133 changes: 133 additions & 0 deletions test/test_dot.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#include "gtest/gtest.h"
#include "xtensor/xarray.hpp"
#include "xtensor/xview.hpp"
#include "xtensor/xbuilder.hpp"
#include "xtensor/xstridedview.hpp"

#include "xtensor-blas/xlinalg.hpp"

namespace xt
{
TEST(xdot, matrix_times_vector)
{
xarray<float> a = xt::ones<float>({1, 4});
xarray<float> b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {1, 1, 1}};

xarray<float> e1 = {{13, 16, 19}};

auto r1 = linalg::dot(a, b);
EXPECT_EQ(e1, r1);

xarray<float> c = xt::ones<float>({3, 1});

auto r2 = linalg::dot(b, c);
xarray<float> e2 = {{6, 15, 24, 3}};
e2.reshape({4, 1});
EXPECT_EQ(e2, r2);

EXPECT_THROW(linalg::dot(b, a), std::runtime_error);
EXPECT_THROW(linalg::dot(c, b), std::runtime_error);
}

TEST(xdot, square_matrix_times_vector)
{
xarray<float> a = {{1, 1, 1}};
xarray<float> b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};

auto r1 = linalg::dot(a, b);

xarray<float> e1 = {{12, 15, 18}};
EXPECT_EQ(r1, e1);

auto r2 = linalg::dot(b, xt::transpose(a));
xarray<float> e2 = xarray<float>::from_shape({3, 1});
e2(0, 0) = 6.f;
e2(1, 0) = 15.f;
e2(2, 0) = 24.f;
EXPECT_EQ(r2, e2);

EXPECT_THROW(linalg::dot(b, a), std::runtime_error);
}

TEST(xdot, vector_times_vector)
{
xarray<float> a = xt::ones<float>({1, 3});
xarray<float> b = xt::ones<float>({3, 1});

auto r1 = linalg::dot(a, b);

xarray<float> e1 = xarray<float>::from_shape({1, 1});
e1(0, 0) = 3;

EXPECT_EQ(e1, r1);

auto r2 = linalg::dot(b, a);
xarray<float> e2 = xt::ones<float>({3, 3});
EXPECT_EQ(e2, r2);

auto r3 = linalg::dot(b, e1);
EXPECT_EQ(b * 3, r3);
}

TEST(xdot, matrix_times_vector_cm)
{
xarray<float, layout_type::column_major> a = xt::ones<float>({1, 4});
xarray<float, layout_type::column_major> b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {1, 1, 1}};

xarray<float, layout_type::column_major> e1 = {{13, 16, 19}};

auto r1 = linalg::dot(a, b);
EXPECT_EQ(e1, r1);

xarray<float, layout_type::column_major> c = xt::ones<float>({3, 1});

auto r2 = linalg::dot(b, c);
xarray<float, layout_type::column_major> e2 = {{6, 15, 24, 3}};
e2.reshape({4, 1});
EXPECT_EQ(e2, r2);

EXPECT_THROW(linalg::dot(b, a), std::runtime_error);
EXPECT_THROW(linalg::dot(c, b), std::runtime_error);
}

TEST(xdot, square_matrix_times_vector_cm)
{
xarray<float, layout_type::column_major> a = {{1, 1, 1}};
xarray<float, layout_type::column_major> b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};

auto r1 = linalg::dot(a, b);

xarray<float, layout_type::column_major> e1 = {{12, 15, 18}};
EXPECT_EQ(r1, e1);

auto r2 = linalg::dot(b, xt::transpose(a));
xarray<float, layout_type::column_major> e2 = xarray<float, layout_type::column_major>::from_shape({3, 1});
e2(0, 0) = 6.f;
e2(1, 0) = 15.f;
e2(2, 0) = 24.f;
EXPECT_EQ(r2, e2);

EXPECT_THROW(linalg::dot(b, a), std::runtime_error);
}

TEST(xdot, vector_times_vector_cm)
{
xarray<float, layout_type::column_major> a = xt::ones<float>({1, 3});
xarray<float, layout_type::column_major> b = xt::ones<float>({3, 1});

auto r1 = linalg::dot(a, b);

xarray<float, layout_type::column_major> e1 = xarray<float, layout_type::column_major>::from_shape({1, 1});
e1(0, 0) = 3;

EXPECT_EQ(e1, r1);

auto r2 = linalg::dot(b, a);
xarray<float, layout_type::column_major> e2 = xt::ones<float>({3, 3});
EXPECT_EQ(e2, r2);

auto r3 = linalg::dot(b, e1);
EXPECT_EQ(b * 3, r3);
}

}

0 comments on commit e9be213

Please sign in to comment.