Skip to content

Commit

Permalink
Preliminary adaptations for PyPy
Browse files Browse the repository at this point in the history
This commit lays the foundation for eventually supporting PyPy as an
alternative interpreter besides CPython. This is is dependent on a few
remaining issues being worked out on the PyPy end:

- https://foss.heptapod.net/pypy/pypy/-/issues/3847
- https://foss.heptapod.net/pypy/pypy/-/issues/3845
- https://foss.heptapod.net/pypy/pypy/-/issues/3844

I decided to already merge this PR because the changes are also nice to
have just for the CPython part (in particular, the alternative local
implementation of `PyType_FromMetaclass` is more general/robust).
  • Loading branch information
wjakob committed Nov 9, 2022
1 parent 12f17e2 commit f935f93
Show file tree
Hide file tree
Showing 15 changed files with 219 additions and 110 deletions.
2 changes: 0 additions & 2 deletions include/nanobind/nb_defs.h
Expand Up @@ -93,14 +93,12 @@
#endif

#if PY_VERSION_HEX < 0x03090000
# define NB_INTERPRETER_STATE_GET _PyInterpreterState_Get
# define NB_TYPING_DICT "Dict"
# define NB_TYPING_LIST "List"
# define NB_TYPING_SET "Set"
# define NB_TYPING_TUPLE "Tuple"
# define NB_TYPING_TYPE "Type"
#else
# define NB_INTERPRETER_STATE_GET PyInterpreterState_Get
# define NB_TYPING_DICT "dict"
# define NB_TYPING_LIST "list"
# define NB_TYPING_SET "set"
Expand Down
6 changes: 3 additions & 3 deletions include/nanobind/nb_types.h
Expand Up @@ -342,7 +342,7 @@ class tuple : public object {
template <typename T, detail::enable_if_t<std::is_arithmetic_v<T>> = 1>
detail::accessor<detail::num_item_tuple> operator[](T key) const;

#if !defined(Py_LIMITED_API)
#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION)
detail::fast_iterator begin() const;
detail::fast_iterator end() const;
#endif
Expand All @@ -362,7 +362,7 @@ class list : public object {
template <typename T, detail::enable_if_t<std::is_arithmetic_v<T>> = 1>
detail::accessor<detail::num_item_list> operator[](T key) const;

#if !defined(Py_LIMITED_API)
#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION)
detail::fast_iterator begin() const;
detail::fast_iterator end() const;
#endif
Expand Down Expand Up @@ -615,7 +615,7 @@ NAMESPACE_END(detail)
inline detail::dict_iterator dict::begin() const { return { *this }; }
inline detail::dict_iterator dict::end() const { return { }; }

#if !defined(Py_LIMITED_API)
#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION)
inline detail::fast_iterator tuple::begin() const {
return ((PyTupleObject *) m_ptr)->ob_item;
}
Expand Down
26 changes: 20 additions & 6 deletions src/common.cpp
Expand Up @@ -131,23 +131,37 @@ PyObject *module_import(const char *name) {

PyObject *module_new_submodule(PyObject *base, const char *name,
const char *doc) noexcept {
PyObject *name_py, *res;

PyObject *base_name = PyModule_GetNameObject(base),
*name_py, *res;
#if !defined(PYPY_VERSION)
PyObject *base_name = PyModule_GetNameObject(base);
if (!base_name)
goto fail;

name_py = PyUnicode_FromFormat("%U.%s", base_name, name);
#else
const char *base_name = PyModule_GetName(base);
if (!base_name)
goto fail;

name_py = PyUnicode_FromFormat("%s.%s", base_name, name);
#endif
if (!name_py)
goto fail;

#if !defined(PYPY_VERSION)
res = PyImport_AddModuleObject(name_py);
#else
res = PyImport_AddModule(PyUnicode_AsUTF8(name_py));
#endif

if (doc) {
PyObject *doc_py = PyUnicode_FromString(doc);
if (!doc_py || PyObject_SetAttrString(res, "__doc__", doc_py))
goto fail;
Py_DECREF(doc_py);
}

Py_DECREF(name_py);
Py_DECREF(base_name);

Expand Down Expand Up @@ -222,7 +236,7 @@ PyObject *obj_vectorcall(PyObject *base, PyObject *const *args, size_t nargsf,
}
}

#if PY_VERSION_HEX < 0x03090000
#if PY_VERSION_HEX < 0x03090000 || defined(PYPY_VERSION)
if (method_call) {
PyObject *self = PyObject_GetAttr(args[0], /* name = */ base);
if (self) {
Expand Down Expand Up @@ -468,7 +482,7 @@ PyObject **seq_get(PyObject *seq, size_t *size_out, PyObject **temp_out) noexcep
goes wrong, it fails gracefully without reporting errors. Other
overloads will then be tried. */

#if !defined(Py_LIMITED_API)
#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION)
if (PyTuple_CheckExact(seq)) {
size = (size_t) PyTuple_GET_SIZE(seq);
result = ((PyTupleObject *) seq)->ob_item;
Expand Down Expand Up @@ -561,7 +575,7 @@ PyObject **seq_get_with_size(PyObject *seq, size_t size,
PyObject *temp = nullptr,
**result = nullptr;

#if !defined(Py_LIMITED_API)
#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION)
if (PyTuple_CheckExact(seq)) {
if (size == (size_t) PyTuple_GET_SIZE(seq)) {
result = ((PyTupleObject *) seq)->ob_item;
Expand Down Expand Up @@ -755,7 +769,7 @@ template <typename T>
NB_INLINE bool load_int(PyObject *o, uint32_t flags, T *out) noexcept {
if (NB_LIKELY(PyLong_CheckExact(o))) {
// Fast path for integers that aren't too large (max. one 15- or 30-bit "digit")
#if !defined(Py_LIMITED_API)
#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION)
PyLongObject *lo = (PyLongObject *) o;
int size = Py_SIZE(o);

Expand Down
2 changes: 1 addition & 1 deletion src/error.cpp
Expand Up @@ -81,7 +81,7 @@ const char *python_error::what() const noexcept {
PyErr_Clear();
}

#if defined(Py_LIMITED_API)
#if defined(Py_LIMITED_API) || defined(PYPY_VERSION)
object mod = module_::import_("traceback"),
result = mod.attr("format_exception")(handle(m_type), handle(m_value), handle(m_trace));
m_what = NB_STRDUP(borrow<str>(str("\n").attr("join")(result)).c_str());
Expand Down
10 changes: 7 additions & 3 deletions src/nb_enum.cpp
Expand Up @@ -101,7 +101,11 @@ static PyObject *nb_enum_int(PyObject *o) {
}
}

static PyObject *nb_enum_init(PyTypeObject *subtype, PyObject *args, PyObject *kwds) {
static PyObject *nb_enum_init(PyObject *, PyObject *, PyObject *) {
return 0;
}

static PyObject *nb_enum_new(PyTypeObject *subtype, PyObject *args, PyObject *kwds) {
PyObject *arg;

if (kwds || NB_TUPLE_GET_SIZE(args) != 1)
Expand Down Expand Up @@ -201,8 +205,8 @@ void nb_enum_prepare(PyType_Slot **s, bool is_arithmetic) {

/* Careful: update 'nb_enum_max_slots' field in nb_type.cpp
when adding further type slots */
*t++ = { Py_tp_new, (void *) nb_enum_init };
*t++ = { Py_tp_init, (void *) nullptr };
*t++ = { Py_tp_new, (void *) nb_enum_new };
*t++ = { Py_tp_init, (void *) nb_enum_init };
*t++ = { Py_tp_repr, (void *) nb_enum_repr };
*t++ = { Py_tp_richcompare, (void *) nb_enum_richcompare };
*t++ = { Py_nb_int, (void *) nb_enum_int };
Expand Down
9 changes: 6 additions & 3 deletions src/nb_func.cpp
Expand Up @@ -43,7 +43,6 @@ int nb_func_traverse(PyObject *self, visitproc visit, void *arg) {
if (f->flags & (uint32_t) func_flags::has_args) {
for (size_t j = 0; j < f->nargs; ++j) {
Py_VISIT(f->args[j].value);
Py_VISIT(f->args[j].name_py);
}
}
++f;
Expand All @@ -63,7 +62,6 @@ int nb_func_clear(PyObject *self) {
if (f->flags & (uint32_t) func_flags::has_args) {
for (size_t j = 0; j < f->nargs; ++j) {
Py_CLEAR(f->args[j].value);
Py_CLEAR(f->args[j].name_py);
}
}
++f;
Expand Down Expand Up @@ -538,7 +536,12 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self,
PyObject *hit = nullptr;
for (size_t j = 0; j < nkwargs_in; ++j) {
PyObject *key = NB_TUPLE_GET_ITEM(kwargs_in, j);
if (key == ad.name_py) {
#if defined(PYPY_VERSION)
bool match = PyUnicode_Compare(key, ad.name_py) == 0;
#else
bool match = (key == ad.name_py);
#endif
if (match) {
hit = args_in[nargs_in + j];
kwarg_used[j] = true;
break;
Expand Down
26 changes: 20 additions & 6 deletions src/nb_internals.cpp
Expand Up @@ -317,6 +317,7 @@ void default_exception_translator(const std::exception_ptr &p, void *) {
}
}

#if !defined(PYPY_VERSION)
static void internals_cleanup() {
bool leak = false;

Expand Down Expand Up @@ -373,15 +374,28 @@ static void internals_cleanup() {
#endif
}
}
#endif

static PyObject *internals_dict() {
#if defined(PYPY_VERSION)
PyObject *dict = PyEval_GetBuiltins();
#elif PY_VERSION_HEX < 0x03090000
PyObject *dict = PyInterpreterState_GetDict(_PyInterpreterState_Get());
#else
PyObject *dict = PyInterpreterState_GetDict(PyInterpreterState_Get());
#endif
if (!dict)
fail("nanobind::detail::internals_dict(): failed!");

return dict;
}

static void internals_make() {
str nb_name("nanobind");

internals_p = new nb_internals();

PyObject *dict = PyInterpreterState_GetDict(NB_INTERPRETER_STATE_GET());
if (!dict)
fail("nanobind::detail::internals_make(): PyInterpreterState_GetDict() failed!");
PyObject *dict = internals_dict();

const char *internals_id = NB_INTERNALS_ID;
PyObject *capsule = PyCapsule_New(internals_p, internals_id, nullptr);
Expand Down Expand Up @@ -495,6 +509,7 @@ static void internals_make() {
PyErr_Clear();
}

#if !defined(PYPY_VERSION)
// Install the memory leak checker
if (Py_AtExit(internals_cleanup))
fprintf(stderr,
Expand All @@ -503,12 +518,11 @@ static void internals_make() {
"resources at interpreter shutdown (e.g., to avoid leaks being "
"reported by tools like 'valgrind'). If you are a user of a "
"python extension library, you can ignore this warning.");
#endif
}

static void internals_fetch() {
PyObject *dict = PyInterpreterState_GetDict(NB_INTERPRETER_STATE_GET());
if (!dict)
fail("nanobind::detail::internals_fetch(): PyInterpreterState_GetDict() failed!");
PyObject *dict = internals_dict();

const char *internals_id = NB_INTERNALS_ID;
PyObject *capsule = PyDict_GetItemString(dict, internals_id);
Expand Down
96 changes: 75 additions & 21 deletions src/nb_type.cpp
Expand Up @@ -467,10 +467,14 @@ PyObject *nb_type_new(const type_data *t) noexcept {
PyHeapTypeObject *temp_ht = (PyHeapTypeObject *) temp;
PyTypeObject *temp_tp = &temp_ht->ht_type;

Py_INCREF(temp_ht->ht_name);
Py_INCREF(temp_ht->ht_qualname);
Py_INCREF(temp_tp->tp_base);
Py_INCREF (temp_ht->ht_name);
Py_INCREF (temp_ht->ht_qualname);
Py_XINCREF(temp_ht->ht_slots);
Py_INCREF (temp_tp->tp_base);

#if PY_VERSION_HEX >= 0x03090000
Py_XINCREF(temp_ht->ht_module);
#endif

char *tp_doc = nullptr;
if (temp_tp->tp_doc) {
Expand All @@ -487,36 +491,84 @@ PyObject *nb_type_new(const type_data *t) noexcept {
PyHeapTypeObject *ht = (PyHeapTypeObject *) result;
PyTypeObject *tp = &ht->ht_type;

memcpy(ht, temp_ht, sizeof(PyHeapTypeObject));
ht->ht_name = temp_ht->ht_name;
ht->ht_qualname = temp_ht->ht_qualname;
ht->ht_slots = temp_ht->ht_slots;

#if PY_VERSION_HEX >= 0x03090000
ht->ht_module = temp_ht->ht_module;
#endif

tp->ob_base.ob_base.ob_type = metaclass;
tp->ob_base.ob_base.ob_refcnt = 1;
tp->ob_base.ob_size = 0;
tp->tp_as_async = &ht->as_async;
tp->tp_as_number = &ht->as_number;
tp->tp_as_sequence = &ht->as_sequence;
tp->tp_as_mapping = &ht->as_mapping;
tp->tp_as_buffer = &ht->as_buffer;
tp->tp_name = name_copy;
tp->tp_doc = tp_doc;
tp->tp_flags = spec.flags | Py_TPFLAGS_HEAPTYPE;

if (temp_tp->tp_flags & Py_TPFLAGS_HAVE_GC)
tp->tp_flags |= Py_TPFLAGS_HAVE_GC;

/* The following fields remain intentionally null-initialized
following the call to PyType_GenericAlloc(): tp_dict, tp_bases, tp_mro,
tp_cache, tp_subclasses, tp_weaklist. */

#define COPY_FIELD(name) \
tp->name = temp_tp->name;

COPY_FIELD(tp_basicsize);
COPY_FIELD(tp_itemsize);
COPY_FIELD(tp_dealloc);
COPY_FIELD(tp_vectorcall_offset);
COPY_FIELD(tp_getattr);
COPY_FIELD(tp_setattr);
COPY_FIELD(tp_repr);
COPY_FIELD(tp_hash);
COPY_FIELD(tp_call);
COPY_FIELD(tp_str);
COPY_FIELD(tp_getattro);
COPY_FIELD(tp_setattro);
COPY_FIELD(tp_traverse);
COPY_FIELD(tp_clear);
COPY_FIELD(tp_richcompare);
COPY_FIELD(tp_weaklistoffset);
COPY_FIELD(tp_iter);
COPY_FIELD(tp_iternext);
COPY_FIELD(tp_methods);
COPY_FIELD(tp_members);
COPY_FIELD(tp_getset);
COPY_FIELD(tp_base);
COPY_FIELD(tp_descr_get);
COPY_FIELD(tp_descr_set);
COPY_FIELD(tp_dictoffset);
COPY_FIELD(tp_init);
COPY_FIELD(tp_alloc);
COPY_FIELD(tp_new);
COPY_FIELD(tp_free);
COPY_FIELD(tp_is_gc);
COPY_FIELD(tp_del);
COPY_FIELD(tp_finalize);
COPY_FIELD(tp_vectorcall);

#undef COPY_FIELD

ht->as_async = temp_ht->as_async;
tp->tp_as_async = &ht->as_async;

ht->as_number = temp_ht->as_number;
tp->tp_as_number = &ht->as_number;

ht->as_sequence = temp_ht->as_sequence;
tp->tp_as_sequence = &ht->as_sequence;

ht->as_mapping = temp_ht->as_mapping;
tp->tp_as_mapping = &ht->as_mapping;

ht->as_buffer = temp_ht->as_buffer;
tp->tp_as_buffer = &ht->as_buffer;

#if PY_VERSION_HEX < 0x03090000
if (has_dynamic_attr)
tp->tp_dictoffset = (Py_ssize_t) (basicsize - ptr_size);
#endif

tp->tp_dict = tp->tp_bases = tp->tp_mro = tp->tp_cache = nullptr;
tp->tp_subclasses = tp->tp_weaklist = nullptr;
ht->ht_cached_keys = nullptr;
tp->tp_version_tag = 0;

#if PY_VERSION_HEX >= 0x030B0000
ht->_ht_tpname = nullptr;
#endif

PyType_Ready(tp);
Py_DECREF(temp);
#endif
Expand Down Expand Up @@ -1140,6 +1192,7 @@ type_data *nb_type_data_static(PyTypeObject *o) noexcept {
PyObject *nb_type_name(PyTypeObject *tp) noexcept {
PyObject *name = PyObject_GetAttrString((PyObject *) tp, "__name__");

#if !defined(PYPY_VERSION)
if (PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE)) {
PyObject *mod = PyObject_GetAttrString((PyObject *) tp, "__module__"),
*combined = PyUnicode_FromFormat("%U.%U", mod, name);
Expand All @@ -1148,6 +1201,7 @@ PyObject *nb_type_name(PyTypeObject *tp) noexcept {
Py_DECREF(name);
name = combined;
}
#endif

return name;
}
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Expand Up @@ -35,6 +35,7 @@ target_link_libraries(test_inter_module_1_ext PRIVATE inter_module)
target_link_libraries(test_inter_module_2_ext PRIVATE inter_module)

set(TEST_FILES
common.py
test_functions.py
test_classes.py
test_holders.py
Expand Down

0 comments on commit f935f93

Please sign in to comment.