Skip to content

Commit

Permalink
Merge pull request #87 from amjames/tensordot
Browse files Browse the repository at this point in the history
Implement Tensordot
  • Loading branch information
wolfv committed Oct 12, 2018
2 parents 6b2507d + 08c8db5 commit 232aaed
Show file tree
Hide file tree
Showing 4 changed files with 429 additions and 1 deletion.
8 changes: 7 additions & 1 deletion docs/source/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Defined in ``xtensor-blas/xlinalg.hpp``

The functions here are closely modeled after NumPy's linalg package.

Matrix and vector products
Matrix, vector and tensor products
--------------------------

.. doxygenfunction:: xt::linalg::dot
Expand All @@ -30,6 +30,12 @@ Matrix and vector products
.. doxygenfunction:: xt::linalg::kron
:project: xtensor-blas

.. doxygenfunction:: xt::linalg::tensordot(const xexpression<T>&, const xexpression<O>&, std::size_t)
:project: xtensor-blas

.. doxygenfunction:: xt::linalg::tensordot(const xexpression<T>&, const xexpression<O>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&)
:project: xtensor-blas

Decompositions
--------------

Expand Down
169 changes: 169 additions & 0 deletions include/xtensor-blas/xlinalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,175 @@ namespace linalg
}
return res;
}

/**
* @brief Compute tensor dot product along specified axes for arrays
*
* Compute the sum of products along the last \em naxes axes of a and first
* \em naxes axes of b.
*
* @param xa input array
* @param xb input array
* @param naxes the number of axes to sum over
* @return resulting array
*/
template <class T, class O>
auto tensordot(const xexpression<T>& xa, const xexpression<O>& xb, std::size_t naxes = 2)
{
using value_type = std::common_type_t<typename T::value_type, typename O::value_type>;
using result_type = std::conditional_t<T::static_layout == O::static_layout &&
(T::static_layout != layout_type::dynamic && T::static_layout != layout_type::any),
xarray<value_type, T::static_layout>,
xarray<value_type, XTENSOR_DEFAULT_LAYOUT>>;

result_type result;
auto&& a = view_eval<T::static_layout>(xa.derived_cast());
auto&& b = view_eval<O::static_layout>(xb.derived_cast());
if (naxes == 0)
{
// special case tensor outer product product
xt::dynamic_shape<std::size_t> result_shape(a.dimension() + b.dimension());
std::size_t j = 0;
for (std::size_t i = 0; i < a.dimension(); ++i)
{
result_shape[j++] = a.shape()[i];
}

for (std::size_t i = 0; i < b.dimension(); ++i)
{
result_shape[j++] = b.shape()[i];
}
// flatten a/b
auto vec_a = xt::ravel<T::static_layout>(a);
auto vec_b = xt::ravel<O::static_layout>(b);
// take the outer product of the two vectors
result = outer(vec_a, vec_b);
// reshape the result
result.reshape(result_shape);
}
else
{
// Sum of products over last n axes of A and the first n axis of b
XTENSOR_ASSERT(a.dimension() >= naxes);
XTENSOR_ASSERT(b.dimension() >= naxes);

auto as_it = a.shape().begin() + (a.dimension() - naxes);
auto bs_it = b.shape().begin();
std::size_t sum_len = 1;
for (std::size_t i = 0; i < naxes; ++i)
{
auto a_val = *as_it;
auto b_val = *bs_it;
// check for axes size match
if (a_val != b_val)
{
throw std::runtime_error("Shape mismatch for sum");
}
else
{
sum_len *= a_val;
}
++as_it;
++bs_it;
}
xt::dynamic_shape<std::size_t> result_shape;
std::size_t keep_a_len = 1;
for (auto it = a.shape().begin(); it != a.shape().begin() + (a.dimension() - naxes); ++it)
{
std::size_t len = *it;
keep_a_len *= len;
result_shape.push_back(len);
}
std::size_t keep_b_len = 1;
for (auto it = b.shape().begin() + naxes; it != b.shape().end(); ++it)
{
std::size_t len = *it;
keep_b_len *= len;
result_shape.push_back(len);
}
xarray<value_type, T::static_layout> a_mat = a;
a_mat.reshape({keep_a_len, sum_len});
xarray<value_type, O::static_layout> b_mat = b;
b_mat.reshape({sum_len, keep_b_len});
result = dot(a_mat, b_mat);
if(result_shape.empty())
{
result.reshape({1});
}
else
{
result.reshape(result_shape);
}

}
return result;
}

/**
* @brief Compute tensor dot product along specified axes for arrays
*
* Compute the sum of products along the axes \em ax_a for a and \em ax_b for b
*
* @param xa input array
* @param xb input array
* @param ax_a axes to sum over for \em a
* @param ax_b axes to sum over for \em b
* @return resulting array
*/
template <class T, class O>
auto tensordot(const xexpression<T>& xa, const xexpression<O>& xb, const std::vector<std::size_t>& ax_a,
const std::vector<std::size_t>& ax_b)
{
auto&& a = view_eval<T::static_layout>(xa.derived_cast());
auto&& b = view_eval<O::static_layout>(xb.derived_cast());
XTENSOR_ASSERT(ax_a.size() == ax_b.size());
XTENSOR_ASSERT(ax_a.size() < a.dimension());
XTENSOR_ASSERT(ax_b.size() < b.dimension());
std::size_t n_ax = ax_a.size();
for (std::size_t i = 0; i < n_ax; ++i)
{
XTENSOR_ASSERT(ax_a[i] < a.dimension());
XTENSOR_ASSERT(ax_b[i] < b.dimension());
}

// Move the axes to sum over to the end of a
xt::dynamic_shape<std::size_t> newaxes_a;
xt::dynamic_shape<std::size_t> result_shape;
for (std::size_t i = 0; i < a.dimension(); ++i)
{
auto a_ax_it = std::find(ax_a.begin(), ax_a.end(), i);
// first pass if i is not in ax_a, add to newaxes_a
if (a_ax_it == ax_a.end())
{
newaxes_a.push_back(i);
}
}
for (auto& a_ax_it : ax_a)
{
newaxes_a.push_back(a_ax_it);
}

// Move the axes to sum over to the start of b
xt::dynamic_shape<std::size_t> newaxes_b;
for(auto& b_ax_it : ax_b)
{
newaxes_b.push_back(b_ax_it);
}
for (std::size_t i = 0; i < b.dimension(); ++i)
{
auto b_ax_it = std::find(ax_b.begin(), ax_b.end(), i);
// seccond pass if i is not in ax_b add to newaxes_b
if (b_ax_it == ax_b.end())
{
newaxes_b.push_back(i);
}
}
auto a_t = xt::transpose(a, newaxes_a);
auto b_t = xt::transpose(b, newaxes_b);

// the integer arg form of tensordot will handle the reshape of output for us
return tensordot(a_t, b_t, n_ax);
}
}
}
#endif
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ set(XTENSOR_BLAS_TESTS
test_lapack.cpp
test_linalg.cpp
test_dot.cpp
test_tensordot.cpp
)

set(XTENSOR_BLAS_TARGET test_xtensor_blas)
Expand Down

0 comments on commit 232aaed

Please sign in to comment.