diff --git a/include/xtensor-python/pytensor.hpp b/include/xtensor-python/pytensor.hpp index 0fab55a..ce3d115 100644 --- a/include/xtensor-python/pytensor.hpp +++ b/include/xtensor-python/pytensor.hpp @@ -100,11 +100,26 @@ namespace pybind11 } }; - } + } // namespace detail } namespace xt { + namespace detail { + + template + struct numpy_strides + { + npy_intp value[N]; + }; + + template <> + struct numpy_strides<0> + { + npy_intp* value = nullptr; + }; + + } // namespace detail template struct xiterable_inner_types> @@ -433,8 +448,8 @@ namespace xt template inline void pytensor::init_tensor(const shape_type& shape, const strides_type& strides) { - npy_intp python_strides[N]; - std::transform(strides.begin(), strides.end(), python_strides, + detail::numpy_strides python_strides; + std::transform(strides.begin(), strides.end(), python_strides.value, [](auto v) { return sizeof(T) * v; }); int flags = NPY_ARRAY_ALIGNED; if (!std::is_const::value) @@ -445,7 +460,7 @@ namespace xt auto tmp = pybind11::reinterpret_steal( PyArray_NewFromDescr(&PyArray_Type, (PyArray_Descr*) dtype.release().ptr(), static_cast(shape.size()), - const_cast(shape.data()), python_strides, + const_cast(shape.data()), python_strides.value, nullptr, flags, nullptr)); if (!tmp) diff --git a/test/test_pytensor.cpp b/test/test_pytensor.cpp index 637d58f..f2ac013 100644 --- a/test/test_pytensor.cpp +++ b/test/test_pytensor.cpp @@ -65,6 +65,15 @@ namespace xt EXPECT_THROW(pyt3::from_shape(shp), std::runtime_error); } + TEST(pytensor, scalar_from_shape) + { + std::array shape; + auto a = pytensor::from_shape(shape); + pytensor b(1.2); + EXPECT_TRUE(a.size() == b.size()); + EXPECT_TRUE(xt::has_shape(a, b.shape())); + } + TEST(pytensor, strided_constructor) { central_major_result cmr; diff --git a/test_python/main.cpp b/test_python/main.cpp index b4b6cc6..66f4e3c 100644 --- a/test_python/main.cpp +++ b/test_python/main.cpp @@ -227,6 +227,11 @@ void col_major_array(xt::pyarray& arg) } } +xt::pytensor xscalar(const xt::pytensor& arg) +{ + return xt::sum(arg); +} + template using ndarray = xt::pyarray; @@ -285,6 +290,8 @@ PYBIND11_MODULE(xtensor_python_test, m) m.def("col_major_array", col_major_array); m.def("row_major_tensor", row_major_tensor); + m.def("xscalar", xscalar); + py::class_(m, "C") .def(py::init<>()) .def_property_readonly( diff --git a/test_python/test_pyarray.py b/test_python/test_pyarray.py index c3c1447..71e9dd3 100644 --- a/test_python/test_pyarray.py +++ b/test_python/test_pyarray.py @@ -151,6 +151,10 @@ def test_col_row_major(self): xt.col_major_array(varF) xt.col_major_array(varF[:, :, 0]) # still col major! + def test_xscalar(self): + var = np.arange(50, dtype=int) + self.assertTrue(np.sum(var) == xt.xscalar(var)) + def test_bad_argument_call(self): with self.assertRaises(TypeError): xt.simple_array("foo")