Skip to content

Commit

Permalink
Copy constructor gets expression's chunk_shape if it is chunked (#2092)
Browse files Browse the repository at this point in the history
Copy constructor gets expression's chunk_shape if it is chunked
  • Loading branch information
davidbrochart committed Jul 16, 2020
1 parent 0ddc79d commit ef8ff77
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 32 deletions.
81 changes: 49 additions & 32 deletions include/xtensor/xchunked_array.hpp
Expand Up @@ -10,6 +10,43 @@

namespace xt
{
namespace detail
{
// Workaround for VS2015
template <class E>
using try_chunk_shape = decltype(std::declval<E>().chunk_shape());

template <class E, template <class> class OP, class = void>
struct chunk_helper_impl
{
static const auto& chunk_shape(const xexpression<E>& e)
{
return e.derived_cast().shape();
}
using is_chunked = std::false_type;
};

template <class E, template <class> class OP>
struct chunk_helper_impl<E, OP, void_t<OP<E>>>
{
static const auto& chunk_shape(const xexpression<E>& e)
{
return e.derived_cast().chunk_shape();
}
using is_chunked = std::true_type;
};

template <class E>
using chunk_helper = chunk_helper_impl<E, try_chunk_shape>;
}

template<class E>
constexpr bool is_chunked(const xexpression<E>& e)
{
using return_type = typename detail::chunk_helper<E>::is_chunked;
return return_type::value;
}

template <class chunk_type>
class xchunked_array;

Expand Down Expand Up @@ -110,37 +147,6 @@ namespace xt
xchunked_array(xchunked_array&&) = default;
xchunked_array& operator=(xchunked_array&&) = default;

template <class E>
xchunked_array(const xexpression<E>& e)
{
const auto& sh = e.derived_cast().shape();
resize_container(m_shape, sh.size());
std::copy(sh.begin(), sh.end(), m_shape.begin());
m_chunk_shape = m_shape;
// Naive implementation to refine later
m_chunk_shape[0] = std::min(size_type(10), m_shape[0]);
size_type nb_chunks = m_shape[0] / m_chunk_shape[0];
bool additional_chunk = m_shape[0] % m_chunk_shape[0] > 0;
if (additional_chunk)
{
m_chunks.resize({nb_chunks + 1u});
}
else
{
m_chunks.resize({nb_chunks});
}
for (size_type i = 0; i < nb_chunks; ++i)
{
noalias(m_chunks(i)) = strided_view(e.derived_cast(),
{range(i * m_chunk_shape[0], (i + 1u) * m_chunk_shape[0]), ellipsis()});
}
if (additional_chunk)
{
noalias(m_chunks(nb_chunks)) = strided_view(e.derived_cast(),
{range(nb_chunks * m_chunk_shape[0], m_shape[0]), ellipsis()});
}
}

template <class E, class S>
xchunked_array(const xexpression<E>& e, S chunk_shape)
{
Expand Down Expand Up @@ -190,6 +196,13 @@ namespace xt
}
}

template <class E>
xchunked_array(const xexpression<E>& e)
{
const auto& chunk_shape = detail::chunk_helper<E>::chunk_shape(e);
*this = xchunked_array<chunk_type>(e, chunk_shape);
}

template <class E>
self_type& operator=(const xexpression<E>& e)
{
Expand Down Expand Up @@ -229,6 +242,11 @@ namespace xt
return shape().size();
}

shape_type chunk_shape() const
{
return m_chunk_shape;
}

template <class S>
bool broadcast_shape(S& s, bool reuse_cache = false) const
{
Expand Down Expand Up @@ -361,4 +379,3 @@ namespace xt
}

#endif

23 changes: 23 additions & 0 deletions test/test_xchunked_array.cpp
Expand Up @@ -44,6 +44,9 @@ namespace xt

TEST(xchunked_array, assign_expression)
{
#ifdef _MSC_FULL_VER
std::cout << "MSC_FULL_VER = " << _MSC_FULL_VER << std::endl;
#endif
std::vector<size_t> shape1 = {2, 2, 2};
std::vector<size_t> chunk_shape1 = {2, 3, 4};
chunked_array a1(shape1, chunk_shape1);
Expand Down Expand Up @@ -75,13 +78,33 @@ namespace xt
{{1., 2., 3.},
{4., 5., 6.},
{7., 8., 9.}};

EXPECT_EQ(xt::is_chunked(a3), false);

std::vector<size_t> chunk_shape4 = {2, 2};
auto a4 = chunked_array(a3, chunk_shape4);

EXPECT_EQ(xt::is_chunked(a4), true);

double i = 1.;
for (const auto& v: a4)
{
EXPECT_EQ(v, i);
i += 1.;
}

auto a5 = chunked_array(a4);
EXPECT_EQ(xt::is_chunked(a5), true);
for (const auto& v: a5.chunk_shape())
{
EXPECT_EQ(v, 2);
}

auto a6 = chunked_array(a3);
EXPECT_EQ(xt::is_chunked(a6), true);
for (const auto& v: a6.chunk_shape())
{
EXPECT_EQ(v, 3);
}
}
}

0 comments on commit ef8ff77

Please sign in to comment.