Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions include/xtensor/xmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3013,6 +3013,94 @@ namespace detail {
return cov(eval(stack(xtuple(x, y))));
}
}



/*
* convolution mode placeholders for selecting the algorithm
* used in computing a 1D convolution.
* Same as NumPy's mode parameter.
*/
namespace convolve_mode
{
struct valid{};
struct full{};
}

namespace detail {
template <class E1, class E2>
inline auto convolve_impl(E1&& e1, E2&& e2, convolve_mode::valid)
{
using value_type = typename std::decay<E1>::type::value_type;

size_t const na = e1.size();
size_t const nv = e2.size();
size_t const n = na - nv + 1;
xt::xtensor<value_type, 1> out = xt::zeros<value_type>({ n });
for (size_t i = 0; i < n; i++)
{
for (int j = 0; j < nv; j++)
{
out(i) += e1(j) * e2(j + i);
}
}
return out;
}

template <class E1, class E2>
inline auto convolve_impl(E1&& e1, E2&& e2, convolve_mode::full mode)
{
using value_type = typename std::decay<E1>::type::value_type;

size_t const na = e1.size();
size_t const nv = e2.size();
size_t const n = na + nv - 1;
xt::xtensor<value_type, 1> out = xt::zeros<value_type>({ n });
for (size_t i = 0; i < n; i++)
{
size_t const jmn = (i >= nv - 1) ? i - (nv - 1) : 0;
size_t const jmx = (i < na - 1) ? i : na - 1;
for (size_t j = jmn; j <= jmx; ++j)
{
out(i) += e1(j) * e2(i - j);
}
}
return out;
}
}

/*
* @brief computes the 1D convolution between two 1D expressions
*
* @param a 1D expression
* @param v 1D expression
* @param mode placeholder Select algorithm #convolve_mode
*
* @detail the algorithm convolves a with v and will incur a copy overhead
* should v be longer than a.
*/
template <class E1, class E2, class E3>
inline auto convolve(E1&& a, E2&& v, E3 mode)
{

if (a.dimension() != 1 || v.dimension() != 1)
{
XTENSOR_THROW(std::runtime_error, "Invalid dimentions convolution arguments must be 1D expressions");
}

XTENSOR_ASSERT(a.size() > 0 && v.size() > 0);

//swap them so a is always the longest one
if (a.size() < v.size())
{
return detail::convolve_impl(std::forward<E2>(v), std::forward<E1>(a), mode);
}
else
{
return detail::convolve_impl(std::forward<E1>(a), std::forward<E2>(v), mode);
}
}
}


#endif
23 changes: 23 additions & 0 deletions test/test_xmath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,4 +915,27 @@ namespace xt

EXPECT_EQ(expected, xt::cov(x, y));
}


TEST(xmath, convolve_full)
{
xt::xarray<double> x = { 1.0, 3.0, 1.0 };
xt::xarray<double> y = { 1.0, 1.0, 1.0 };
xt::xarray<double> expected = { 1, 4, 5, 4, 1 };

auto result = xt::convolve(x, y, xt::convolve_mode::full());

EXPECT_EQ(result, expected);
}

TEST(xmath, convolve_valid)
{
xt::xarray<double> x = { 3.0, 1.0, 1.0 };
xt::xarray<double> y = { 1.0, 1.0, 1.0 };
xt::xarray<double> expected = { 5 };

auto result = xt::convolve(x, y, xt::convolve_mode::valid());

EXPECT_EQ(result, expected);
}
}