Skip to content

Commit

Permalink
xt::linalg::kron: support arguments of arbitrary number of dimensions
Browse files Browse the repository at this point in the history
Before this commit, `xt::linalg::kron` only supports 2D arguments.  This
commit proposes to add support for argument with any number of
dimensions.  This change of behavior is coherent with what `numpy.kron`
does.

The current implementation is a performance step back compared to the
previous one.  Until a better generic implementation is found, keeping a
couple of specialized implementations for the most used scenarios makes
sense.

Tested with `./test/test_xtensor_blas --gtest_filter=xlinalg.kron*`.
  • Loading branch information
lsix committed May 17, 2021
1 parent 7ceb791 commit 1136bab
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
44 changes: 30 additions & 14 deletions include/xtensor-blas/xlinalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define XLINALG_HPP

#include <algorithm>
#include <functional>
#include <limits>
#include <sstream>
#include <chrono>
Expand Down Expand Up @@ -1460,7 +1461,7 @@ namespace linalg
}

/**
* Calculate the Kronecker product between two 2D xexpressions.
* Calculate the Kronecker product between two kD xexpressions.
*/
template <class T, class E>
auto kron(const xexpression<T>& a, const xexpression<E>& b)
Expand All @@ -1470,23 +1471,38 @@ namespace linalg
const auto& da = a.derived_cast();
const auto& db = b.derived_cast();

XTENSOR_ASSERT(da.dimension() == 2);
XTENSOR_ASSERT(db.dimension() == 2);
XTENSOR_ASSERT(da.dimension() == bd.dimension());

std::array<std::size_t, 2> shp = {da.shape()[0] * db.shape()[0], da.shape()[1] * db.shape()[1]};
xtensor<value_type, 2> res(shp);
const auto shapea = da.shape();
const auto shapeb = db.shape();
const std::vector<std::size_t> shp = [&shapea, &shapeb](){
std::vector<std::size_t> r;
r.reserve(shapea.size());
std::transform(shapea.begin(), shapea.end(), shapeb.begin(),
std::back_inserter(r), std::multiplies<std::size_t>());
return r;
}();

for (std::size_t i = 0; i < da.shape()[0]; ++i)
xarray<value_type> res(shp);

std::vector<std::size_t> ia(da.dimension(), 0);
xt::xstrided_slice_vector sv(da.dimension(), 0);

for (auto ii = da.begin(); ii < da.end(); ii++)
{
for (std::size_t j = 0; j < da.shape()[1]; ++j)
for (std::size_t i = 0; i < da.dimension(); i++)
{
for (std::size_t k = 0; k < db.shape()[0]; ++k)
{
for (std::size_t h = 0; h < db.shape()[1]; ++h)
{
res(i * db.shape()[0] + k, j * db.shape()[1] + h) = da(i, j) * db(k, h);
}
}
sv[i] = range(ia[i] * shapeb[i], (ia[i] + 1) * shapeb[i]);
}
strided_view(res, sv) = da[ia] * db;

size_t j = ia.size() - 1;
ia[j]++;
while (ia[j] >= shapea[j] && j > 0)
{
ia[j] = 0;
ia[j - 1]++;
j--;
}
}

Expand Down
16 changes: 16 additions & 0 deletions test/test_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,22 @@ namespace xt
EXPECT_EQ(expected, res);
}

TEST(xlinalg, kron_3d)
{
xarray<int> arg_0 = {{{1, 2, 3}}};

xarray<int> arg_1 = xt::ones<int>({2, 2, 1});

auto res = xt::linalg::kron(arg_0, arg_1);

xarray<int> expected = {{{1, 2, 3},
{1, 2, 3}},
{{1, 2, 3},
{1, 2, 3}}};

EXPECT_EQ(expected, res);
}

TEST(xlinalg, cholesky)
{
xarray<double> arg_0 = {{ 4, 12,-16},
Expand Down

0 comments on commit 1136bab

Please sign in to comment.