Skip to content

Commit

Permalink
Refine conditions under which ndarray_wrap copies
Browse files Browse the repository at this point in the history
The ndarray caster uses return value policies to decide when to copy its
input, which can lead to bogus results due to a cascade of issues (see
issue #188 for a discussion).

The preceding commits address many of the issues regarding return
value policies. This commit adds one more refinement: it adds an
interpretation of several standard return value policies:

- ``nb::rv_policy::automatic`` copies when an nd-array has no owner and
  is not already associated with a Python object.

- ``nb::rv_policy::automatic_reference`` and
  ``nb::rv_policy::reference`` never copy.

- ``nb::rv_policy::copy`` always copies.

- ``nb::rv_policy::none`` refuses the cast unless the array is already
  associated with an existing Python object (e.g. a NumPy array), in
  which case that object is returned.

- ``nb::rv_policy::reference_internal`` retroactively sets the ndarray's ``owner``
  field to the method's ``self`` argument. It fails with an error
  if there is already an owner.

- ``nb::rv_policy::move`` is not supported.
  • Loading branch information
wjakob committed Jun 13, 2023
1 parent de11176 commit a795751
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 42 deletions.
26 changes: 26 additions & 0 deletions docs/ndarray.rst
Expand Up @@ -329,6 +329,32 @@ when all of them have expired:
);
});
Return value policies
---------------------

Function bindings that return ndarrays admit additional return value policy
annotations to determine whether or not a copy should be made. They are
interpreted as follows:

- :cpp:enumerator:`rv_policy::automatic` causes the array to be copied when it
has no owner and when ti is not already associated with a Python object.

- :cpp:enumerator:`rv_policy::automatic_reference` and
:cpp:enumerator:`rv_policy::reference`
``automatic_reference`` and ``reference`` never copy.

- :cpp:enumerator:`rv_policy::copy` always copies.

- :cpp:enumerator:`rv_policy::none` refuses the cast unless the array is
already associated with an existing Python object (e.g. a NumPy array), in
which case that object is returned.

- :cpp:enumerator:`rv_policy::reference_internal` retroactively sets the
ndarray's ``owner`` field to a method's ``self`` argument. It fails with an
error if there is already a different owner.

- :cpp:enumerator:`rv_policy::move` is unsupported.

Limitations
-----------

Expand Down
2 changes: 1 addition & 1 deletion include/nanobind/nb_lib.h
Expand Up @@ -411,7 +411,7 @@ NB_CORE void ndarray_dec_ref(ndarray_handle *) noexcept;

/// Wrap a ndarray_handle* into a PyCapsule
NB_CORE PyObject *ndarray_wrap(ndarray_handle *, int framework,
rv_policy policy) noexcept;
rv_policy policy, cleanup_list *cleanup) noexcept;

/// Check if an object is a known ndarray type (NumPy, PyTorch, Tensorflow, JAX)
NB_CORE bool ndarray_check(PyObject *o) noexcept;
Expand Down
4 changes: 2 additions & 2 deletions include/nanobind/ndarray.h
Expand Up @@ -423,8 +423,8 @@ template <typename... Args> struct type_caster<ndarray<Args...>> {
}

static handle from_cpp(const ndarray<Args...> &tensor, rv_policy policy,
cleanup_list *) noexcept {
return ndarray_wrap(tensor.handle(), int(Value::Info::framework), policy);
cleanup_list *cleanup) noexcept {
return ndarray_wrap(tensor.handle(), int(Value::Info::framework), policy, cleanup);
}
};

Expand Down
59 changes: 51 additions & 8 deletions src/nb_ndarray.cpp
Expand Up @@ -607,15 +607,53 @@ static void ndarray_capsule_destructor(PyObject *o) {
}

