Skip to content

Commit

Permalink
Allow nanobind methods on non-nanobind classes.
Browse files Browse the repository at this point in the history
This allows us to extend an existing Python (or
pybind11!) class with methods defined in nanobind. This is handy if you
want to add a small amount of nanobind code to a class
defined in other ways. For example, one use is to migrate methods of a
pybind11-defined class to nanobind one by one.

All we have to do to allow this is remove the type check. I
can't see anything that goes wrong if we simply allow dispatch to
proceed.
  • Loading branch information
hawkinsp committed Dec 2, 2022
1 parent 6a25ef3 commit 415ec55
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 43 deletions.
72 changes: 29 additions & 43 deletions src/nb_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,28 +421,19 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
if (!nb_type_cache)
nb_type_cache = internals_get().nb_type;

if (self_arg && Py_TYPE((PyObject *) Py_TYPE(self_arg)) != nb_type_cache)
self_arg = nullptr;

if (!self_arg) {
PyErr_SetString(
PyExc_RuntimeError,
"nanobind::detail::nb_func_vectorcall(): the 'self' argument "
"of a method call should be a nanobind class.");
return nullptr;
}

self_flags = nb_type_data(Py_TYPE(self_arg))->flags;
if (self_flags & (uint32_t) type_flags::is_trampoline)
current_method_data = current_method{ fr->name, self_arg };

if (is_constructor) {
if (((nb_inst *) self_arg)->ready) {
PyErr_SetString(
PyExc_RuntimeError,
"nanobind::detail::nb_func_vectorcall(): the __init__ "
"method should not be called on an initialized object!");
return nullptr;
if (self_arg && Py_TYPE((PyObject *) Py_TYPE(self_arg)) == nb_type_cache) {
self_flags = nb_type_data(Py_TYPE(self_arg))->flags;
if (self_flags & (uint32_t) type_flags::is_trampoline)
current_method_data = current_method{ fr->name, self_arg };

if (is_constructor) {
if (((nb_inst *) self_arg)->ready) {
PyErr_SetString(
PyExc_RuntimeError,
"nanobind::detail::nb_func_vectorcall(): the __init__ "
"method should not be called on an initialized object!");
return nullptr;
}
}
}
}
Expand Down Expand Up @@ -687,25 +678,19 @@ static PyObject *nb_func_vectorcall_simple(PyObject *self,
if (NB_UNLIKELY(!nb_type_cache))
nb_type_cache = internals_get().nb_type;

if (NB_UNLIKELY(!self_arg || Py_TYPE((PyObject *) Py_TYPE(self_arg)) != nb_type_cache)) {
PyErr_SetString(
PyExc_RuntimeError,
"nanobind::detail::nb_func_vectorcall_simple(): the 'self' "
"argument of a method call should be a nanobind class.");
return nullptr;
}

self_flags = nb_type_data(Py_TYPE(self_arg))->flags;
if (NB_UNLIKELY(self_flags & (uint32_t) type_flags::is_trampoline))
current_method_data = current_method{ fr->name, self_arg };

if (is_constructor) {
if (NB_UNLIKELY(((nb_inst *) self_arg)->ready)) {
PyErr_SetString(PyExc_RuntimeError,
"nanobind::detail::nb_func_vectorcall_simple():"
" the __init__ method should not be called on "
"an initialized object!");
return nullptr;
if (NB_LIKELY(self_arg && Py_TYPE((PyObject *) Py_TYPE(self_arg)) == nb_type_cache)) {
self_flags = nb_type_data(Py_TYPE(self_arg))->flags;
if (NB_UNLIKELY(self_flags & (uint32_t) type_flags::is_trampoline))
current_method_data = current_method{ fr->name, self_arg };

if (is_constructor) {
if (NB_UNLIKELY(((nb_inst *) self_arg)->ready)) {
PyErr_SetString(PyExc_RuntimeError,
"nanobind::detail::nb_func_vectorcall_simple():"
" the __init__ method should not be called on "
"an initialized object!");
return nullptr;
}
}
}
}
Expand Down Expand Up @@ -807,11 +792,12 @@ static PyObject *nb_bound_method_vectorcall(PyObject *self,
result = mb->func->vectorcall((PyObject *) mb->func, args_tmp, nargs + 1, kwargs_in);
args_tmp[0] = tmp;
} else {
PyObject **args_tmp = (PyObject **) PyObject_Malloc((nargs + 1) * sizeof(PyObject *));
size_t nkwargs_in = kwargs_in ? (size_t) NB_TUPLE_GET_SIZE(kwargs_in) : 0;
PyObject **args_tmp = (PyObject **) PyObject_Malloc((nargs + nkwargs_in + 1) * sizeof(PyObject *));
if (!args_tmp)
return PyErr_NoMemory();
args_tmp[0] = mb->self;
for (size_t i = 0; i < nargs; ++i)
for (size_t i = 0; i < nargs + nkwargs_in; ++i)
args_tmp[i + 1] = args_in[i];
result = mb->func->vectorcall((PyObject *) mb->func, args_tmp, nargs + 1, kwargs_in);
PyObject_Free(args_tmp);
Expand Down
7 changes: 7 additions & 0 deletions tests/test_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,11 @@ NB_MODULE(test_functions_ext, m) {
m.def("identity_u32", [](uint32_t i) { return i; });
m.def("identity_i64", [](int64_t i) { return i; });
m.def("identity_u64", [](uint64_t i) { return i; });

m.attr("test_33") = nb::cpp_function([](nb::object self, int y) {
return nb::cast<int>(self.attr("x")) + y;
}, nb::is_method());
m.attr("test_34") = nb::cpp_function([](nb::object self, int y) {
return nb::cast<int>(self.attr("x")) * y;
}, nb::arg("y"), nb::is_method());
}
10 changes: 10 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,13 @@ def test31_range(func):
else:
value_out = func(value)
assert value_out == value

def test33_method_on_non_nanobind_class():
class AClass:
def __init__(self):
self.x = 42
AClass.simple_method = t.test_33
AClass.complex_method = t.test_34
a = AClass()
assert a.simple_method(7) == 49
assert a.complex_method(y=2) == 84

0 comments on commit 415ec55

Please sign in to comment.