Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/xtensor-python/pyarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ namespace xt
explicit pyarray(const shape_type& shape, const strides_type& strides, const_reference value);
explicit pyarray(const shape_type& shape, const strides_type& strides);

template <class S = shape_type>
static pyarray from_shape(S&& s);

pyarray(const self_type& rhs);
self_type& operator=(const self_type& rhs);

Expand Down Expand Up @@ -605,6 +608,18 @@ namespace xt
{
init_array(shape, strides);
}

/**
* Allocates and returns an pyarray with the specified shape.
* @param shape the shape of the pyarray
*/
template <class T, layout_type L>
template <class S>
inline pyarray<T, L> pyarray<T, L>::from_shape(S&& shape)
{
auto shp = xtl::forward_sequence<shape_type>(shape);
return self_type(shp);
}
//@}

/**
Expand Down Expand Up @@ -726,6 +741,12 @@ namespace xt
static_cast<size_type>(PyArray_NDIM(this->python_array())));
m_strides = inner_strides_type(reinterpret_cast<size_type*>(PyArray_STRIDES(this->python_array())),
static_cast<size_type>(PyArray_NDIM(this->python_array())));

if (L != layout_type::dynamic && !do_strides_match(m_shape, m_strides, L))
{
throw std::runtime_error("NumPy: passing container with bad strides for layout (is it a view?).");
}

m_backstrides = backstrides_type(*this);
m_storage = storage_type(reinterpret_cast<pointer>(PyArray_DATA(this->python_array())),
this->get_min_stride() * static_cast<size_type>(PyArray_SIZE(this->python_array())));
Expand Down
39 changes: 31 additions & 8 deletions include/xtensor-python/pycontainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ namespace xt
bool check_array(const pybind11::handle& src)
{
using is_arithmetic_type = std::integral_constant<bool, bool(pybind11::detail::satisfies_any_of<T, std::is_arithmetic, xtl::is_complex>::value)>;
return PyArray_Check(src.ptr()) &&
check_array_type<T>(src, is_arithmetic_type{});
return PyArray_Check(src.ptr()) && check_array_type<T>(src, is_arithmetic_type{});
}
}

Expand Down Expand Up @@ -277,9 +276,7 @@ namespace xt
template <class D>
inline bool pycontainer<D>::check_(pybind11::handle h)
{
auto dtype = pybind11::detail::npy_format_descriptor<value_type>::dtype();
return PyArray_Check(h.ptr()) &&
PyArray_EquivTypes_(PyArray_TYPE(reinterpret_cast<PyArrayObject*>(h.ptr())), dtype.ptr());
return detail::check_array<typename D::value_type>(h);
}

template <class D>
Expand Down Expand Up @@ -321,6 +318,30 @@ namespace xt
return *static_cast<const derived_type*>(this);
}

namespace detail
{
template <class S>
struct check_dims
{
static bool run(std::size_t)
{
return true;
}
};

template <class T, std::size_t N>
struct check_dims<std::array<T, N>>
{
static bool run(std::size_t new_dim)
{
if(new_dim != N)
{
throw std::runtime_error("Dims not matching.");
}
return new_dim == N;
}
};
}

/**
* resizes the container.
Expand Down Expand Up @@ -359,6 +380,7 @@ namespace xt
template <class S>
inline void pycontainer<D>::resize(const S& shape, const strides_type& strides)
{
detail::check_dims<shape_type>::run(shape.size());
derived_type tmp(xtl::forward_sequence<shape_type>(shape), strides);
*static_cast<derived_type*>(this) = std::move(tmp);
}
Expand All @@ -369,9 +391,9 @@ namespace xt
{
if (compute_size(shape) != this->size())
{
throw std::runtime_error("Cannot reshape with incorrect number of elements.");
throw std::runtime_error("Cannot reshape with incorrect number of elements (" + std::to_string(this->size()) + " vs " + std::to_string(compute_size(shape)) + ")");
}

detail::check_dims<shape_type>::run(shape.size());
layout = default_assignable_layout(layout);

NPY_ORDER npy_layout;
Expand All @@ -388,7 +410,8 @@ namespace xt
throw std::runtime_error("Cannot reshape with unknown layout_type.");
}

PyArray_Dims dims = {reinterpret_cast<npy_intp*>(shape.data()), static_cast<int>(shape.size())};
using shape_ptr = typename std::decay_t<S>::pointer;
PyArray_Dims dims = {reinterpret_cast<npy_intp*>(const_cast<shape_ptr>(shape.data())), static_cast<int>(shape.size())};
auto new_ptr = PyArray_Newshape((PyArrayObject*) this->ptr(), &dims, npy_layout);
auto old_ptr = this->ptr();
this->ptr() = new_ptr;
Expand Down
22 changes: 22 additions & 0 deletions include/xtensor-python/pytensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ namespace xt
explicit pytensor(const shape_type& shape, const strides_type& strides, const_reference value);
explicit pytensor(const shape_type& shape, const strides_type& strides);

template <class S = shape_type>
static pytensor from_shape(S&& shape);

pytensor(const self_type& rhs);
self_type& operator=(const self_type& rhs);

Expand Down Expand Up @@ -315,6 +318,19 @@ namespace xt
{
init_tensor(shape, strides);
}

/**
* Allocates and returns an pytensor with the specified shape.
* @param shape the shape of the pytensor
*/
template <class T, std::size_t N, layout_type L>
template <class S>
inline pytensor<T, N, L> pytensor<T, N, L>::from_shape(S&& shape)
{
detail::check_dims<shape_type>::run(shape.size());
auto shp = xtl::forward_sequence<shape_type>(shape);
return self_type(shp);
}
//@}

