Skip to content

Commit c452790

Browse files
wolfvSylvainCorlay
authored andcommitted
add more dimension checks, new from_shape constructor, const intp cas… (#153)
* add more dimension checks, new from_shape constructor, const intp casting * check strides in cast * ..
1 parent 0d6a645 commit c452790

File tree

8 files changed

+142
-9
lines changed

8 files changed

+142
-9
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ namespace xt
346346
explicit pyarray(const shape_type& shape, const strides_type& strides, const_reference value);
347347
explicit pyarray(const shape_type& shape, const strides_type& strides);
348348

349+
template <class S = shape_type>
350+
static pyarray from_shape(S&& s);
351+
349352
pyarray(const self_type& rhs);
350353
self_type& operator=(const self_type& rhs);
351354

@@ -605,6 +608,18 @@ namespace xt
605608
{
606609
init_array(shape, strides);
607610
}
611+
612+
/**
613+
* Allocates and returns an pyarray with the specified shape.
614+
* @param shape the shape of the pyarray
615+
*/
616+
template <class T, layout_type L>
617+
template <class S>
618+
inline pyarray<T, L> pyarray<T, L>::from_shape(S&& shape)
619+
{
620+
auto shp = xtl::forward_sequence<shape_type>(shape);
621+
return self_type(shp);
622+
}
608623
//@}
609624

610625
/**
@@ -726,6 +741,12 @@ namespace xt
726741
static_cast<size_type>(PyArray_NDIM(this->python_array())));
727742
m_strides = inner_strides_type(reinterpret_cast<size_type*>(PyArray_STRIDES(this->python_array())),
728743
static_cast<size_type>(PyArray_NDIM(this->python_array())));
744+
745+
if (L != layout_type::dynamic && !do_strides_match(m_shape, m_strides, L))
746+
{
747+
throw std::runtime_error("NumPy: passing container with bad strides for layout (is it a view?).");
748+
}
749+
729750
m_backstrides = backstrides_type(*this);
730751
m_storage = storage_type(reinterpret_cast<pointer>(PyArray_DATA(this->python_array())),
731752
this->get_min_stride() * static_cast<size_type>(PyArray_SIZE(this->python_array())));

include/xtensor-python/pycontainer.hpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,7 @@ namespace xt
226226
bool check_array(const pybind11::handle& src)
227227
{
228228
using is_arithmetic_type = std::integral_constant<bool, bool(pybind11::detail::satisfies_any_of<T, std::is_arithmetic, xtl::is_complex>::value)>;
229-
return PyArray_Check(src.ptr()) &&
230-
check_array_type<T>(src, is_arithmetic_type{});
229+
return PyArray_Check(src.ptr()) && check_array_type<T>(src, is_arithmetic_type{});
231230
}
232231
}
233232

@@ -277,9 +276,7 @@ namespace xt
277276
template <class D>
278277
inline bool pycontainer<D>::check_(pybind11::handle h)
279278
{
280-
auto dtype = pybind11::detail::npy_format_descriptor<value_type>::dtype();
281-
return PyArray_Check(h.ptr()) &&
282-
PyArray_EquivTypes_(PyArray_TYPE(reinterpret_cast<PyArrayObject*>(h.ptr())), dtype.ptr());
279+
return detail::check_array<typename D::value_type>(h);
283280
}
284281

285282
template <class D>
@@ -321,6 +318,30 @@ namespace xt
321318
return *static_cast<const derived_type*>(this);
322319
}
323320

321+
namespace detail
322+
{
323+
template <class S>
324+
struct check_dims
325+
{
326+
static bool run(std::size_t)
327+
{
328+
return true;
329+
}
330+
};
331+
332+
template <class T, std::size_t N>
333+
struct check_dims<std::array<T, N>>
334+
{
335+
static bool run(std::size_t new_dim)
336+
{
337+
if(new_dim != N)
338+
{
339+
throw std::runtime_error("Dims not matching.");
340+
}
341+
return new_dim == N;
342+
}
343+
};
344+
}
324345

325346
/**
326347
* resizes the container.
@@ -359,6 +380,7 @@ namespace xt
359380
template <class S>
360381
inline void pycontainer<D>::resize(const S& shape, const strides_type& strides)
361382
{
383+
detail::check_dims<shape_type>::run(shape.size());
362384
derived_type tmp(xtl::forward_sequence<shape_type>(shape), strides);
363385
*static_cast<derived_type*>(this) = std::move(tmp);
364386
}
@@ -369,9 +391,9 @@ namespace xt
369391
{
370392
if (compute_size(shape) != this->size())
371393
{
372-
throw std::runtime_error("Cannot reshape with incorrect number of elements.");
394+
throw std::runtime_error("Cannot reshape with incorrect number of elements (" + std::to_string(this->size()) + " vs " + std::to_string(compute_size(shape)) + ")");
373395
}
374-
396+
detail::check_dims<shape_type>::run(shape.size());
375397
layout = default_assignable_layout(layout);
376398

377399
NPY_ORDER npy_layout;
@@ -388,7 +410,8 @@ namespace xt
388410
throw std::runtime_error("Cannot reshape with unknown layout_type.");
389411
}
390412

391-
PyArray_Dims dims = {reinterpret_cast<npy_intp*>(shape.data()), static_cast<int>(shape.size())};
413+
using shape_ptr = typename std::decay_t<S>::pointer;
414+
PyArray_Dims dims = {reinterpret_cast<npy_intp*>(const_cast<shape_ptr>(shape.data())), static_cast<int>(shape.size())};
392415
auto new_ptr = PyArray_Newshape((PyArrayObject*) this->ptr(), &dims, npy_layout);
393416
auto old_ptr = this->ptr();
394417
this->ptr() = new_ptr;

include/xtensor-python/pytensor.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ namespace xt
163163
explicit pytensor(const shape_type& shape, const strides_type& strides, const_reference value);
164164
explicit pytensor(const shape_type& shape, const strides_type& strides);
165165

166+
template <class S = shape_type>
167+
static pytensor from_shape(S&& shape);
168+
166169
pytensor(const self_type& rhs);
167170
self_type& operator=(const self_type& rhs);
168171

@@ -315,6 +318,19 @@ namespace xt
315318
{
316319
init_tensor(shape, strides);
317320
}
321+
322+
/**
323+
* Allocates and returns an pytensor with the specified shape.
324+
* @param shape the shape of the pytensor
325+
*/
326+
template <class T, std::size_t N, layout_type L>
327+
template <class S>
328+
inline pytensor<T, N, L> pytensor<T, N, L>::from_shape(S&& shape)
329+
{
330+
detail::check_dims<shape_type>::run(shape.size());
331+
auto shp = xtl::forward_sequence<shape_type>(shape);
332+
return self_type(shp);
333+
}
318334
//@}
319335

320336
/**
@@ -429,6 +445,12 @@ namespace xt
429445
std::transform(PyArray_STRIDES(this->python_array()), PyArray_STRIDES(this->python_array()) + N, m_strides.begin(),
430446
[](auto v) { return v / sizeof(T); });
431447
adapt_strides(m_shape, m_strides, m_backstrides);
448+
449+
if (L != layout_type::dynamic && !do_strides_match(m_shape, m_strides, L))
450+
{
451+
throw std::runtime_error("NumPy: passing container with bad strides for layout (is it a view?).");
452+
}
453+
432454
m_storage = storage_type(reinterpret_cast<pointer>(PyArray_DATA(this->python_array())),
433455
this->get_min_stride() * static_cast<size_type>(PyArray_SIZE(this->python_array())));
434456
}

test/test_pyarray.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ namespace xt
5252
}
5353
}
5454

55+
TEST(pyarray, from_shape)
56+
{
57+
auto arr = pyarray<double>::from_shape({5, 2, 6});
58+
auto exp_shape = std::vector<std::size_t>{5, 2, 6};
59+
EXPECT_TRUE(std::equal(arr.shape().begin(), arr.shape().end(), exp_shape.begin()));
60+
EXPECT_EQ(arr.shape().size(), 3);
61+
EXPECT_EQ(arr.size(), 5 * 2 * 6);
62+
}
63+
5564
TEST(pyarray, strided_constructor)
5665
{
5766
central_major_result<> cmr;

test/test_pytensor.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ namespace xt
5151
}
5252
}
5353

54+
TEST(pytensor, from_shape)
55+
{
56+
auto arr = pytensor<double, 3>::from_shape({5, 2, 6});
57+
auto exp_shape = std::vector<std::size_t>{5, 2, 6};
58+
EXPECT_TRUE(std::equal(arr.shape().begin(), arr.shape().end(), exp_shape.begin()));
59+
EXPECT_EQ(arr.shape().size(), 3);
60+
EXPECT_EQ(arr.size(), 5 * 2 * 6);
61+
using pyt3 = pytensor<double, 3>;
62+
std::vector<std::size_t> shp = std::vector<std::size_t>{5, 2};
63+
EXPECT_THROW(pyt3::from_shape(shp), std::runtime_error);
64+
}
65+
5466
TEST(pytensor, strided_constructor)
5567
{
5668
central_major_result<container_type> cmr;
@@ -211,8 +223,12 @@ namespace xt
211223
{
212224
pytensor<int, 2> a = {{1,2,3}, {4,5,6}};
213225
auto ptr = a.data();
226+
a.reshape(a.shape()); // compilation check
214227
a.reshape({1, 6});
215228
EXPECT_EQ(ptr, a.data());
216-
EXPECT_THROW(a.reshape({6}), std::runtime_error);
229+
EXPECT_THROW(a.reshape(std::vector<std::size_t>{6}), std::runtime_error);
230+
// note this throws because std array has only 1 element initialized
231+
// and the second element is `0`.
232+
EXPECT_THROW(a.reshape({6, 5}), std::runtime_error);
217233
}
218234
}

test_python/main.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "xtensor/xarray.hpp"
1313
#define FORCE_IMPORT_ARRAY
1414
#include "xtensor-python/pyarray.hpp"
15+
#include "xtensor-python/pytensor.hpp"
1516
#include "xtensor-python/pyvectorize.hpp"
1617

1718
namespace py = pybind11;
@@ -154,6 +155,22 @@ void char_array(xt::pyarray<char[20]>& carr)
154155
carr(2)[3] = '\0';
155156
}
156157

158+
void row_major_tensor(xt::pytensor<double, 3, xt::layout_type::row_major>& arg)
159+
{
160+
if (!std::is_same<decltype(arg.begin()), double*>::value)
161+
{
162+
throw std::runtime_error("TEST FAILED");
163+
}
164+
}
165+
166+
void col_major_array(xt::pyarray<double, xt::layout_type::column_major>& arg)
167+
{
168+
if (!std::is_same<decltype(arg.template begin<xt::layout_type::column_major>()), double*>::value)
169+
{
170+
throw std::runtime_error("TEST FAILED");
171+
}
172+
}
173+
157174
PYBIND11_MODULE(xtensor_python_test, m)
158175
{
159176
xt::import_numpy();
@@ -197,4 +214,7 @@ PYBIND11_MODULE(xtensor_python_test, m)
197214
m.def("dtype_to_python", dtype_to_python);
198215
m.def("dtype_from_python", dtype_from_python);
199216
m.def("char_array", char_array);
217+
218+
m.def("col_major_array", col_major_array);
219+
m.def("row_major_tensor", row_major_tensor);
200220
}

test_python/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __str__(self):
4646
['main.cpp'],
4747
include_dirs=[
4848
# Path to pybind11 headers
49+
'../include/',
4950
get_pybind_include(),
5051
get_pybind_include(user=True),
5152
# Path to numpy headers

test_python/test_pyarray.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,24 @@ def test_char_array(self):
125125
self.assertEqual(var[0], b'hello')
126126
self.assertEqual(var[1], b'from')
127127
self.assertEqual(var[2], b'c++')
128+
129+
def test_col_row_major(self):
130+
var = np.arange(50, dtype=float).reshape(2, 5, 5)
131+
132+
with self.assertRaises(RuntimeError):
133+
xt.col_major_array(var)
134+
135+
with self.assertRaises(RuntimeError):
136+
xt.row_major_tensor(var.T)
137+
138+
with self.assertRaises(RuntimeError):
139+
xt.row_major_tensor(var[:, ::2, ::2])
140+
141+
with self.assertRaises(RuntimeError):
142+
# raise for wrong dimension
143+
xt.row_major_tensor(var[0, 0, :])
144+
145+
xt.row_major_tensor(var)
146+
varF = np.arange(50, dtype=float).reshape(2, 5, 5, order='F')
147+
xt.col_major_array(varF)
148+
xt.col_major_array(varF[:, :, 0]) # still col major!

0 commit comments

Comments
 (0)