From a79575165134c72c0a26e46772290d0404eae7a3 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Wed, 14 Jun 2023 00:18:28 +0200 Subject: [PATCH] Refine conditions under which ``ndarray_wrap`` copies The ndarray caster uses return value policies to decide when to copy its input, which can lead to bogus results due to a cascade of issues (see issue #188 for a discussion). The preceding commits address many of the issues regarding return value policies. This commit adds one more refinement: it adds an interpretation of several standard return value policies: - ``nb::rv_policy::automatic`` copies when an nd-array has no owner and is not already associated with a Python object. - ``nb::rv_policy::automatic_reference`` and ``nb::rv_policy::reference`` never copy. - ``nb::rv_policy::copy`` always copies. - ``nb::rv_policy::none`` refuses the cast unless the array is already associated with an existing Python object (e.g. a NumPy array), in which case that object is returned. - ``nb::rv_policy::reference_internal`` retroactively sets the ndarray's ``owner`` field to the method's ``self`` argument. It fails with an error if there is already an owner. - ``nb::rv_policy::move`` is not supported. --- docs/ndarray.rst | 26 ++++++++ include/nanobind/nb_lib.h | 2 +- include/nanobind/ndarray.h | 4 +- src/nb_ndarray.cpp | 59 ++++++++++++++--- tests/test_ndarray.cpp | 39 +++++++++-- tests/test_ndarray.py | 131 ++++++++++++++++++++++++++++++------- 6 files changed, 219 insertions(+), 42 deletions(-) diff --git a/docs/ndarray.rst b/docs/ndarray.rst index 110cdc6a..2c56df0e 100644 --- a/docs/ndarray.rst +++ b/docs/ndarray.rst @@ -329,6 +329,32 @@ when all of them have expired: ); }); +Return value policies +--------------------- + +Function bindings that return ndarrays admit additional return value policy +annotations to determine whether or not a copy should be made. They are +interpreted as follows: + +- :cpp:enumerator:`rv_policy::automatic` causes the array to be copied when it + has no owner and when ti is not already associated with a Python object. + +- :cpp:enumerator:`rv_policy::automatic_reference` and + :cpp:enumerator:`rv_policy::reference` + ``automatic_reference`` and ``reference`` never copy. + +- :cpp:enumerator:`rv_policy::copy` always copies. + +- :cpp:enumerator:`rv_policy::none` refuses the cast unless the array is + already associated with an existing Python object (e.g. a NumPy array), in + which case that object is returned. + +- :cpp:enumerator:`rv_policy::reference_internal` retroactively sets the + ndarray's ``owner`` field to a method's ``self`` argument. It fails with an + error if there is already a different owner. + +- :cpp:enumerator:`rv_policy::move` is unsupported. + Limitations ----------- diff --git a/include/nanobind/nb_lib.h b/include/nanobind/nb_lib.h index 734198d8..9d51040f 100644 --- a/include/nanobind/nb_lib.h +++ b/include/nanobind/nb_lib.h @@ -411,7 +411,7 @@ NB_CORE void ndarray_dec_ref(ndarray_handle *) noexcept; /// Wrap a ndarray_handle* into a PyCapsule NB_CORE PyObject *ndarray_wrap(ndarray_handle *, int framework, - rv_policy policy) noexcept; + rv_policy policy, cleanup_list *cleanup) noexcept; /// Check if an object is a known ndarray type (NumPy, PyTorch, Tensorflow, JAX) NB_CORE bool ndarray_check(PyObject *o) noexcept; diff --git a/include/nanobind/ndarray.h b/include/nanobind/ndarray.h index db795ff6..014e040b 100644 --- a/include/nanobind/ndarray.h +++ b/include/nanobind/ndarray.h @@ -423,8 +423,8 @@ template struct type_caster> { } static handle from_cpp(const ndarray &tensor, rv_policy policy, - cleanup_list *) noexcept { - return ndarray_wrap(tensor.handle(), int(Value::Info::framework), policy); + cleanup_list *cleanup) noexcept { + return ndarray_wrap(tensor.handle(), int(Value::Info::framework), policy, cleanup); } }; diff --git a/src/nb_ndarray.cpp b/src/nb_ndarray.cpp index ee18fbaa..f0e78e8f 100644 --- a/src/nb_ndarray.cpp +++ b/src/nb_ndarray.cpp @@ -607,15 +607,53 @@ static void ndarray_capsule_destructor(PyObject *o) { } PyObject *ndarray_wrap(ndarray_handle *th, int framework, - rv_policy policy) noexcept { + rv_policy policy, cleanup_list *cleanup) noexcept { if (!th) return none().release().ptr(); - bool copy = policy == rv_policy::copy || policy == rv_policy::move; + bool copy; + switch (policy) { + case rv_policy::reference_internal: + if (cleanup->self() != th->owner) { + if (th->owner) { + PyErr_SetString(PyExc_RuntimeError, + "nanobind::detail::ndarray_wrap(): " + "reference_internal policy cannot be " + "applied (ndarray already has an owner)"); + return nullptr; + } else { + th->owner = cleanup->self(); + Py_INCREF(th->owner); + } + } + [[fallthrough]]; + + case rv_policy::automatic: + copy = th->owner == nullptr && th->self == nullptr; + break; + + case rv_policy::copy: + copy = true; + break; + + case rv_policy::move: + PyErr_SetString(PyExc_RuntimeError, + "nanobind::detail::ndarray_wrap(): rv_policy::move " + "is not supported!"); + return nullptr; - if (th->self && !copy) { - Py_INCREF(th->self); - return th->self; + default: + copy = false; + break; + } + + if (!copy) { + if (th->self) { + Py_INCREF(th->self); + return th->self; + } else if (policy == rv_policy::none) { + return nullptr; + } } if ((ndarray_framework) framework == ndarray_framework::numpy) { @@ -670,10 +708,15 @@ PyObject *ndarray_wrap(ndarray_handle *th, int framework, return nullptr; } - object o = steal(PyCapsule_New(th->ndarray, "dltensor", - ndarray_capsule_destructor)); + object o; + if (copy && (ndarray_framework) framework == ndarray_framework::none && th->self) { + o = borrow(th->self); + } else { + o = steal(PyCapsule_New(th->ndarray, "dltensor", + ndarray_capsule_destructor)); + ndarray_inc_ref(th); + } - ndarray_inc_ref(th); if (package.is_valid()) { try { diff --git a/tests/test_ndarray.cpp b/tests/test_ndarray.cpp index 2afad92c..bb0a52ec 100644 --- a/tests/test_ndarray.cpp +++ b/tests/test_ndarray.cpp @@ -8,7 +8,7 @@ namespace nb = nanobind; using namespace nb::literals; int destruct_count = 0; -static const float f_const[] { 1, 2, 3, 4, 5, 6, 7, 8 }; +static float f_global[] { 1, 2, 3, 4, 5, 6, 7, 8 }; NB_MODULE(test_ndarray_ext, m) { m.def("get_shape", [](const nb::ndarray &t) { @@ -134,7 +134,9 @@ NB_MODULE(test_ndarray_ext, m) { return nb::ndarray>(f, 2, shape, deleter); }); - m.def("passthrough", [](nb::ndarray<> a) { return a; }); + + m.def("passthrough", [](nb::ndarray<> a) { return a; }, nb::rv_policy::none); + m.def("passthrough_copy", [](nb::ndarray<> a) { return a; }, nb::rv_policy::copy); m.def("ret_numpy", []() { float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; @@ -146,12 +148,16 @@ NB_MODULE(test_ndarray_ext, m) { }); return nb::ndarray>(f, 2, shape, - deleter); + deleter); }); - m.def("ret_numpy_const", []() { + m.def("ret_numpy_const_ref", []() { size_t shape[2] = { 2, 4 }; - return nb::ndarray>(f_const, 2, shape); + return nb::ndarray>(f_global, 2, shape); + }, nb::rv_policy::reference); + + m.def("ret_numpy_const", []() { + return nb::ndarray>(f_global, { 2, 4 }); }); m.def("ret_pytorch", []() { @@ -164,7 +170,7 @@ NB_MODULE(test_ndarray_ext, m) { }); return nb::ndarray>(f, 2, shape, - deleter); + deleter); }); m.def("ret_array_scalar", []() { @@ -192,4 +198,25 @@ NB_MODULE(test_ndarray_ext, m) { m.def("accept_ro", [](nb::ndarray> a) { return a(0); }); m.def("check", [](nb::handle h) { return nb::ndarray_check(h); }); + + + struct Cls { + auto f1() { return nb::ndarray(data, { 10 }); } + auto f2() { return nb::ndarray(data, { 10 }, nb::cast(this, nb::rv_policy::none)); } + auto f3(nb::handle owner) { return nb::ndarray(data, { 10 }, owner); } + + ~Cls() { + destruct_count++; + } + + float data [10] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + }; + + nb::class_(m, "Cls") + .def(nb::init<>()) + .def("f1", &Cls::f1) + .def("f2", &Cls::f2) + .def("f1_ri", &Cls::f1, nb::rv_policy::reference_internal) + .def("f2_ri", &Cls::f2, nb::rv_policy::reference_internal) + .def("f3_ri", &Cls::f3, nb::rv_policy::reference_internal); } diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 781d0bc5..29973a56 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -301,31 +301,9 @@ def __dlpack__(self): @needs_numpy def test15_passthrough(): - collect() - class wrapper: - def __init__(self, value): - self.value = value - def __dlpack__(self): - return self.value - dc = t.destruct_count() - a = t.return_dlpack() + a = t.ret_numpy() b = t.passthrough(a) - if hasattr(np, '_from_dlpack'): - y = np._from_dlpack(wrapper(b)) - elif hasattr(np, 'from_dlpack'): - y = np.from_dlpack(wrapper(b)) - else: - pytest.skip('your version of numpy is too old') - - del a - del b - collect() - assert dc == t.destruct_count() - assert y.shape == (2, 4) - assert np.all(y == [[1, 2, 3, 4], [5, 6, 7, 8]]) - del y - collect() - assert t.destruct_count() - dc == 1 + assert a is b a = np.array([1,2,3]) b = t.passthrough(a) @@ -436,7 +414,7 @@ def test22_ro_array(): @needs_numpy def test22_return_ro(): - x = t.ret_numpy_const() + x = t.ret_numpy_const_ref() assert t.ret_numpy_const.__doc__ == 'ret_numpy_const() -> numpy.ndarray[dtype=float32, writable=False, shape=(2, 4)]' assert x.shape == (2, 4) assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) @@ -459,3 +437,106 @@ def test25_check_tensorflow(): @needs_jax def test26_check_jax(): assert t.check(jnp.zeros((1))) + +@needs_numpy +def test27_rv_policy(): + def p(a): + return a.__array_interface__['data'] + + x1 = t.ret_numpy_const_ref() + x2 = t.ret_numpy_const_ref() + y1 = t.ret_numpy_const() + y2 = t.ret_numpy_const() + + z1 = t.passthrough(y1) + z2 = t.passthrough(y2) + q1 = t.passthrough_copy(y1) + q2 = t.passthrough_copy(y2) + + assert p(x1) == p(x2) + assert p(y1) != p(y2) + + assert z1 is y1 + assert z2 is y2 + assert q1 is not y1 + assert q2 is not y2 + assert p(q1) != p(y1) + assert p(q2) != p(y2) + +@needs_numpy +def test28_reference_internal(): + collect() + dc = t.destruct_count() + c = t.Cls() + + v1_a = c.f1() + v1_b = c.f1() + v2_a = c.f2() + v2_b = c.f2() + del c + + assert np.all(v1_a == np.arange(10, dtype=np.float32)) + assert np.all(v1_b == np.arange(10, dtype=np.float32)) + + v1_a += 1 + v1_b += 2 + + assert np.all(v1_a == np.arange(10, dtype=np.float32) + 1) + assert np.all(v1_b == np.arange(10, dtype=np.float32) + 2) + del v1_a + del v1_b + + assert np.all(v2_a == np.arange(10, dtype=np.float32)) + assert np.all(v2_b == np.arange(10, dtype=np.float32)) + + v2_a += 1 + v2_b += 2 + + assert np.all(v2_a == np.arange(10, dtype=np.float32) + 3) + assert np.all(v2_b == np.arange(10, dtype=np.float32) + 3) + + del v2_a + collect() + assert t.destruct_count() == dc + + del v2_b + collect() + dc += 1 + assert t.destruct_count() == dc + + for i in range(2): + c2 = t.Cls() + + if i == 0: + v3_a = c2.f1_ri() + v3_b = c2.f1_ri() + else: + v3_a = c2.f2_ri() + v3_b = c2.f2_ri() + del c2 + + assert np.all(v3_a == np.arange(10, dtype=np.float32)) + assert np.all(v3_b == np.arange(10, dtype=np.float32)) + + v3_a += 1 + v3_b += 2 + + assert np.all(v3_a == np.arange(10, dtype=np.float32) + 3) + assert np.all(v3_b == np.arange(10, dtype=np.float32) + 3) + del v3_a + + collect() + assert t.destruct_count() == dc + + del v3_b + collect() + dc += 1 + assert t.destruct_count() == dc + + c3 = t.Cls() + c3_t = (c3,) + with pytest.raises(RuntimeError) as excinfo: + c3.f3_ri(c3_t) + + msg = 'nanobind::detail::ndarray_wrap(): reference_internal policy cannot be applied (ndarray already has an owner)' + assert msg in str(excinfo.value)