@@ -226,8 +226,7 @@ namespace xt
226226 bool check_array (const pybind11::handle& src)
227227 {
228228 using is_arithmetic_type = std::integral_constant<bool , bool (pybind11::detail::satisfies_any_of<T, std::is_arithmetic, xtl::is_complex>::value)>;
229- return PyArray_Check (src.ptr ()) &&
230- check_array_type<T>(src, is_arithmetic_type{});
229+ return PyArray_Check (src.ptr ()) && check_array_type<T>(src, is_arithmetic_type{});
231230 }
232231 }
233232
@@ -277,9 +276,7 @@ namespace xt
277276 template <class D >
278277 inline bool pycontainer<D>::check_(pybind11::handle h)
279278 {
280- auto dtype = pybind11::detail::npy_format_descriptor<value_type>::dtype ();
281- return PyArray_Check (h.ptr ()) &&
282- PyArray_EquivTypes_ (PyArray_TYPE (reinterpret_cast <PyArrayObject*>(h.ptr ())), dtype.ptr ());
279+ return detail::check_array<typename D::value_type>(h);
283280 }
284281
285282 template <class D >
@@ -321,6 +318,30 @@ namespace xt
321318 return *static_cast <const derived_type*>(this );
322319 }
323320
321+ namespace detail
322+ {
323+ template <class S >
324+ struct check_dims
325+ {
326+ static bool run (std::size_t )
327+ {
328+ return true ;
329+ }
330+ };
331+
332+ template <class T , std::size_t N>
333+ struct check_dims <std::array<T, N>>
334+ {
335+ static bool run (std::size_t new_dim)
336+ {
337+ if (new_dim != N)
338+ {
339+ throw std::runtime_error (" Dims not matching." );
340+ }
341+ return new_dim == N;
342+ }
343+ };
344+ }
324345
325346 /* *
326347 * resizes the container.
@@ -359,6 +380,7 @@ namespace xt
359380 template <class S >
360381 inline void pycontainer<D>::resize(const S& shape, const strides_type& strides)
361382 {
383+ detail::check_dims<shape_type>::run (shape.size ());
362384 derived_type tmp (xtl::forward_sequence<shape_type>(shape), strides);
363385 *static_cast <derived_type*>(this ) = std::move (tmp);
364386 }
@@ -369,9 +391,9 @@ namespace xt
369391 {
370392 if (compute_size (shape) != this ->size ())
371393 {
372- throw std::runtime_error (" Cannot reshape with incorrect number of elements. " );
394+ throw std::runtime_error (" Cannot reshape with incorrect number of elements ( " + std::to_string ( this -> size ()) + " vs " + std::to_string ( compute_size (shape)) + " ) " );
373395 }
374-
396+ detail::check_dims<shape_type>:: run (shape. size ());
375397 layout = default_assignable_layout (layout);
376398
377399 NPY_ORDER npy_layout;
@@ -388,7 +410,8 @@ namespace xt
388410 throw std::runtime_error (" Cannot reshape with unknown layout_type." );
389411 }
390412
391- PyArray_Dims dims = {reinterpret_cast <npy_intp*>(shape.data ()), static_cast <int >(shape.size ())};
413+ using shape_ptr = typename std::decay_t <S>::pointer;
414+ PyArray_Dims dims = {reinterpret_cast <npy_intp*>(const_cast <shape_ptr>(shape.data ())), static_cast <int >(shape.size ())};
392415 auto new_ptr = PyArray_Newshape ((PyArrayObject*) this ->ptr (), &dims, npy_layout);
393416 auto old_ptr = this ->ptr ();
394417 this ->ptr () = new_ptr;
0 commit comments