Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix inheritance of
Py_TPFLAGS_HAVE_GC
flag in class construction
The ``nb_type_new()`` function of nanobind assumed that classes created via ``PyType_Read()`` / ``PyType_FromMetaclass()`` would inherit the value of the ``Py_TPFLAGS_HAVE_GC`` flag from the base class, but this was only true when the new class did not at the same time provide its own ``tp_traverse``/``tp_clear`` functions. An implication of this mistake was that classes deriving from classes with the ``nb::dynamic_attr`` were broken and would cause segfaults. Fixes issue #279.
- Loading branch information
Showing
4 changed files
with
76 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#include <nanobind/stl/shared_ptr.h> | ||
#include <nanobind/stl/string.h> | ||
#include <unordered_map> | ||
|
||
namespace nb = nanobind; | ||
using namespace nb::literals; | ||
|
||
|
||
NB_MODULE(test_issue_ext, m) { | ||
// ------------------------------------ | ||
// issue #279: dynamic_attr broken | ||
// ------------------------------------ | ||
|
||
struct Component { | ||
virtual ~Component() = default; | ||
}; | ||
|
||
struct Param : Component { }; | ||
|
||
struct Model : Component { | ||
void add_param(const std::string &name, std::shared_ptr<Param> p) { | ||
params_[name] = std::move(p); | ||
} | ||
|
||
std::shared_ptr<Param> get_param(const std::string &name) { | ||
return params_.find(name) != params_.end() ? params_[name] : nullptr; | ||
} | ||
|
||
std::unordered_map<std::string, std::shared_ptr<Param>> params_; | ||
}; | ||
|
||
struct ModelA : Model { | ||
ModelA() { | ||
add_param("a", std::make_shared<Param>()); | ||
add_param("b", std::make_shared<Param>()); | ||
} | ||
}; | ||
|
||
nb::class_<Component>(m, "Component"); | ||
nb::class_<Param, Component>(m, "ParamBase"); | ||
nb::class_<Model, Component>(m, "Model", nb::dynamic_attr()).def(nb::init<>{}) | ||
.def("_get_param", &Model::get_param, "name"_a) | ||
.def("_add_param", &Model::add_param, "name"_a, "p"_a); | ||
nb::class_<ModelA, Model>(m, "ModelA").def(nb::init<>{}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import test_issue_ext as m | ||
import pytest | ||
|
||
# Issue #279: dynamic_attr broken | ||
@pytest.mark.parametrize("variant", [1, 2]) | ||
def test01_issue_279(variant): | ||
def _get_parameter(self: m.Model, key: str): | ||
p = self._get_param(key) | ||
if p is not None: # cache it for fast access later | ||
setattr(self, key, p) | ||
return p | ||
raise AttributeError(f"'key' not found in {self}") | ||
|
||
m.Model.__getattr__ = _get_parameter | ||
|
||
if variant == 2: | ||
def _print_model(self): | ||
return f"{self.__class__.__qualname__}()" | ||
|
||
m.Model.__str__ = _print_model | ||
|
||
class Top(m.Model): | ||
def __init__(self): | ||
super().__init__() | ||
self.model_a = m.ModelA() | ||
|
||
top = Top() | ||
str(top.model_a) | ||
str(top.model_a.a) |