/**
Expand Down Expand Up @@ -429,6 +445,12 @@ namespace xt
std::transform(PyArray_STRIDES(this->python_array()), PyArray_STRIDES(this->python_array()) + N, m_strides.begin(),
[](auto v) { return v / sizeof(T); });
adapt_strides(m_shape, m_strides, m_backstrides);

if (L != layout_type::dynamic && !do_strides_match(m_shape, m_strides, L))
{
throw std::runtime_error("NumPy: passing container with bad strides for layout (is it a view?).");
}

m_storage = storage_type(reinterpret_cast<pointer>(PyArray_DATA(this->python_array())),
this->get_min_stride() * static_cast<size_type>(PyArray_SIZE(this->python_array())));
}
Expand Down
9 changes: 9 additions & 0 deletions test/test_pyarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ namespace xt
}
}

TEST(pyarray, from_shape)
{
auto arr = pyarray<double>::from_shape({5, 2, 6});
auto exp_shape = std::vector<std::size_t>{5, 2, 6};
EXPECT_TRUE(std::equal(arr.shape().begin(), arr.shape().end(), exp_shape.begin()));
EXPECT_EQ(arr.shape().size(), 3);
EXPECT_EQ(arr.size(), 5 * 2 * 6);
}

TEST(pyarray, strided_constructor)
{
central_major_result<> cmr;
Expand Down
18 changes: 17 additions & 1 deletion test/test_pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ namespace xt
}
}

TEST(pytensor, from_shape)
{
auto arr = pytensor<double, 3>::from_shape({5, 2, 6});
auto exp_shape = std::vector<std::size_t>{5, 2, 6};
EXPECT_TRUE(std::equal(arr.shape().begin(), arr.shape().end(), exp_shape.begin()));
EXPECT_EQ(arr.shape().size(), 3);
EXPECT_EQ(arr.size(), 5 * 2 * 6);
using pyt3 = pytensor<double, 3>;
std::vector<std::size_t> shp = std::vector<std::size_t>{5, 2};
EXPECT_THROW(pyt3::from_shape(shp), std::runtime_error);
}

TEST(pytensor, strided_constructor)
{
central_major_result<container_type> cmr;
Expand Down Expand Up @@ -211,8 +223,12 @@ namespace xt
{
pytensor<int, 2> a = {{1,2,3}, {4,5,6}};
auto ptr = a.data();
a.reshape(a.shape()); // compilation check
a.reshape({1, 6});
EXPECT_EQ(ptr, a.data());
EXPECT_THROW(a.reshape({6}), std::runtime_error);
EXPECT_THROW(a.reshape(std::vector<std::size_t>{6}), std::runtime_error);
// note this throws because std array has only 1 element initialized
// and the second element is `0`.
EXPECT_THROW(a.reshape({6, 5}), std::runtime_error);
}
}
20 changes: 20 additions & 0 deletions test_python/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "xtensor/xarray.hpp"
#define FORCE_IMPORT_ARRAY
#include "xtensor-python/pyarray.hpp"
#include "xtensor-python/pytensor.hpp"
#include "xtensor-python/pyvectorize.hpp"

namespace py = pybind11;
Expand Down Expand Up @@ -154,6 +155,22 @@ void char_array(xt::pyarray<char[20]>& carr)
carr(2)[3] = '\0';
}

void row_major_tensor(xt::pytensor<double, 3, xt::layout_type::row_major>& arg)
{
if (!std::is_same<decltype(arg.begin()), double*>::value)
{
throw std::runtime_error("TEST FAILED");
}
}

void col_major_array(xt::pyarray<double, xt::layout_type::column_major>& arg)
{
if (!std::is_same<decltype(arg.template begin<xt::layout_type::column_major>()), double*>::value)
{
throw std::runtime_error("TEST FAILED");
}
}

PYBIND11_MODULE(xtensor_python_test, m)
{
xt::import_numpy();
Expand Down Expand Up @@ -197,4 +214,7 @@ PYBIND11_MODULE(xtensor_python_test, m)
m.def("dtype_to_python", dtype_to_python);
m.def("dtype_from_python", dtype_from_python);
m.def("char_array", char_array);

m.def("col_major_array", col_major_array);
m.def("row_major_tensor", row_major_tensor);
}
1 change: 1 addition & 0 deletions test_python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __str__(self):
['main.cpp'],
include_dirs=[
# Path to pybind11 headers
'../include/',
get_pybind_include(),
get_pybind_include(user=True),
# Path to numpy headers
Expand Down
21 changes: 21 additions & 0 deletions test_python/test_pyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,24 @@ def test_char_array(self):
self.assertEqual(var[0], b'hello')
self.assertEqual(var[1], b'from')
self.assertEqual(var[2], b'c++')

def test_col_row_major(self):
var = np.arange(50, dtype=float).reshape(2, 5, 5)

with self.assertRaises(RuntimeError):
xt.col_major_array(var)

with self.assertRaises(RuntimeError):
xt.row_major_tensor(var.T)

with self.assertRaises(RuntimeError):
xt.row_major_tensor(var[:, ::2, ::2])

with self.assertRaises(RuntimeError):
# raise for wrong dimension
xt.row_major_tensor(var[0, 0, :])

xt.row_major_tensor(var)
varF = np.arange(50, dtype=float).reshape(2, 5, 5, order='F')
xt.col_major_array(varF)
xt.col_major_array(varF[:, :, 0]) # still col major!