Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xclosure #85

Merged
merged 1 commit into from
Feb 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions include/xtensor/xarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ namespace xt
using shape_type = typename base_type::shape_type;
using strides_type = typename base_type::strides_type;

using closure_type = const self_type&;

xarray();
explicit xarray(const shape_type& shape, layout l = layout::row_major);
explicit xarray(const shape_type& shape, const_reference value, layout l = layout::row_major);
Expand Down Expand Up @@ -147,8 +145,6 @@ namespace xt
using shape_type = typename base_type::shape_type;
using strides_type = typename base_type::strides_type;

using closure_type = const self_type&;

xarray_adaptor(container_type& data);
xarray_adaptor(container_type& data, const shape_type& shape, layout l = layout::row_major);
xarray_adaptor(container_type& data, const shape_type& shape, const strides_type& strides);
Expand Down
129 changes: 64 additions & 65 deletions include/xtensor/xbroadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,34 +59,34 @@ namespace xt
* to a specified shape. xbroadcast is not meant to be used directly, but
* only with the \ref broadcast helper functions.
*
* @tparam E the type of the \ref xexpression to broadcast
* @tparam S the type of the specified shape.
* @tparam CT the closure type of the \ref xexpression to broadcast
* @tparam X the type of the specified shape.
*/
template <class E, class X, bool LV>
class xbroadcast : public xexpression<xbroadcast<E, X, LV>>
template <class CT, class X>
class xbroadcast : public xexpression<xbroadcast<CT, X>>
{

public:

using self_type = xbroadcast<E, X, LV>;
using self_type = xbroadcast<CT, X>;
using xexpression_type = typename CT::xexpression_type;

using value_type = typename E::value_type;
using reference = typename E::reference;
using const_reference = typename E::const_reference;
using pointer = typename E::pointer;
using const_pointer = typename E::const_pointer;
using size_type = typename E::size_type;
using difference_type = typename E::difference_type;
using value_type = typename xexpression_type::value_type;
using reference = typename xexpression_type::reference;
using const_reference = typename xexpression_type::const_reference;
using pointer = typename xexpression_type::pointer;
using const_pointer = typename xexpression_type::const_pointer;
using size_type = typename xexpression_type::size_type;
using difference_type = typename xexpression_type::difference_type;

using shape_type = promote_shape_t<typename E::shape_type, X>;
using closure_type = const self_type;
using shape_type = promote_shape_t<typename xexpression_type::shape_type, X>;

using const_stepper = typename E::const_stepper;
using const_stepper = typename xexpression_type::const_stepper;
using const_iterator = xiterator<const_stepper, shape_type>;
using const_storage_iterator = const_iterator;

template <class S>
xbroadcast(const E& e, S s) noexcept;
xbroadcast(typename CT::xclosure_type e, S s) noexcept;

size_type dimension() const noexcept;
const shape_type & shape() const noexcept;
Expand Down Expand Up @@ -131,7 +131,7 @@ namespace xt

private:

std::conditional_t<LV, const E&, E> m_e;
typename CT::xclosure_type m_e;
shape_type m_shape;
};

