Skip to content

Commit

Permalink
Make instances weak-referenceable (#335)
Browse files Browse the repository at this point in the history
The logic closely follows that of internal attribute dictionaries and involves an additional pointer to store a weak reference list.
  • Loading branch information
huangweiwu authored and wjakob committed Feb 18, 2024
1 parent 21eaffc commit fc77093
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 26 deletions.
5 changes: 5 additions & 0 deletions docs/api_core.rst
Expand Up @@ -1685,6 +1685,11 @@ parameter of the constructor :cpp:func:`class_::class_`.

Indicate that instances of a type require a Python dictionary to support the dynamic addition of attributes.

.. cpp:struct:: weak_referenceable

Indicate that instances of a type require weak reference list so that they
can be referenced by the Python ``weakref`` type.

.. cpp:struct:: template <typename T> supplement

Indicate that ``sizeof(T)`` bytes of memory should be set aside to
Expand Down
5 changes: 5 additions & 0 deletions docs/changelog.rst
Expand Up @@ -18,6 +18,11 @@ below inherit that of the preceding release.
Version 1.9.0 (TBA)
-------------------

* Nanobind instances can now be :ref:`made weak-referenceable <weak_refs>` by
specifying the :cpp:class:`nb::weak_referenceable <weak_referenceable>` tag
in the :cpp:class:`nb::class_\<..\> <class_>` constructor. (PR `#335
<https://github.com/wjakob/nanobind/pull/335>`__.)

* :cpp:func:`nb::try_cast() <try_cast>` no longer crashes the interpreter
when attempting to cast a Python ``None`` to a C++ type that was bound
using ``nb::class_<>``. Previously this would raise an exception from the
Expand Down
16 changes: 16 additions & 0 deletions docs/classes.rst
Expand Up @@ -365,6 +365,22 @@ default, so this is not anything to worry about. By default, nanobind classes
are more efficient than native Python classes. Enabling dynamic attributes just
brings them on par.

.. _weak_refs:

Weak references
---------------

By default, nanobind instances cannot be referenced via Python's ``weakref``
class, and attempting to do so will raise an exception.

To support this, add the :class:`nb::is_weak_referenceable` tag to the
:class:`nb::class_` constructor. Note that this will increase the size of every
instance by ``sizeof(void*)`` due to the need to store a weak reference list.

.. code-block:: cpp
nb::class_<Pet>(m, "Pet", nb::is_weak_referenceable());
.. _inheriting_in_python:

Extending C++ classes in Python
Expand Down
1 change: 1 addition & 0 deletions include/nanobind/nb_attr.h
Expand Up @@ -48,6 +48,7 @@ template <typename... Ts> struct call_guard {
};

struct dynamic_attr {};
struct weak_referenceable {};
struct is_method {};
struct is_implicit {};
struct is_operator {};
Expand Down
8 changes: 8 additions & 0 deletions include/nanobind/nb_class.h
Expand Up @@ -49,6 +49,9 @@ enum class type_flags : uint32_t {
/// If so, type_data::keep_shared_from_this_alive is also set.
has_shared_from_this = (1 << 12),

/// Instances of this type can be referenced by 'weakref'
is_weak_referenceable = (1 << 13),

// Six more flag bits available (13 through 18) without needing
// a larger reorganization
};
Expand Down Expand Up @@ -98,6 +101,7 @@ struct type_data {
bool (*keep_shared_from_this_alive)(PyObject *) noexcept;
#if defined(Py_LIMITED_API)
size_t dictoffset;
size_t weaklistoffset;
#endif
};

Expand Down Expand Up @@ -152,6 +156,10 @@ NB_INLINE void type_extra_apply(type_init_data &t, dynamic_attr) {
t.flags |= (uint32_t) type_flags::has_dynamic_attr;
}

NB_INLINE void type_extra_apply(type_data & t, weak_referenceable) {
t.flags |= (uint32_t)type_flags::is_weak_referenceable;
}

template <typename T>
NB_INLINE void type_extra_apply(type_init_data &t, supplement<T>) {
static_assert(std::is_trivially_default_constructible_v<T>,
Expand Down
111 changes: 85 additions & 26 deletions src/nb_type.cpp
Expand Up @@ -18,23 +18,35 @@ NAMESPACE_BEGIN(detail)

static PyObject **nb_dict_ptr(PyObject *self) {
PyTypeObject *tp = Py_TYPE(self);
#if !defined(Py_LIMITED_API)
return (PyObject **) ((uint8_t *) self + tp->tp_dictoffset);
#if defined(Py_LIMITED_API)
Py_ssize_t dictoffset = nb_type_data(tp)->dictoffset;
#else
return (PyObject **) ((uint8_t *) self + nb_type_data(tp)->dictoffset);
Py_ssize_t dictoffset = tp->tp_dictoffset;
#endif
return dictoffset ? (PyObject **) ((uint8_t *) self + dictoffset) : nullptr;
}

static PyObject **nb_weaklist_ptr(PyObject *self) {
PyTypeObject *tp = Py_TYPE(self);
#if defined(Py_LIMITED_API)
Py_ssize_t weaklistoffset = nb_type_data(tp)->weaklistoffset;
#else
Py_ssize_t weaklistoffset = tp->tp_weaklistoffset;
#endif
return weaklistoffset ? (PyObject **) ((uint8_t *) self + weaklistoffset) : nullptr;
}

static int inst_clear(PyObject *self) {
PyObject *&dict = *nb_dict_ptr(self);
Py_CLEAR(dict);
PyObject **dict = nb_dict_ptr(self);
if (dict)
Py_CLEAR(*dict);
return 0;
}

static int inst_traverse(PyObject *self, visitproc visit, void *arg) {
PyObject *&dict = *nb_dict_ptr(self);
PyObject **dict = nb_dict_ptr(self);
if (dict)
Py_VISIT(dict);
Py_VISIT(*dict);
#if PY_VERSION_HEX >= 0x03090000
Py_VISIT(Py_TYPE(self));
#endif
Expand Down Expand Up @@ -183,12 +195,24 @@ static void inst_dealloc(PyObject *self) {
if (NB_UNLIKELY(gc)) {
PyObject_GC_UnTrack(self);

if (t->flags & (uint32_t) type_flags::has_dynamic_attr) {
PyObject *&dict = *nb_dict_ptr(self);
Py_CLEAR(dict);
if (t->flags & (uint32_t)type_flags::has_dynamic_attr) {
PyObject **dict = nb_dict_ptr(self);
if (dict)
Py_CLEAR(*dict);
}
}

if (t->flags & (uint32_t)type_flags::is_weak_referenceable &&
nb_weaklist_ptr(self) != nullptr) {
#if defined(PYPY_VERSION)
PyObject **weaklist = nb_weaklist_ptr(self);
if (weaklist)
Py_CLEAR(*weaklist);
#else
PyObject_ClearWeakRefs(self);
#endif
}

nb_inst *inst = (nb_inst *) self;
void *p = inst_ptr(inst);

Expand Down Expand Up @@ -765,14 +789,15 @@ static PyTypeObject *nb_type_tp(size_t supplement) noexcept {

/// Called when a C++ type is bound via nb::class_<>
PyObject *nb_type_new(const type_init_data *t) noexcept {
bool has_doc = t->flags & (uint32_t) type_init_flags::has_doc,
has_base = t->flags & (uint32_t) type_init_flags::has_base,
has_base_py = t->flags & (uint32_t) type_init_flags::has_base_py,
has_type_slots = t->flags & (uint32_t) type_init_flags::has_type_slots,
has_supplement = t->flags & (uint32_t) type_init_flags::has_supplement,
has_dynamic_attr = t->flags & (uint32_t) type_flags::has_dynamic_attr,
intrusive_ptr = t->flags & (uint32_t) type_flags::intrusive_ptr,
has_shared_from_this = t->flags & (uint32_t) type_flags::has_shared_from_this;
bool has_doc = t->flags & (uint32_t) type_init_flags::has_doc,
has_base = t->flags & (uint32_t) type_init_flags::has_base,
has_base_py = t->flags & (uint32_t) type_init_flags::has_base_py,
has_type_slots = t->flags & (uint32_t) type_init_flags::has_type_slots,
has_supplement = t->flags & (uint32_t) type_init_flags::has_supplement,
has_dynamic_attr = t->flags & (uint32_t) type_flags::has_dynamic_attr,
is_weak_referenceable = t->flags & (uint32_t) type_flags::is_weak_referenceable,
intrusive_ptr = t->flags & (uint32_t) type_flags::intrusive_ptr,
has_shared_from_this = t->flags & (uint32_t) type_flags::has_shared_from_this;

str name(t->name), qualname = name;
object modname;
Expand Down Expand Up @@ -834,6 +859,9 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
if (tb->flags & (uint32_t) type_flags::has_dynamic_attr)
has_dynamic_attr = true;

if (tb->flags & (uint32_t) type_flags::is_weak_referenceable)
is_weak_referenceable = true;

/* Handle a corner case (base class larger than derived class)
which can arise when extending trampoline base classes */
size_t base_basicsize = sizeof(nb_inst) + tb->size;
Expand All @@ -853,7 +881,7 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
nb_total_slots = nb_type_max_slots +
nb_extra_slots + 1;

PyMemberDef members[2] { };
PyMemberDef members[3] { };
PyType_Slot slots[nb_total_slots], *s = slots;
PyType_Spec spec = {
/* .name = */ name_copy,
Expand Down Expand Up @@ -898,26 +926,50 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
for (PyType_Slot *ts = slots; ts != s; ++ts)
has_traverse |= ts->slot == Py_tp_traverse;

if (has_dynamic_attr) {
// realign to sizeof(void*), add one pointer
Py_ssize_t dictoffset = 0, weaklistoffset = 0;
int num_members = 0;

// realign to sizeof(void*) if needed
if (has_dynamic_attr || is_weak_referenceable)
basicsize = (basicsize + ptr_size - 1) / ptr_size * ptr_size;

if (has_dynamic_attr) {
dictoffset = (Py_ssize_t) basicsize;
basicsize += ptr_size;

members[0] = PyMemberDef{ "__dictoffset__", T_PYSSIZET,
(Py_ssize_t) (basicsize - ptr_size), READONLY,
nullptr };
*s++ = { Py_tp_members, (void *) members };
members[num_members] = PyMemberDef{ "__dictoffset__", T_PYSSIZET,
dictoffset, READONLY, nullptr };
++num_members;

// Install GC traverse and clear routines if not inherited/overridden
if (!has_traverse) {
*s++ = { Py_tp_traverse, (void *) inst_traverse };
*s++ = { Py_tp_clear, (void *) inst_clear };
has_traverse = true;
}
spec.basicsize = (int) basicsize;
}

if (is_weak_referenceable) {
weaklistoffset = (Py_ssize_t) basicsize;
basicsize += ptr_size;

members[num_members] = PyMemberDef{ "__weaklistoffset__", T_PYSSIZET,
weaklistoffset, READONLY, nullptr };
++num_members;

// Install GC traverse and clear routines if not inherited/overridden
if (!has_traverse) {
*s++ = { Py_tp_traverse, (void *) inst_traverse };
*s++ = { Py_tp_clear, (void *) inst_clear };
has_traverse = true;
}
spec.basicsize = (int) basicsize;
}

if (num_members > 0)
*s++ = { Py_tp_members, (void*)members };

if (has_traverse)
spec.flags |= Py_TPFLAGS_HAVE_GC;

Expand Down Expand Up @@ -955,7 +1007,14 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
if (has_dynamic_attr) {
to->flags |= (uint32_t) type_flags::has_dynamic_attr;
#if defined(Py_LIMITED_API)
to->dictoffset = (size_t) (basicsize - ptr_size);
to->dictoffset = dictoffset;
#endif
}

if (is_weak_referenceable) {
to->flags |= (uint32_t)type_flags::is_weak_referenceable;
#if defined(Py_LIMITED_API)
to->weaklistoffset = weaklistoffset;
#endif
}

Expand Down
11 changes: 11 additions & 0 deletions tests/test_classes.cpp
Expand Up @@ -90,6 +90,10 @@ struct Wrapper {
std::shared_ptr<Wrapper> value;
};

struct StructWithWeakrefs : Struct { };

struct StructWithWeakrefsAndDynamicAttrs : Struct { };

int wrapper_tp_traverse(PyObject *self, visitproc visit, void *arg) {
Wrapper *w = nb::inst_ptr<Wrapper>(self);

Expand Down Expand Up @@ -554,4 +558,11 @@ NB_MODULE(test_classes_ext, m) {
"get_incrementing_struct_value",
[](IncrementingStruct &s) { return new Struct(s.i + 100); },
nb::keep_alive<0, 1>());

nb::class_<StructWithWeakrefs, Struct>(m, "StructWithWeakrefs", nb::weak_referenceable())
.def(nb::init<int>());

nb::class_<StructWithWeakrefsAndDynamicAttrs, Struct>(m, "StructWithWeakrefsAndDynamicAttrs",
nb::weak_referenceable(), nb::dynamic_attr())
.def(nb::init<int>());
}
22 changes: 22 additions & 0 deletions tests/test_classes.py
Expand Up @@ -759,3 +759,25 @@ def test41_implicit_conversion_keep_alive():
assert d1 == []
assert d2 == [5]
assert d3 == [106, 6]

def test42_weak_references():
import weakref
import gc
import time
o = t.StructWithWeakrefs(42)
w = weakref.ref(o)
assert w() is o
del o
gc.collect()
gc.collect()
assert w() is None

p = t.StructWithWeakrefsAndDynamicAttrs(43)
p.a_dynamic_attr = 101
w = weakref.ref(p)
assert w() is p
assert w().a_dynamic_attr == 101
del p
gc.collect()
gc.collect()
assert w() is None

0 comments on commit fc77093

Please sign in to comment.