Skip to content

Commit

Permalink
Trivial refactoring to make the capsule API more user friendly. (#4720)
Browse files Browse the repository at this point in the history
* Trivial refactoring to make the capsule API more user friendly.

* Use new API in production code. Thanks @lalaland for pointing this out.
  • Loading branch information
rwgk committed Jun 27, 2023
1 parent e10da79 commit 2fb3d7c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 21 deletions.
2 changes: 1 addition & 1 deletion include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ class cpp_function : public function {
rec->def->ml_flags = METH_VARARGS | METH_KEYWORDS;

capsule rec_capsule(unique_rec.release(),
detail::get_function_record_capsule_name(),
[](void *ptr) { destruct((detail::function_record *) ptr); });
rec_capsule.set_name(detail::get_function_record_capsule_name());
guarded_strdup.release();

object scope_module;
Expand Down
51 changes: 31 additions & 20 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1925,28 +1925,13 @@ class capsule : public object {
}
}

/// Capsule name is nullptr.
capsule(const void *value, void (*destructor)(void *)) {
m_ptr = PyCapsule_New(const_cast<void *>(value), nullptr, [](PyObject *o) {
// guard if destructor called while err indicator is set
error_scope error_guard;
auto destructor = reinterpret_cast<void (*)(void *)>(PyCapsule_GetContext(o));
if (destructor == nullptr && PyErr_Occurred()) {
throw error_already_set();
}
const char *name = get_name_in_error_scope(o);
void *ptr = PyCapsule_GetPointer(o, name);
if (ptr == nullptr) {
throw error_already_set();
}

if (destructor != nullptr) {
destructor(ptr);
}
});
initialize_with_void_ptr_destructor(value, nullptr, destructor);
}

if (!m_ptr || PyCapsule_SetContext(m_ptr, reinterpret_cast<void *>(destructor)) != 0) {
throw error_already_set();
}
capsule(const void *value, const char *name, void (*destructor)(void *)) {
initialize_with_void_ptr_destructor(value, name, destructor);
}

explicit capsule(void (*destructor)()) {
Expand Down Expand Up @@ -2014,6 +1999,32 @@ class capsule : public object {

return name;
}

void initialize_with_void_ptr_destructor(const void *value,
const char *name,
void (*destructor)(void *)) {
m_ptr = PyCapsule_New(const_cast<void *>(value), name, [](PyObject *o) {
// guard if destructor called while err indicator is set
error_scope error_guard;
auto destructor = reinterpret_cast<void (*)(void *)>(PyCapsule_GetContext(o));
if (destructor == nullptr && PyErr_Occurred()) {
throw error_already_set();
}
const char *name = get_name_in_error_scope(o);
void *ptr = PyCapsule_GetPointer(o, name);
if (ptr == nullptr) {
throw error_already_set();
}

if (destructor != nullptr) {
destructor(ptr);
}
});

if (!m_ptr || PyCapsule_SetContext(m_ptr, reinterpret_cast<void *>(destructor)) != 0) {
throw error_already_set();
}
}
};

class tuple : public object {
Expand Down
9 changes: 9 additions & 0 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,15 @@ TEST_SUBMODULE(pytypes, m) {
});
});

m.def("return_capsule_with_destructor_3", []() {
py::print("creating capsule");
auto cap = py::capsule((void *) 1233, "oname", [](void *ptr) {
py::print("destructing capsule: {}"_s.format((size_t) ptr));
});
py::print("original name: {}"_s.format(cap.name()));
return cap;
});

m.def("return_renamed_capsule_with_destructor_2", []() {
py::print("creating capsule");
auto cap = py::capsule((void *) 1234, [](void *ptr) {
Expand Down
13 changes: 13 additions & 0 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,19 @@ def test_capsule(capture):
"""
)

with capture:
a = m.return_capsule_with_destructor_3()
del a
pytest.gc_collect()
assert (
capture.unordered
== """
creating capsule
destructing capsule: 1233
original name: oname
"""
)

with capture:
a = m.return_renamed_capsule_with_destructor_2()
del a
Expand Down

0 comments on commit 2fb3d7c

Please sign in to comment.