Expand Down Expand Up @@ -192,8 +192,7 @@ namespace xt
template <class E, class S>
inline auto broadcast(E&& e, const S& s) noexcept
{
constexpr bool is_lvalue = std::is_lvalue_reference<decltype(e)>::value;
using broadcast_type = xbroadcast<get_xexpression_type<E>, S, is_lvalue>;
using broadcast_type = xbroadcast<xclosure<E>, S>;
using shape_type = typename broadcast_type::shape_type;
return broadcast_type(std::forward<E>(e), detail::forward_shape<shape_type>(s));
}
Expand All @@ -202,15 +201,15 @@ namespace xt
template <class E, class I>
inline auto broadcast(E&& e, std::initializer_list<I> s) noexcept
{
using broadcast_type = xbroadcast<get_xexpression_type<E>, std::vector<std::size_t>, false>;
using broadcast_type = xbroadcast<xclosure<E>, std::vector<std::size_t>>;
using shape_type = typename broadcast_type::shape_type;
return broadcast_type(std::forward<E>(e), detail::forward_shape<shape_type>(s));
}
#else
template <class E, class I, std::size_t L>
inline auto broadcast(E&& e, const I(&s)[L]) noexcept
{
using broadcast_type = xbroadcast<get_xexpression_type<E>, std::array<std::size_t, L>, false>;
using broadcast_type = xbroadcast<xclosure<E>, std::array<std::size_t, L>>;
using shape_type = typename broadcast_type::shape_type;
return broadcast_type(std::forward<E>(e), detail::forward_shape<shape_type>(s));
}
Expand All @@ -231,9 +230,9 @@ namespace xt
* @param e the expression to broadcast
* @param s the shape to apply
*/
template <class E, class X, bool LV>
template <class CT, class X>
template <class S>
inline xbroadcast<E, X, LV>::xbroadcast(const E& e, S s) noexcept
inline xbroadcast<CT, X>::xbroadcast(typename CT::xclosure_type e, S s) noexcept
: m_e(e), m_shape(std::move(s))
{
xt::broadcast_shape(e.shape(), m_shape);
Expand All @@ -247,17 +246,17 @@ namespace xt
/**
* Returns the number of dimensions of the expression.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::dimension() const noexcept -> size_type
template <class CT, class X>
inline auto xbroadcast<CT, X>::dimension() const noexcept -> size_type
{
return m_shape.size();
}

/**
* Returns the shape of the expression.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::shape() const noexcept -> const shape_type &
template <class CT, class X>
inline auto xbroadcast<CT, X>::shape() const noexcept -> const shape_type &
{
return m_shape;
}
Expand All @@ -272,9 +271,9 @@ namespace xt
* must be unsigned integers, the number of indices should be equal or greater than
* the number of dimensions of the expression.
*/
template <class E, class X, bool LV>
template <class CT, class X>
template <class... Args>
inline auto xbroadcast<E, X, LV>::operator()(Args... args) const -> const_reference
inline auto xbroadcast<CT, X>::operator()(Args... args) const -> const_reference
{
return detail::get_element(m_e, args...);
}
Expand All @@ -285,8 +284,8 @@ namespace xt
* must be unsigned integers, the number of indices in the sequence should be equal or greater
* than the number of dimensions of the container.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::operator[](const xindex& index) const -> const_reference
template <class CT, class X>
inline auto xbroadcast<CT, X>::operator[](const xindex& index) const -> const_reference
{
return element(index.cbegin(), index.cend());
}
Expand All @@ -298,9 +297,9 @@ namespace xt
* The number of indices in the squence should be equal or greater
* than the number of dimensions of the function.
*/
template <class E, class X, bool LV>
template <class CT, class X>
template <class It>
inline auto xbroadcast<E, X, LV>::element(It, It last) const -> const_reference
inline auto xbroadcast<CT, X>::element(It, It last) const -> const_reference
{
// Workaround MSVC bug. m_e.element(last - dimension(), last) does not build.
It first = last;
Expand All @@ -318,9 +317,9 @@ namespace xt
* @param shape the result shape
* @return a boolean indicating whether the broadcasting is trivial
*/
template <class E, class X, bool LV>
template <class CT, class X>
template <class S>
inline bool xbroadcast<E, X, LV>::broadcast_shape(S& shape) const
inline bool xbroadcast<CT, X>::broadcast_shape(S& shape) const
{
return xt::broadcast_shape(m_shape, shape);
}
Expand All @@ -330,9 +329,9 @@ namespace xt
* the broadcasting is trivial.
* @return a boolean indicating whether the broadcasting is trivial
*/
template <class E, class X, bool LV>
template <class CT, class X>
template <class S>
inline bool xbroadcast<E, X, LV>::is_trivial_broadcast(const S& strides) const noexcept
inline bool xbroadcast<CT, X>::is_trivial_broadcast(const S& strides) const noexcept
{
return dimension() == m_e.dimension() &&
std::equal(m_shape.cbegin(), m_shape.cend(), m_e.shape().cbegin()) &&
Expand All @@ -347,8 +346,8 @@ namespace xt
/**
* Returns a constant iterator to the first element of the expression.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::begin() const noexcept -> const_iterator
template <class CT, class X>
inline auto xbroadcast<CT, X>::begin() const noexcept -> const_iterator
{
return cxbegin(shape());
}
Expand All @@ -357,17 +356,17 @@ namespace xt
* Returns a constant iterator to the element following the last element
* of the expression.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::end() const noexcept -> const_iterator
template <class CT, class X>
inline auto xbroadcast<CT, X>::end() const noexcept -> const_iterator
{
return cxend(shape());
}

/**
* Returns a constant iterator to the first element of the expression.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::cbegin() const noexcept -> const_iterator
template <class CT, class X>
inline auto xbroadcast<CT, X>::cbegin() const noexcept -> const_iterator
{
return cxbegin(shape());
}
Expand All @@ -376,8 +375,8 @@ namespace xt
* Returns a constant iterator to the element following the last element
* of the expression.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::cend() const noexcept -> const_iterator
template <class CT, class X>
inline auto xbroadcast<CT, X>::cend() const noexcept -> const_iterator
{
return cxend(shape());
}
Expand All @@ -387,9 +386,9 @@ namespace xt
* iteration is broadcasted to the specified shape.
* @param shape the shape used for braodcasting
*/
template <class E, class X, bool LV>
template <class CT, class X>
template <class S>
inline auto xbroadcast<E, X, LV>::xbegin(const S& shape) const noexcept -> xiterator<const_stepper, S>
inline auto xbroadcast<CT, X>::xbegin(const S& shape) const noexcept -> xiterator<const_stepper, S>
{
// Could check if (broadcastable(shape, m_shape)
return xiterator<const_stepper, S>(stepper_begin(shape), shape);
Expand All @@ -400,9 +399,9 @@ namespace xt
* expression. The iteration is broadcasted to the specified shape.
* @param shape the shape used for broadcasting
*/
template <class E, class X, bool LV>
template <class CT, class X>
template <class S>
inline auto xbroadcast<E, X, LV>::xend(const S& shape) const noexcept -> xiterator<const_stepper, S>
inline auto xbroadcast<CT, X>::xend(const S& shape) const noexcept -> xiterator<const_stepper, S>
{
// Could check if (broadcastable(shape, m_shape)
return xiterator<const_stepper, S>(stepper_end(shape), shape);
Expand All @@ -413,9 +412,9 @@ namespace xt
* iteration is broadcasted to the specified shape.
* @param shape the shape used for braodcasting
*/
template <class E, class X, bool LV>
template <class CT, class X>
template <class S>
inline auto xbroadcast<E, X, LV>::cxbegin(const S& shape) const noexcept -> xiterator<const_stepper, S>
inline auto xbroadcast<CT, X>::cxbegin(const S& shape) const noexcept -> xiterator<const_stepper, S>
{
// Could check if (broadcastable(shape, m_shape)
return xiterator<const_stepper, S>(stepper_begin(shape), shape);
Expand All @@ -426,26 +425,26 @@ namespace xt
* expression. The iteration is broadcasted to the specified shape.
* @param shape the shape used for broadcasting
*/
template <class E, class X, bool LV>
template <class CT, class X>
template <class S>
inline auto xbroadcast<E, X, LV>::cxend(const S& shape) const noexcept -> xiterator<const_stepper, S>
inline auto xbroadcast<CT, X>::cxend(const S& shape) const noexcept -> xiterator<const_stepper, S>
{
// Could check if (broadcastable(shape, m_shape)
return xiterator<const_stepper, S>(stepper_end(shape), shape);
}
//@}

template <class E, class X, bool LV>
template <class CT, class X>
template <class S>
inline auto xbroadcast<E, X, LV>::stepper_begin(const S& shape) const noexcept -> const_stepper
inline auto xbroadcast<CT, X>::stepper_begin(const S& shape) const noexcept -> const_stepper
{
// Could check if (broadcastable(shape, m_shape)
return m_e.stepper_begin(shape);
}

template <class E, class X, bool LV>
template <class CT, class X>
template <class S>
inline auto xbroadcast<E, X, LV>::stepper_end(const S& shape) const noexcept -> const_stepper
inline auto xbroadcast<CT, X>::stepper_end(const S& shape) const noexcept -> const_stepper
{
// Could check if (broadcastable(shape, m_shape)
return m_e.stepper_end(shape);
Expand All @@ -458,8 +457,8 @@ namespace xt
* Returns an iterator to the first element of the buffer
* containing the elements of the expression.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::storage_begin() const noexcept -> const_storage_iterator
template <class CT, class X>
inline auto xbroadcast<CT, X>::storage_begin() const noexcept -> const_storage_iterator
{
return cbegin();
}
Expand All @@ -468,8 +467,8 @@ namespace xt
* Returns an iterator to the element following the last
* element of the buffer containing the elements of the expression.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::storage_end() const noexcept -> const_storage_iterator
template <class CT, class X>
inline auto xbroadcast<CT, X>::storage_end() const noexcept -> const_storage_iterator
{
return cend();
}
Expand All @@ -478,8 +477,8 @@ namespace xt
* Returns a constant iterator to the first element of the buffer
* containing the elements of the expression.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::storage_cbegin() const noexcept -> const_storage_iterator
template <class CT, class X>
inline auto xbroadcast<CT, X>::storage_cbegin() const noexcept -> const_storage_iterator
{
return cbegin();
}
Expand All @@ -488,8 +487,8 @@ namespace xt
* Returns a constant iterator to the element following the last
* element of the buffer containing the elements of the expression.
*/
template <class E, class X, bool LV>
inline auto xbroadcast<E, X, LV>::storage_cend() const noexcept -> const_storage_iterator
template <class CT, class X>
inline auto xbroadcast<CT, X>::storage_cend() const noexcept -> const_storage_iterator
{
return cend();
}
Expand Down
Loading