Skip to content

Commit

Permalink
Fix inheritance of Py_TPFLAGS_HAVE_GC flag in class construction
Browse files Browse the repository at this point in the history
The ``nb_type_new()`` function of nanobind assumed that classes created
via ``PyType_Read()`` / ``PyType_FromMetaclass()`` would inherit the
value of the ``Py_TPFLAGS_HAVE_GC`` flag from the base class, but this
was only true when the new class did not at the same time provide its
own ``tp_traverse``/``tp_clear`` functions.

An implication of this mistake was that classes deriving from classes
with the ``nb::dynamic_attr`` were broken and would cause segfaults.

Fixes issue #279.
  • Loading branch information
wjakob committed Aug 24, 2023
1 parent ed929b7 commit dbedadc
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 2 deletions.
3 changes: 1 addition & 2 deletions src/nb_type.cpp
Expand Up @@ -870,8 +870,7 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
spec.basicsize = (int) basicsize;
}

if (has_traverse && (!base || (PyType_GetFlags((PyTypeObject *) base) &
Py_TPFLAGS_HAVE_GC) == 0))
if (has_traverse)
spec.flags |= Py_TPFLAGS_HAVE_GC;

*s++ = { 0, nullptr };
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Expand Up @@ -33,6 +33,7 @@ nanobind_add_module(test_ndarray_ext test_ndarray.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_intrusive_ext test_intrusive.cpp object.cpp object.h ${NB_EXTRA_ARGS})
nanobind_add_module(test_exception_ext test_exception.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_make_iterator_ext test_make_iterator.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_issue_ext test_issue.cpp ${NB_EXTRA_ARGS})

find_package (Eigen3 3.3.1 NO_MODULE)
if (TARGET Eigen3::Eigen)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_issue.cpp
@@ -0,0 +1,45 @@
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <unordered_map>

namespace nb = nanobind;
using namespace nb::literals;


NB_MODULE(test_issue_ext, m) {
// ------------------------------------
// issue #279: dynamic_attr broken
// ------------------------------------

struct Component {
virtual ~Component() = default;
};

struct Param : Component { };

struct Model : Component {
void add_param(const std::string &name, std::shared_ptr<Param> p) {
params_[name] = std::move(p);
}

std::shared_ptr<Param> get_param(const std::string &name) {
return params_.find(name) != params_.end() ? params_[name] : nullptr;
}

std::unordered_map<std::string, std::shared_ptr<Param>> params_;
};

struct ModelA : Model {
ModelA() {
add_param("a", std::make_shared<Param>());
add_param("b", std::make_shared<Param>());
}
};

nb::class_<Component>(m, "Component");
nb::class_<Param, Component>(m, "ParamBase");
nb::class_<Model, Component>(m, "Model", nb::dynamic_attr()).def(nb::init<>{})
.def("_get_param", &Model::get_param, "name"_a)
.def("_add_param", &Model::add_param, "name"_a, "p"_a);
nb::class_<ModelA, Model>(m, "ModelA").def(nb::init<>{});
}
29 changes: 29 additions & 0 deletions tests/test_issue.py
@@ -0,0 +1,29 @@
import test_issue_ext as m
import pytest

# Issue #279: dynamic_attr broken
@pytest.mark.parametrize("variant", [1, 2])
def test01_issue_279(variant):
def _get_parameter(self: m.Model, key: str):
p = self._get_param(key)
if p is not None: # cache it for fast access later
setattr(self, key, p)
return p
raise AttributeError(f"'key' not found in {self}")

m.Model.__getattr__ = _get_parameter

if variant == 2:
def _print_model(self):
return f"{self.__class__.__qualname__}()"

m.Model.__str__ = _print_model

class Top(m.Model):
def __init__(self):
super().__init__()
self.model_a = m.ModelA()

top = Top()
str(top.model_a)
str(top.model_a.a)

0 comments on commit dbedadc

Please sign in to comment.