Skip to content

Commit

Permalink
xt::linalg::kron: support arguments of arbitrary number of dimensions
Browse files Browse the repository at this point in the history
Before this commit, `xt::linalg::kron` only supports 2D arguments.  This
commit proposes to add support for argument with any number of
dimensions.  This change of behavior is coherent with what `numpy.kron`
does.

Given the implementation, it would probably make sense to return a
xexpression instead of a xarray and allow lazy evaluation.  This might be
done in a separate commit.  It could also be possible to have a dynamic
check of the number of dimensions and use specialized implementation for
the more common cases (i.e. 2D) at runtime, which should be more
efficient.

Tested with `./test/test_xtensor_blas --gtest_filter=xlinalg.kron*`.
  • Loading branch information
lsix committed May 12, 2021
1 parent 7ceb791 commit ca4cc18
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 14 deletions.
48 changes: 34 additions & 14 deletions include/xtensor-blas/xlinalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define XLINALG_HPP

#include <algorithm>
#include <functional>
#include <limits>
#include <sstream>
#include <chrono>
Expand Down Expand Up @@ -1460,7 +1461,7 @@ namespace linalg
}

/**
* Calculate the Kronecker product between two 2D xexpressions.
* Calculate the Kronecker product between two kD xexpressions.
*/
template <class T, class E>
auto kron(const xexpression<T>& a, const xexpression<E>& b)
Expand All @@ -1470,23 +1471,42 @@ namespace linalg
const auto& da = a.derived_cast();
const auto& db = b.derived_cast();

XTENSOR_ASSERT(da.dimension() == 2);
XTENSOR_ASSERT(db.dimension() == 2);
XTENSOR_ASSERT(da.dimension() == bd.dimension());

std::array<std::size_t, 2> shp = {da.shape()[0] * db.shape()[0], da.shape()[1] * db.shape()[1]};
xtensor<value_type, 2> res(shp);
const auto shapea = da.shape();
const auto shapeb = db.shape();
const std::vector<std::size_t> shp = [&shapea, &shapeb](){
std::vector<std::size_t> r;
r.reserve(shapea.size());
std::transform(shapea.begin(), shapea.end(), shapeb.begin(),
std::back_inserter(r), std::multiplies<std::size_t>());
return r;
}();

for (std::size_t i = 0; i < da.shape()[0]; ++i)
xarray<value_type> res(shp);

std::vector<std::size_t> ires(da.dimension(), 0);
std::vector<std::size_t> ia(da.dimension(), 0);
std::vector<std::size_t> ib(da.dimension(), 0);

for (size_t i = 0; i < res.size(); i++)
{
for (std::size_t j = 0; j < da.shape()[1]; ++j)
for (size_t j = 0; j < shp.size(); j++)
{
for (std::size_t k = 0; k < db.shape()[0]; ++k)
{
for (std::size_t h = 0; h < db.shape()[1]; ++h)
{
res(i * db.shape()[0] + k, j * db.shape()[1] + h) = da(i, j) * db(k, h);
}
}
ia[j] = ires[j] / shapeb[j];
ib[j] = ires[j] % shapeb[j];
}

res[ires] = da[ia] * db[ib];

// Figure out the index of the next element
size_t j = ires.size() - 1;
ires[j]++;
while (ires[j] >= shp[j] && j > 0)
{
ires[j] = 0;
ires[j - 1]++;
j--;
}
}

Expand Down
16 changes: 16 additions & 0 deletions test/test_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,22 @@ namespace xt
EXPECT_EQ(expected, res);
}

TEST(xlinalg, kron_3d)
{
xarray<int> arg_0 = {{{1, 2, 3}}};

xarray<int> arg_1 = xt::ones<int>({2, 2, 1});

auto res = xt::linalg::kron(arg_0, arg_1);

xarray<int> expected = {{{1, 2, 3},
{1, 2, 3}},
{{1, 2, 3},
{1, 2, 3}}};

EXPECT_EQ(expected, res);
}

TEST(xlinalg, cholesky)
{
xarray<double> arg_0 = {{ 4, 12,-16},
Expand Down

0 comments on commit ca4cc18

Please sign in to comment.