From f935f93b9d532a5ef1f385445f328d61eb2af97f Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Wed, 9 Nov 2022 21:26:55 +0100 Subject: [PATCH] Preliminary adaptations for PyPy This commit lays the foundation for eventually supporting PyPy as an alternative interpreter besides CPython. This is is dependent on a few remaining issues being worked out on the PyPy end: - https://foss.heptapod.net/pypy/pypy/-/issues/3847 - https://foss.heptapod.net/pypy/pypy/-/issues/3845 - https://foss.heptapod.net/pypy/pypy/-/issues/3844 I decided to already merge this PR because the changes are also nice to have just for the CPython part (in particular, the alternative local implementation of `PyType_FromMetaclass` is more general/robust). --- include/nanobind/nb_defs.h | 2 - include/nanobind/nb_types.h | 6 +-- src/common.cpp | 26 +++++++--- src/error.cpp | 2 +- src/nb_enum.cpp | 10 ++-- src/nb_func.cpp | 9 ++-- src/nb_internals.cpp | 26 +++++++--- src/nb_type.cpp | 96 +++++++++++++++++++++++++++++-------- tests/CMakeLists.txt | 1 + tests/common.py | 15 ++++++ tests/test_classes.py | 48 +++++++++++-------- tests/test_holders.py | 25 +++++----- tests/test_intrusive.py | 17 +++---- tests/test_stl.py | 15 +++--- tests/test_tensor.py | 31 ++++++------ 15 files changed, 219 insertions(+), 110 deletions(-) create mode 100644 tests/common.py diff --git a/include/nanobind/nb_defs.h b/include/nanobind/nb_defs.h index a997aef0..9b966d7d 100644 --- a/include/nanobind/nb_defs.h +++ b/include/nanobind/nb_defs.h @@ -93,14 +93,12 @@ #endif #if PY_VERSION_HEX < 0x03090000 -# define NB_INTERPRETER_STATE_GET _PyInterpreterState_Get # define NB_TYPING_DICT "Dict" # define NB_TYPING_LIST "List" # define NB_TYPING_SET "Set" # define NB_TYPING_TUPLE "Tuple" # define NB_TYPING_TYPE "Type" #else -# define NB_INTERPRETER_STATE_GET PyInterpreterState_Get # define NB_TYPING_DICT "dict" # define NB_TYPING_LIST "list" # define NB_TYPING_SET "set" diff --git a/include/nanobind/nb_types.h b/include/nanobind/nb_types.h index 1cbb9832..647b6b5a 100644 --- a/include/nanobind/nb_types.h +++ b/include/nanobind/nb_types.h @@ -342,7 +342,7 @@ class tuple : public object { template > = 1> detail::accessor operator[](T key) const; -#if !defined(Py_LIMITED_API) +#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION) detail::fast_iterator begin() const; detail::fast_iterator end() const; #endif @@ -362,7 +362,7 @@ class list : public object { template > = 1> detail::accessor operator[](T key) const; -#if !defined(Py_LIMITED_API) +#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION) detail::fast_iterator begin() const; detail::fast_iterator end() const; #endif @@ -615,7 +615,7 @@ NAMESPACE_END(detail) inline detail::dict_iterator dict::begin() const { return { *this }; } inline detail::dict_iterator dict::end() const { return { }; } -#if !defined(Py_LIMITED_API) +#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION) inline detail::fast_iterator tuple::begin() const { return ((PyTupleObject *) m_ptr)->ob_item; } diff --git a/src/common.cpp b/src/common.cpp index e9539ba8..15f2e0a9 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -131,23 +131,37 @@ PyObject *module_import(const char *name) { PyObject *module_new_submodule(PyObject *base, const char *name, const char *doc) noexcept { + PyObject *name_py, *res; - PyObject *base_name = PyModule_GetNameObject(base), - *name_py, *res; +#if !defined(PYPY_VERSION) + PyObject *base_name = PyModule_GetNameObject(base); if (!base_name) goto fail; name_py = PyUnicode_FromFormat("%U.%s", base_name, name); +#else + const char *base_name = PyModule_GetName(base); + if (!base_name) + goto fail; + + name_py = PyUnicode_FromFormat("%s.%s", base_name, name); +#endif if (!name_py) goto fail; +#if !defined(PYPY_VERSION) res = PyImport_AddModuleObject(name_py); +#else + res = PyImport_AddModule(PyUnicode_AsUTF8(name_py)); +#endif + if (doc) { PyObject *doc_py = PyUnicode_FromString(doc); if (!doc_py || PyObject_SetAttrString(res, "__doc__", doc_py)) goto fail; Py_DECREF(doc_py); } + Py_DECREF(name_py); Py_DECREF(base_name); @@ -222,7 +236,7 @@ PyObject *obj_vectorcall(PyObject *base, PyObject *const *args, size_t nargsf, } } -#if PY_VERSION_HEX < 0x03090000 +#if PY_VERSION_HEX < 0x03090000 || defined(PYPY_VERSION) if (method_call) { PyObject *self = PyObject_GetAttr(args[0], /* name = */ base); if (self) { @@ -468,7 +482,7 @@ PyObject **seq_get(PyObject *seq, size_t *size_out, PyObject **temp_out) noexcep goes wrong, it fails gracefully without reporting errors. Other overloads will then be tried. */ -#if !defined(Py_LIMITED_API) +#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION) if (PyTuple_CheckExact(seq)) { size = (size_t) PyTuple_GET_SIZE(seq); result = ((PyTupleObject *) seq)->ob_item; @@ -561,7 +575,7 @@ PyObject **seq_get_with_size(PyObject *seq, size_t size, PyObject *temp = nullptr, **result = nullptr; -#if !defined(Py_LIMITED_API) +#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION) if (PyTuple_CheckExact(seq)) { if (size == (size_t) PyTuple_GET_SIZE(seq)) { result = ((PyTupleObject *) seq)->ob_item; @@ -755,7 +769,7 @@ template NB_INLINE bool load_int(PyObject *o, uint32_t flags, T *out) noexcept { if (NB_LIKELY(PyLong_CheckExact(o))) { // Fast path for integers that aren't too large (max. one 15- or 30-bit "digit") - #if !defined(Py_LIMITED_API) + #if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION) PyLongObject *lo = (PyLongObject *) o; int size = Py_SIZE(o); diff --git a/src/error.cpp b/src/error.cpp index 1b01a451..dd74bf6d 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -81,7 +81,7 @@ const char *python_error::what() const noexcept { PyErr_Clear(); } -#if defined(Py_LIMITED_API) +#if defined(Py_LIMITED_API) || defined(PYPY_VERSION) object mod = module_::import_("traceback"), result = mod.attr("format_exception")(handle(m_type), handle(m_value), handle(m_trace)); m_what = NB_STRDUP(borrow(str("\n").attr("join")(result)).c_str()); diff --git a/src/nb_enum.cpp b/src/nb_enum.cpp index 4623c90f..09fd4ce9 100644 --- a/src/nb_enum.cpp +++ b/src/nb_enum.cpp @@ -101,7 +101,11 @@ static PyObject *nb_enum_int(PyObject *o) { } } -static PyObject *nb_enum_init(PyTypeObject *subtype, PyObject *args, PyObject *kwds) { +static PyObject *nb_enum_init(PyObject *, PyObject *, PyObject *) { + return 0; +} + +static PyObject *nb_enum_new(PyTypeObject *subtype, PyObject *args, PyObject *kwds) { PyObject *arg; if (kwds || NB_TUPLE_GET_SIZE(args) != 1) @@ -201,8 +205,8 @@ void nb_enum_prepare(PyType_Slot **s, bool is_arithmetic) { /* Careful: update 'nb_enum_max_slots' field in nb_type.cpp when adding further type slots */ - *t++ = { Py_tp_new, (void *) nb_enum_init }; - *t++ = { Py_tp_init, (void *) nullptr }; + *t++ = { Py_tp_new, (void *) nb_enum_new }; + *t++ = { Py_tp_init, (void *) nb_enum_init }; *t++ = { Py_tp_repr, (void *) nb_enum_repr }; *t++ = { Py_tp_richcompare, (void *) nb_enum_richcompare }; *t++ = { Py_nb_int, (void *) nb_enum_int }; diff --git a/src/nb_func.cpp b/src/nb_func.cpp index 89081963..855a6b6b 100644 --- a/src/nb_func.cpp +++ b/src/nb_func.cpp @@ -43,7 +43,6 @@ int nb_func_traverse(PyObject *self, visitproc visit, void *arg) { if (f->flags & (uint32_t) func_flags::has_args) { for (size_t j = 0; j < f->nargs; ++j) { Py_VISIT(f->args[j].value); - Py_VISIT(f->args[j].name_py); } } ++f; @@ -63,7 +62,6 @@ int nb_func_clear(PyObject *self) { if (f->flags & (uint32_t) func_flags::has_args) { for (size_t j = 0; j < f->nargs; ++j) { Py_CLEAR(f->args[j].value); - Py_CLEAR(f->args[j].name_py); } } ++f; @@ -538,7 +536,12 @@ static PyObject *nb_func_vectorcall_complex(PyObject *self, PyObject *hit = nullptr; for (size_t j = 0; j < nkwargs_in; ++j) { PyObject *key = NB_TUPLE_GET_ITEM(kwargs_in, j); - if (key == ad.name_py) { + #if defined(PYPY_VERSION) + bool match = PyUnicode_Compare(key, ad.name_py) == 0; + #else + bool match = (key == ad.name_py); + #endif + if (match) { hit = args_in[nargs_in + j]; kwarg_used[j] = true; break; diff --git a/src/nb_internals.cpp b/src/nb_internals.cpp index 28b3833d..c56077a6 100644 --- a/src/nb_internals.cpp +++ b/src/nb_internals.cpp @@ -317,6 +317,7 @@ void default_exception_translator(const std::exception_ptr &p, void *) { } } +#if !defined(PYPY_VERSION) static void internals_cleanup() { bool leak = false; @@ -373,15 +374,28 @@ static void internals_cleanup() { #endif } } +#endif + +static PyObject *internals_dict() { +#if defined(PYPY_VERSION) + PyObject *dict = PyEval_GetBuiltins(); +#elif PY_VERSION_HEX < 0x03090000 + PyObject *dict = PyInterpreterState_GetDict(_PyInterpreterState_Get()); +#else + PyObject *dict = PyInterpreterState_GetDict(PyInterpreterState_Get()); +#endif + if (!dict) + fail("nanobind::detail::internals_dict(): failed!"); + + return dict; +} static void internals_make() { str nb_name("nanobind"); internals_p = new nb_internals(); - PyObject *dict = PyInterpreterState_GetDict(NB_INTERPRETER_STATE_GET()); - if (!dict) - fail("nanobind::detail::internals_make(): PyInterpreterState_GetDict() failed!"); + PyObject *dict = internals_dict(); const char *internals_id = NB_INTERNALS_ID; PyObject *capsule = PyCapsule_New(internals_p, internals_id, nullptr); @@ -495,6 +509,7 @@ static void internals_make() { PyErr_Clear(); } +#if !defined(PYPY_VERSION) // Install the memory leak checker if (Py_AtExit(internals_cleanup)) fprintf(stderr, @@ -503,12 +518,11 @@ static void internals_make() { "resources at interpreter shutdown (e.g., to avoid leaks being " "reported by tools like 'valgrind'). If you are a user of a " "python extension library, you can ignore this warning."); +#endif } static void internals_fetch() { - PyObject *dict = PyInterpreterState_GetDict(NB_INTERPRETER_STATE_GET()); - if (!dict) - fail("nanobind::detail::internals_fetch(): PyInterpreterState_GetDict() failed!"); + PyObject *dict = internals_dict(); const char *internals_id = NB_INTERNALS_ID; PyObject *capsule = PyDict_GetItemString(dict, internals_id); diff --git a/src/nb_type.cpp b/src/nb_type.cpp index 7c968047..a066e7f7 100644 --- a/src/nb_type.cpp +++ b/src/nb_type.cpp @@ -467,10 +467,14 @@ PyObject *nb_type_new(const type_data *t) noexcept { PyHeapTypeObject *temp_ht = (PyHeapTypeObject *) temp; PyTypeObject *temp_tp = &temp_ht->ht_type; - Py_INCREF(temp_ht->ht_name); - Py_INCREF(temp_ht->ht_qualname); - Py_INCREF(temp_tp->tp_base); + Py_INCREF (temp_ht->ht_name); + Py_INCREF (temp_ht->ht_qualname); Py_XINCREF(temp_ht->ht_slots); + Py_INCREF (temp_tp->tp_base); + +#if PY_VERSION_HEX >= 0x03090000 + Py_XINCREF(temp_ht->ht_module); +#endif char *tp_doc = nullptr; if (temp_tp->tp_doc) { @@ -487,36 +491,84 @@ PyObject *nb_type_new(const type_data *t) noexcept { PyHeapTypeObject *ht = (PyHeapTypeObject *) result; PyTypeObject *tp = &ht->ht_type; - memcpy(ht, temp_ht, sizeof(PyHeapTypeObject)); + ht->ht_name = temp_ht->ht_name; + ht->ht_qualname = temp_ht->ht_qualname; + ht->ht_slots = temp_ht->ht_slots; + +#if PY_VERSION_HEX >= 0x03090000 + ht->ht_module = temp_ht->ht_module; +#endif - tp->ob_base.ob_base.ob_type = metaclass; - tp->ob_base.ob_base.ob_refcnt = 1; - tp->ob_base.ob_size = 0; - tp->tp_as_async = &ht->as_async; - tp->tp_as_number = &ht->as_number; - tp->tp_as_sequence = &ht->as_sequence; - tp->tp_as_mapping = &ht->as_mapping; - tp->tp_as_buffer = &ht->as_buffer; tp->tp_name = name_copy; tp->tp_doc = tp_doc; tp->tp_flags = spec.flags | Py_TPFLAGS_HEAPTYPE; + if (temp_tp->tp_flags & Py_TPFLAGS_HAVE_GC) tp->tp_flags |= Py_TPFLAGS_HAVE_GC; + /* The following fields remain intentionally null-initialized + following the call to PyType_GenericAlloc(): tp_dict, tp_bases, tp_mro, + tp_cache, tp_subclasses, tp_weaklist. */ + + #define COPY_FIELD(name) \ + tp->name = temp_tp->name; + + COPY_FIELD(tp_basicsize); + COPY_FIELD(tp_itemsize); + COPY_FIELD(tp_dealloc); + COPY_FIELD(tp_vectorcall_offset); + COPY_FIELD(tp_getattr); + COPY_FIELD(tp_setattr); + COPY_FIELD(tp_repr); + COPY_FIELD(tp_hash); + COPY_FIELD(tp_call); + COPY_FIELD(tp_str); + COPY_FIELD(tp_getattro); + COPY_FIELD(tp_setattro); + COPY_FIELD(tp_traverse); + COPY_FIELD(tp_clear); + COPY_FIELD(tp_richcompare); + COPY_FIELD(tp_weaklistoffset); + COPY_FIELD(tp_iter); + COPY_FIELD(tp_iternext); + COPY_FIELD(tp_methods); + COPY_FIELD(tp_members); + COPY_FIELD(tp_getset); + COPY_FIELD(tp_base); + COPY_FIELD(tp_descr_get); + COPY_FIELD(tp_descr_set); + COPY_FIELD(tp_dictoffset); + COPY_FIELD(tp_init); + COPY_FIELD(tp_alloc); + COPY_FIELD(tp_new); + COPY_FIELD(tp_free); + COPY_FIELD(tp_is_gc); + COPY_FIELD(tp_del); + COPY_FIELD(tp_finalize); + COPY_FIELD(tp_vectorcall); + + #undef COPY_FIELD + + ht->as_async = temp_ht->as_async; + tp->tp_as_async = &ht->as_async; + + ht->as_number = temp_ht->as_number; + tp->tp_as_number = &ht->as_number; + + ht->as_sequence = temp_ht->as_sequence; + tp->tp_as_sequence = &ht->as_sequence; + + ht->as_mapping = temp_ht->as_mapping; + tp->tp_as_mapping = &ht->as_mapping; + + ht->as_buffer = temp_ht->as_buffer; + tp->tp_as_buffer = &ht->as_buffer; + #if PY_VERSION_HEX < 0x03090000 if (has_dynamic_attr) tp->tp_dictoffset = (Py_ssize_t) (basicsize - ptr_size); #endif - tp->tp_dict = tp->tp_bases = tp->tp_mro = tp->tp_cache = nullptr; - tp->tp_subclasses = tp->tp_weaklist = nullptr; - ht->ht_cached_keys = nullptr; - tp->tp_version_tag = 0; - -#if PY_VERSION_HEX >= 0x030B0000 - ht->_ht_tpname = nullptr; -#endif - PyType_Ready(tp); Py_DECREF(temp); #endif @@ -1140,6 +1192,7 @@ type_data *nb_type_data_static(PyTypeObject *o) noexcept { PyObject *nb_type_name(PyTypeObject *tp) noexcept { PyObject *name = PyObject_GetAttrString((PyObject *) tp, "__name__"); +#if !defined(PYPY_VERSION) if (PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE)) { PyObject *mod = PyObject_GetAttrString((PyObject *) tp, "__module__"), *combined = PyUnicode_FromFormat("%U.%U", mod, name); @@ -1148,6 +1201,7 @@ PyObject *nb_type_name(PyTypeObject *tp) noexcept { Py_DECREF(name); name = combined; } +#endif return name; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 558a284b..4e50777f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -35,6 +35,7 @@ target_link_libraries(test_inter_module_1_ext PRIVATE inter_module) target_link_libraries(test_inter_module_2_ext PRIVATE inter_module) set(TEST_FILES + common.py test_functions.py test_classes.py test_holders.py diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 00000000..387fa35d --- /dev/null +++ b/tests/common.py @@ -0,0 +1,15 @@ +import platform +import gc +import pytest + +is_pypy = platform.python_implementation() == 'PyPy' + +def collect(): + if is_pypy: + for i in range(3): + gc.collect() + else: + gc.collect() + +skip_on_pypy = pytest.mark.skipif( + is_pypy, reason="This test currently fails/crashes PyPy") diff --git a/tests/test_classes.py b/tests/test_classes.py index 78075b6f..455e20a6 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -1,15 +1,17 @@ import sys import test_classes_ext as t import pytest -import gc +from common import skip_on_pypy, collect + @pytest.fixture def clean(): - gc.collect() + collect() t.reset() + def assert_stats(**kwargs): - gc.collect() + collect() for k, v in t.stats().items(): fail = False if k in kwargs: @@ -329,20 +331,20 @@ def __init__(self): a = Struct2() assert t.keep_alive_arg(s, a) is a del s - gc.collect() + collect() assert constructed == 1 and destructed == 0 del a - gc.collect() + collect() assert constructed == 1 and destructed == 1 s = Struct() a = Struct2() assert t.keep_alive_ret(a, s) is s del a - gc.collect() + collect() assert constructed == 2 and destructed == 1 del s - gc.collect() + collect() assert constructed == 2 and destructed == 2 def f(): @@ -401,21 +403,26 @@ def test18_static_properties(): t.StaticProperties2.value = 50 assert t.StaticProperties2.get() == 50 assert t.StaticProperties.get() == 50 + + +@skip_on_pypy +def test19_static_properties_doc(): import pydoc assert "Static property docstring" in pydoc.render_doc(t.StaticProperties2) -def test19_supplement(): + +def test20_supplement(): c = t.ClassWithSupplement() assert t.check_supplement(c) assert not t.check_supplement(t.Struct()) -def test20_type_callback(): +def test21_type_callback(): o = t.ClassWithLen() assert len(o) == 123 -def test21_low_level(clean): +def test22_low_level(clean): s1, s2, s3 = t.test_lowlevel() assert s1.value() == 123 and s2.value() == 0 and s3.value() == 345 del s1 @@ -429,7 +436,7 @@ def test21_low_level(clean): ) -def test22_handle_t(clean): +def test23_handle_t(clean): assert t.test_handle_t.__doc__ == 'test_handle_t(arg: test_classes_ext.Struct, /) -> object' s = t.test_handle_t(t.Struct(5)) assert s.value() == 5 @@ -444,7 +451,7 @@ def test22_handle_t(clean): ) -def test23_type_object_t(clean): +def test24_type_object_t(clean): if sys.version_info < (3, 9): assert t.test_type_object_t.__doc__ == 'test_type_object_t(arg: Type[test_classes_ext.Struct], /) -> object' else: @@ -459,7 +466,7 @@ def test23_type_object_t(clean): t.test_type_object_t(int) -def test24_none_arg(): +def test25_none_arg(): with pytest.raises(TypeError): t.none_0(None) with pytest.raises(TypeError): @@ -475,14 +482,14 @@ def test24_none_arg(): assert t.none_4.__doc__ == 'none_4(arg: Optional[test_classes_ext.Struct]) -> bool' -def test25_is_final(): +def test26_is_final(): with pytest.raises(TypeError) as excinfo: class MyType(t.FinalType): pass assert "The type 'test_classes_ext.FinalType' prohibits subclassing!" in str(excinfo.value) -def test26_dynamic_attr(clean): +def test27_dynamic_attr(clean): l = [None] * 100 for i in range(100): l[i] = t.StructWithAttr(i) @@ -501,25 +508,25 @@ def test26_dynamic_attr(clean): assert l[i].next.value() == (i+1 if i < 99 else 0) del l - gc.collect() assert_stats( value_constructed=100, destructed=100 ) -def test27_copy_rvp(): + +def test28_copy_rvp(): a = t.Struct.create_reference() b = t.Struct.create_copy() assert a is not b -def test28_pydoc(): +def test29_pydoc(): import pydoc assert "Some documentation" in pydoc.render_doc(t) -def test29_property_assignment_instance(): +def test30_property_assignment_instance(): s = t.PairStruct() s1 = t.Struct(123) s2 = t.Struct(456) @@ -531,7 +538,8 @@ def test29_property_assignment_instance(): assert s1.value() == 123 assert s2.value() == 456 -def test30_cycle(): + +def test31_cycle(): a = t.Wrapper() a.value = a del a diff --git a/tests/test_holders.py b/tests/test_holders.py index bc7ecca9..25260389 100644 --- a/tests/test_holders.py +++ b/tests/test_holders.py @@ -1,10 +1,11 @@ import test_holders_ext as t import pytest -import gc +from common import collect + @pytest.fixture def clean(): - gc.collect() + collect() t.reset() # ------------------------------------------------------------------ @@ -15,7 +16,7 @@ def test01_create(clean): assert t.query_shared_1(e) == 123 assert t.query_shared_2(e) == 123 del e - gc.collect() + collect() assert t.stats() == (1, 1) @@ -24,19 +25,19 @@ def test02_sharedptr_from_python(clean): w = t.SharedWrapper(e) assert w.ptr is e del e - gc.collect() + collect() assert t.stats() == (1, 0) del w - gc.collect() + collect() assert t.stats() == (1, 1) w = t.SharedWrapper(t.Example(234)) assert t.stats() == (2, 1) w.ptr = t.Example(0) - gc.collect() + collect() assert t.stats() == (3, 2) del w - gc.collect() + collect() assert t.stats() == (3, 3) @@ -72,7 +73,7 @@ def test04_uniqueptr_from_cpp(clean): assert a.value == 1 assert b.value == 2 del a, b - gc.collect() + collect() assert t.stats() == (2, 2) @@ -92,7 +93,7 @@ def test05_uniqueptr_from_cpp(clean): assert 'incompatible function arguments' in str(excinfo.value) del a, b del wa, wb - gc.collect() + collect() assert t.stats() == (2, 2) t.reset() @@ -108,7 +109,7 @@ def test05_uniqueptr_from_cpp(clean): assert a.value == 1 and b.value == 2 assert t.stats() == (2, 0) del a, b, a2, b2 - gc.collect() + collect() assert t.stats() == (2, 2) @@ -126,7 +127,7 @@ def test06_uniqueptr_from_py(clean): a2 = wa.get() assert a2.value == 1 and a is a2 del a, a2 - gc.collect() + collect() assert t.stats() == (1, 1) def test07_uniqueptr_passthrough(clean): @@ -134,7 +135,7 @@ def test07_uniqueptr_passthrough(clean): assert t.passthrough_unique(t.unique_from_cpp_2()).value == 2 assert t.passthrough_unique_2(t.unique_from_cpp()).value == 1 assert t.passthrough_unique_2(t.unique_from_cpp_2()).value == 2 - gc.collect() + collect() assert t.stats() == (4, 4) t.reset() diff --git a/tests/test_intrusive.py b/tests/test_intrusive.py index 83d32b33..3337f58f 100644 --- a/tests/test_intrusive.py +++ b/tests/test_intrusive.py @@ -1,11 +1,10 @@ import test_intrusive_ext as t import pytest -import gc +from common import collect @pytest.fixture def clean(): - gc.collect() - gc.collect() + collect() t.reset() def test01_construct(clean): @@ -15,8 +14,7 @@ def test01_construct(clean): assert t.get_value_2(o) == 123 assert t.get_value_3(o) == 123 del o - gc.collect() - gc.collect() + collect() assert t.stats() == (1, 1) @@ -27,8 +25,7 @@ def test02_factory(clean): assert t.get_value_2(o) == 123 assert t.get_value_3(o) == 123 del o - gc.collect() - gc.collect() + collect() assert t.stats() == (1, 1) @@ -39,8 +36,7 @@ def test03_factory_ref(clean): assert t.get_value_2(o) == 123 assert t.get_value_3(o) == 123 del o - gc.collect() - gc.collect() + collect() assert t.stats() == (1, 1) def test04_subclass(clean): @@ -58,6 +54,5 @@ def value(self): assert t.get_value_2(o) == 456 assert t.get_value_3(o) == 456 del o - gc.collect() - gc.collect() + collect() assert t.stats() == (1, 1) diff --git a/tests/test_stl.py b/tests/test_stl.py index bf615763..25b497ac 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -1,15 +1,17 @@ import test_stl_ext as t import pytest -import gc import sys +from common import collect + @pytest.fixture def clean(): - gc.collect() + collect() t.reset() + def assert_stats(**kwargs): - gc.collect() + collect() for k, v in t.stats().items(): fail = False if k in kwargs: @@ -352,8 +354,7 @@ def f(): assert t.FuncWrapper.alive == 1 del b import gc - gc.collect() - gc.collect() + collect() assert t.FuncWrapper.alive == 0 def test33_vec_type_check(): @@ -682,7 +683,7 @@ def test65_class_with_movable_field(clean): ) del m1, m2 - gc.collect() + collect() assert_stats( value_constructed=2, @@ -691,7 +692,7 @@ def test65_class_with_movable_field(clean): ) del cwmf - gc.collect() + collect() assert_stats( value_constructed=2, diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 55e91341..49d75e59 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -1,8 +1,8 @@ import test_tensor_ext as t import pytest import warnings -import gc import importlib +from common import collect try: import numpy as np @@ -33,6 +33,7 @@ def needs_jax(x): needs_jax = pytest.mark.skip(reason="JAX is required") + @needs_numpy def test01_metadata(): a = np.zeros(shape=()) @@ -234,18 +235,18 @@ def test12_implicit_conversion_jax(): def test13_destroy_capsule(): - gc.collect() + collect() dc = t.destruct_count() a = t.return_dlpack() assert dc == t.destruct_count() del a - gc.collect() + collect() assert t.destruct_count() - dc == 1 @needs_numpy def test14_consume_numpy(): - gc.collect() + collect() class wrapper: def __init__(self, value): self.value = value @@ -261,18 +262,18 @@ def __dlpack__(self): pytest.skip('your version of numpy is too old') del a - gc.collect() + collect() assert x.shape == (2, 4) assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) assert dc == t.destruct_count() del x - gc.collect() + collect() assert t.destruct_count() - dc == 1 @needs_numpy def test15_passthrough(): - gc.collect() + collect() class wrapper: def __init__(self, value): self.value = value @@ -290,24 +291,24 @@ def __dlpack__(self): del a del b - gc.collect() + collect() assert dc == t.destruct_count() assert y.shape == (2, 4) assert np.all(y == [[1, 2, 3, 4], [5, 6, 7, 8]]) del y - gc.collect() + collect() assert t.destruct_count() - dc == 1 @needs_numpy def test16_return_numpy(): - gc.collect() + collect() dc = t.destruct_count() x = t.ret_numpy() assert x.shape == (2, 4) assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) del x - gc.collect() + collect() assert t.destruct_count() - dc == 1 @@ -317,21 +318,21 @@ def test17_return_pytorch(): c = torch.zeros(3, 5) except: pytest.skip('pytorch is missing') - gc.collect() + collect() dc = t.destruct_count() x = t.ret_pytorch() assert x.shape == (2, 4) assert torch.all(x == torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])) del x - gc.collect() + collect() assert t.destruct_count() - dc == 1 @needs_numpy def test18_return_array_scalar(): - gc.collect() + collect() dc = t.destruct_count() x = t.ret_array_scalar() assert np.array_equal(x, np.array(1)) del x - gc.collect() + collect() assert t.destruct_count() - dc == 1