diff --git a/include/xtensor-blas/xlinalg.hpp b/include/xtensor-blas/xlinalg.hpp index 39334c4..150444b 100644 --- a/include/xtensor-blas/xlinalg.hpp +++ b/include/xtensor-blas/xlinalg.hpp @@ -11,6 +11,7 @@ #define XLINALG_HPP #include +#include #include #include #include @@ -1460,7 +1461,7 @@ namespace linalg } /** - * Calculate the Kronecker product between two 2D xexpressions. + * Calculate the Kronecker product between two kD xexpressions. */ template auto kron(const xexpression& a, const xexpression& b) @@ -1470,23 +1471,38 @@ namespace linalg const auto& da = a.derived_cast(); const auto& db = b.derived_cast(); - XTENSOR_ASSERT(da.dimension() == 2); - XTENSOR_ASSERT(db.dimension() == 2); + XTENSOR_ASSERT(da.dimension() == bd.dimension()); - std::array shp = {da.shape()[0] * db.shape()[0], da.shape()[1] * db.shape()[1]}; - xtensor res(shp); + const auto shapea = da.shape(); + const auto shapeb = db.shape(); + const std::vector shp = [&shapea, &shapeb](){ + std::vector r; + r.reserve(shapea.size()); + std::transform(shapea.begin(), shapea.end(), shapeb.begin(), + std::back_inserter(r), std::multiplies()); + return r; + }(); - for (std::size_t i = 0; i < da.shape()[0]; ++i) + xarray res(shp); + + std::vector ia(da.dimension(), 0); + xt::xstrided_slice_vector sv(da.dimension(), 0); + + for (auto ii = da.begin(); ii < da.end(); ii++) { - for (std::size_t j = 0; j < da.shape()[1]; ++j) + for (std::size_t i = 0; i < da.dimension(); i++) { - for (std::size_t k = 0; k < db.shape()[0]; ++k) - { - for (std::size_t h = 0; h < db.shape()[1]; ++h) - { - res(i * db.shape()[0] + k, j * db.shape()[1] + h) = da(i, j) * db(k, h); - } - } + sv[i] = range(ia[i] * shapeb[i], (ia[i] + 1) * shapeb[i]); + } + strided_view(res, sv) = da[ia] * db; + + size_t j = ia.size() - 1; + ia[j]++; + while (ia[j] >= shapea[j] && j > 0) + { + ia[j] = 0; + ia[j - 1]++; + j--; } } diff --git a/test/test_linalg.cpp b/test/test_linalg.cpp index 570750f..2ddaef9 100644 --- a/test/test_linalg.cpp +++ b/test/test_linalg.cpp @@ -366,6 +366,22 @@ namespace xt EXPECT_EQ(expected, res); } + TEST(xlinalg, kron_3d) + { + xarray arg_0 = {{{1, 2, 3}}}; + + xarray arg_1 = xt::ones({2, 2, 1}); + + auto res = xt::linalg::kron(arg_0, arg_1); + + xarray expected = {{{1, 2, 3}, + {1, 2, 3}}, + {{1, 2, 3}, + {1, 2, 3}}}; + + EXPECT_EQ(expected, res); + } + TEST(xlinalg, cholesky) { xarray arg_0 = {{ 4, 12,-16},