From 6f0c3feaf088e78c75f2abee90164f20446eba08 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Wed, 14 Jun 2023 15:48:18 +0200 Subject: [PATCH] incorporate feedback by @WKarel --- docs/ndarray.rst | 3 ++- include/nanobind/ndarray.h | 5 +++++ src/nb_ndarray.cpp | 9 ++------- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/ndarray.rst b/docs/ndarray.rst index 2c56df0e..e9468e75 100644 --- a/docs/ndarray.rst +++ b/docs/ndarray.rst @@ -353,7 +353,8 @@ interpreted as follows: ndarray's ``owner`` field to a method's ``self`` argument. It fails with an error if there is already a different owner. -- :cpp:enumerator:`rv_policy::move` is unsupported. +- :cpp:enumerator:`rv_policy::move` is unsupported and demoted to + :cpp:enumerator:`rv_policy::copy`. Limitations ----------- diff --git a/include/nanobind/ndarray.h b/include/nanobind/ndarray.h index 014e040b..674fa2dc 100644 --- a/include/nanobind/ndarray.h +++ b/include/nanobind/ndarray.h @@ -298,10 +298,15 @@ template class ndarray { dlpack::dtype dtype = nanobind::dtype(), int32_t device_type = device::cpu::value, int32_t device_id = 0) { + + if (strides.size() != 0 && strides.size() != shape.size()) + detail::fail("ndarray(): shape and strides have incompatible size!"); + m_handle = detail::ndarray_create( (void *) value, shape.size(), shape.begin(), owner.ptr(), (strides.size() == 0) ? nullptr : strides.begin(), &dtype, std::is_const_v, device_type, device_id); + m_dltensor = *detail::ndarray_inc_ref(m_handle); } diff --git a/src/nb_ndarray.cpp b/src/nb_ndarray.cpp index f0e78e8f..9e36c68a 100644 --- a/src/nb_ndarray.cpp +++ b/src/nb_ndarray.cpp @@ -614,7 +614,7 @@ PyObject *ndarray_wrap(ndarray_handle *th, int framework, bool copy; switch (policy) { case rv_policy::reference_internal: - if (cleanup->self() != th->owner) { + if (cleanup && cleanup->self() != th->owner) { if (th->owner) { PyErr_SetString(PyExc_RuntimeError, "nanobind::detail::ndarray_wrap(): " @@ -633,15 +633,10 @@ PyObject *ndarray_wrap(ndarray_handle *th, int framework, break; case rv_policy::copy: + case rv_policy::move: copy = true; break; - case rv_policy::move: - PyErr_SetString(PyExc_RuntimeError, - "nanobind::detail::ndarray_wrap(): rv_policy::move " - "is not supported!"); - return nullptr; - default: copy = false; break;