Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions include/xtensor-python/pyarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,24 +330,14 @@ namespace xt
template<typename... Args>
inline auto pyarray<T, ExtraFlags>::operator()(Args... args) -> reference
{
if (sizeof...(args) != dimension())
{
pybind_array::fail_dim_check(sizeof...(args), "index dimension mismatch");
}
// not using pybind_array::offset_at() / index_at() here so as to avoid another dimension check.
return *(static_cast<pointer>(pybind_array::mutable_data()) + pybind_array::get_byte_offset(args...) / itemsize());
return *(static_cast<pointer>(pybind_array::mutable_data()) + pybind_array::byte_offset(args...) / itemsize());
}

template <class T, int ExtraFlags>
template<typename... Args>
inline auto pyarray<T, ExtraFlags>::operator()(Args... args) const -> const_reference
{
if (sizeof...(args) != dimension())
{
pybind_array::fail_dim_check(sizeof...(args), "index dimension mismatch");
}
// not using pybind_array::offset_at() / index_at() here so as to avoid another dimension check.
return *(static_cast<const_pointer>(pybind_array::data()) + pybind_array::get_byte_offset(args...) / itemsize());
return *(static_cast<const_pointer>(pybind_array::data()) + pybind_array::byte_offset(args...) / itemsize());
}

template <class T, int ExtraFlags>
Expand Down Expand Up @@ -522,7 +512,7 @@ namespace xt
template<typename... Args>
inline auto pyarray<T, ExtraFlags>::index_at(Args... args) const -> size_type
{
return pybind_array::offset_at(args...) / itemsize();
return pybind_array::byte_offset(args...) / itemsize();
}

template <class T, int ExtraFlags>
Expand Down
45 changes: 7 additions & 38 deletions include/xtensor-python/pybind11_backport.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ namespace pybind11

size_type size() const
{
return std::accumulate(shape(), shape() + ndim(), size_type{1}, std::multiplies<size_type>());
return std::accumulate(shape(), shape() + ndim(), size_type(1), std::multiplies<size_type>());
}

size_type itemsize() const
Expand All @@ -114,61 +114,30 @@ namespace pybind11
return reinterpret_cast<const size_type*>(PyArray_GET_(m_ptr, strides));
}

template<typename... Ix>
void* data()
{
return static_cast<void*>(PyArray_GET_(m_ptr, data));
}

template<typename... Ix>
void* mutable_data()
{
// check_writeable();
return static_cast<void *>(PyArray_GET_(m_ptr, data));
}

template<typename... Ix>
size_type offset_at(Ix... index) const
{
if (sizeof...(index) > ndim())
{
fail_dim_check(sizeof...(index), "too many indices for an array");
}
return get_byte_offset(index...);
}
protected:

size_type offset_at() const
template<size_t dim = 0>
inline size_type byte_offset() const
{
return 0;
}

protected:

void fail_dim_check(size_type dim, const std::string& msg) const
template <size_t dim = 0, class... Args>
inline size_type byte_offset(size_type i, Args... args) const
{
throw index_error(msg + ": " + std::to_string(dim) +
" (ndim = " + std::to_string(ndim()) + ")");
return i * strides()[dim] + byte_offset<dim + 1>(args...);
}

template<typename... Ix>
size_type get_byte_offset(Ix... index) const
{
const size_type idx[] = { static_cast<size_type>(index)... };
if (!std::equal(idx + 0, idx + sizeof...(index), shape(), std::less<size_type>{}))
{
auto mismatch = std::mismatch(idx + 0, idx + sizeof...(index), shape(), std::less<size_type>{});
throw index_error(std::string("index ") + std::to_string(*mismatch.first) +
" is out of bounds for axis " + std::to_string(mismatch.first - idx) +
" with size " + std::to_string(*mismatch.second));
}
return std::inner_product(idx + 0, idx + sizeof...(index), strides(), size_type{0});
}

size_type get_byte_offset() const
{
return 0;
}

static std::vector<size_type>
default_strides(const std::vector<size_type>& shape, size_type itemsize)
{
Expand Down