PyObject *ndarray_wrap(ndarray_handle *th, int framework,
rv_policy policy) noexcept {
rv_policy policy, cleanup_list *cleanup) noexcept {
if (!th)
return none().release().ptr();

bool copy = policy == rv_policy::copy || policy == rv_policy::move;
bool copy;
switch (policy) {
case rv_policy::reference_internal:
if (cleanup->self() != th->owner) {

This comment has been minimized.

Copy link
@WKarel

WKarel Jun 14, 2023

Contributor

cleanup may be nullptr.

This comment has been minimized.

Copy link
@wjakob

wjakob Jun 14, 2023

Author Owner

Good catch, thanks!

if (th->owner) {
PyErr_SetString(PyExc_RuntimeError,
"nanobind::detail::ndarray_wrap(): "
"reference_internal policy cannot be "
"applied (ndarray already has an owner)");
return nullptr;
} else {
th->owner = cleanup->self();
Py_INCREF(th->owner);
}
}
[[fallthrough]];

case rv_policy::automatic:
copy = th->owner == nullptr && th->self == nullptr;
break;

case rv_policy::copy:
copy = true;
break;

case rv_policy::move:
PyErr_SetString(PyExc_RuntimeError,

This comment has been minimized.

Copy link
@WKarel

WKarel Jun 14, 2023

Contributor

"move" could gracefully demote to "copy" here.

This comment has been minimized.

Copy link
@wjakob

wjakob Jun 14, 2023

Author Owner

good idea.

"nanobind::detail::ndarray_wrap(): rv_policy::move "
"is not supported!");
return nullptr;

if (th->self && !copy) {
Py_INCREF(th->self);
return th->self;
default:
copy = false;
break;
}

if (!copy) {
if (th->self) {
Py_INCREF(th->self);
return th->self;
} else if (policy == rv_policy::none) {
return nullptr;
}
}

if ((ndarray_framework) framework == ndarray_framework::numpy) {
Expand Down Expand Up @@ -670,10 +708,15 @@ PyObject *ndarray_wrap(ndarray_handle *th, int framework,
return nullptr;
}

object o = steal(PyCapsule_New(th->ndarray, "dltensor",
ndarray_capsule_destructor));
object o;
if (copy && (ndarray_framework) framework == ndarray_framework::none && th->self) {
o = borrow(th->self);
} else {
o = steal(PyCapsule_New(th->ndarray, "dltensor",
ndarray_capsule_destructor));
ndarray_inc_ref(th);
}

ndarray_inc_ref(th);

if (package.is_valid()) {
try {
Expand Down
39 changes: 33 additions & 6 deletions tests/test_ndarray.cpp
Expand Up @@ -8,7 +8,7 @@ namespace nb = nanobind;
using namespace nb::literals;

int destruct_count = 0;
static const float f_const[] { 1, 2, 3, 4, 5, 6, 7, 8 };
static float f_global[] { 1, 2, 3, 4, 5, 6, 7, 8 };

NB_MODULE(test_ndarray_ext, m) {
m.def("get_shape", [](const nb::ndarray<nb::ro> &t) {
Expand Down Expand Up @@ -134,7 +134,9 @@ NB_MODULE(test_ndarray_ext, m) {

return nb::ndarray<float, nb::shape<2, 4>>(f, 2, shape, deleter);
});
m.def("passthrough", [](nb::ndarray<> a) { return a; });

m.def("passthrough", [](nb::ndarray<> a) { return a; }, nb::rv_policy::none);
m.def("passthrough_copy", [](nb::ndarray<> a) { return a; }, nb::rv_policy::copy);

m.def("ret_numpy", []() {
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
Expand All @@ -146,12 +148,16 @@ NB_MODULE(test_ndarray_ext, m) {
});

return nb::ndarray<nb::numpy, float, nb::shape<2, 4>>(f, 2, shape,
deleter);
deleter);
});

m.def("ret_numpy_const", []() {
m.def("ret_numpy_const_ref", []() {
size_t shape[2] = { 2, 4 };
return nb::ndarray<nb::numpy, const float, nb::shape<2, 4>>(f_const, 2, shape);
return nb::ndarray<nb::numpy, const float, nb::shape<2, 4>>(f_global, 2, shape);
}, nb::rv_policy::reference);

m.def("ret_numpy_const", []() {
return nb::ndarray<nb::numpy, const float, nb::shape<2, 4>>(f_global, { 2, 4 });
});

m.def("ret_pytorch", []() {
Expand All @@ -164,7 +170,7 @@ NB_MODULE(test_ndarray_ext, m) {
});

return nb::ndarray<nb::pytorch, float, nb::shape<2, 4>>(f, 2, shape,
deleter);
deleter);
});

m.def("ret_array_scalar", []() {
Expand Down Expand Up @@ -192,4 +198,25 @@ NB_MODULE(test_ndarray_ext, m) {
m.def("accept_ro", [](nb::ndarray<const float, nb::shape<2>> a) { return a(0); });

m.def("check", [](nb::handle h) { return nb::ndarray_check(h); });


struct Cls {
auto f1() { return nb::ndarray<nb::numpy, float>(data, { 10 }); }
auto f2() { return nb::ndarray<nb::numpy, float>(data, { 10 }, nb::cast(this, nb::rv_policy::none)); }
auto f3(nb::handle owner) { return nb::ndarray<nb::numpy, float>(data, { 10 }, owner); }

~Cls() {
destruct_count++;
}

float data [10] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
};

nb::class_<Cls>(m, "Cls")
.def(nb::init<>())
.def("f1", &Cls::f1)
.def("f2", &Cls::f2)
.def("f1_ri", &Cls::f1, nb::rv_policy::reference_internal)
.def("f2_ri", &Cls::f2, nb::rv_policy::reference_internal)
.def("f3_ri", &Cls::f3, nb::rv_policy::reference_internal);
}
131 changes: 106 additions & 25 deletions tests/test_ndarray.py
Expand Up @@ -301,31 +301,9 @@ def __dlpack__(self):

@needs_numpy
def test15_passthrough():
collect()
class wrapper:
def __init__(self, value):
self.value = value
def __dlpack__(self):
return self.value
dc = t.destruct_count()
a = t.return_dlpack()
a = t.ret_numpy()
b = t.passthrough(a)
if hasattr(np, '_from_dlpack'):
y = np._from_dlpack(wrapper(b))
elif hasattr(np, 'from_dlpack'):
y = np.from_dlpack(wrapper(b))
else:
pytest.skip('your version of numpy is too old')

del a
del b
collect()
assert dc == t.destruct_count()
assert y.shape == (2, 4)
assert np.all(y == [[1, 2, 3, 4], [5, 6, 7, 8]])
del y
collect()
assert t.destruct_count() - dc == 1
assert a is b

a = np.array([1,2,3])
b = t.passthrough(a)
Expand Down Expand Up @@ -436,7 +414,7 @@ def test22_ro_array():

@needs_numpy
def test22_return_ro():
x = t.ret_numpy_const()
x = t.ret_numpy_const_ref()
assert t.ret_numpy_const.__doc__ == 'ret_numpy_const() -> numpy.ndarray[dtype=float32, writable=False, shape=(2, 4)]'
assert x.shape == (2, 4)
assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]])
Expand All @@ -459,3 +437,106 @@ def test25_check_tensorflow():
@needs_jax
def test26_check_jax():
assert t.check(jnp.zeros((1)))

@needs_numpy
def test27_rv_policy():
def p(a):
return a.__array_interface__['data']

x1 = t.ret_numpy_const_ref()
x2 = t.ret_numpy_const_ref()
y1 = t.ret_numpy_const()
y2 = t.ret_numpy_const()

z1 = t.passthrough(y1)
z2 = t.passthrough(y2)
q1 = t.passthrough_copy(y1)
q2 = t.passthrough_copy(y2)

assert p(x1) == p(x2)
assert p(y1) != p(y2)

assert z1 is y1
assert z2 is y2
assert q1 is not y1
assert q2 is not y2
assert p(q1) != p(y1)
assert p(q2) != p(y2)

@needs_numpy
def test28_reference_internal():
collect()
dc = t.destruct_count()
c = t.Cls()

v1_a = c.f1()
v1_b = c.f1()
v2_a = c.f2()
v2_b = c.f2()
del c

assert np.all(v1_a == np.arange(10, dtype=np.float32))
assert np.all(v1_b == np.arange(10, dtype=np.float32))

v1_a += 1
v1_b += 2

assert np.all(v1_a == np.arange(10, dtype=np.float32) + 1)
assert np.all(v1_b == np.arange(10, dtype=np.float32) + 2)
del v1_a
del v1_b

assert np.all(v2_a == np.arange(10, dtype=np.float32))
assert np.all(v2_b == np.arange(10, dtype=np.float32))

v2_a += 1
v2_b += 2

assert np.all(v2_a == np.arange(10, dtype=np.float32) + 3)
assert np.all(v2_b == np.arange(10, dtype=np.float32) + 3)

del v2_a
collect()
assert t.destruct_count() == dc

del v2_b
collect()
dc += 1
assert t.destruct_count() == dc

for i in range(2):
c2 = t.Cls()

if i == 0:
v3_a = c2.f1_ri()
v3_b = c2.f1_ri()
else:
v3_a = c2.f2_ri()
v3_b = c2.f2_ri()
del c2

assert np.all(v3_a == np.arange(10, dtype=np.float32))
assert np.all(v3_b == np.arange(10, dtype=np.float32))

v3_a += 1
v3_b += 2

assert np.all(v3_a == np.arange(10, dtype=np.float32) + 3)
assert np.all(v3_b == np.arange(10, dtype=np.float32) + 3)
del v3_a

collect()
assert t.destruct_count() == dc

del v3_b
collect()
dc += 1
assert t.destruct_count() == dc

c3 = t.Cls()
c3_t = (c3,)
with pytest.raises(RuntimeError) as excinfo:
c3.f3_ri(c3_t)

msg = 'nanobind::detail::ndarray_wrap(): reference_internal policy cannot be applied (ndarray already has an owner)'
assert msg in str(excinfo.value)

0 comments on commit a795751

Please sign in to comment.