Skip to content

Commit

Permalink
[3.10] bpo-46433: _PyType_GetModuleByDef: handle static types in MRO (G…
Browse files Browse the repository at this point in the history
…H-30696) (GH-31262)


(cherry picked from commit 0ef0853)
  • Loading branch information
encukou committed Feb 11, 2022
1 parent 1124ab6 commit 8b8673f
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 13 deletions.
16 changes: 16 additions & 0 deletions Lib/test/test_capi.py
Expand Up @@ -1030,6 +1030,22 @@ def test_state_access(self):
with self.assertRaises(TypeError):
increment_count(1, 2, 3)

def test_get_module_bad_def(self):
# _PyType_GetModuleByDef fails gracefully if it doesn't
# find what it's looking for.
# see bpo-46433
instance = self.module.StateAccessType()
with self.assertRaises(TypeError):
instance.getmodulebydef_bad_def()

def test_get_module_static_in_mro(self):
# Here, the class _PyType_GetModuleByDef is looking for
# appears in the MRO after a static type (Exception).
# see bpo-46433
class Subclass(BaseException, self.module.StateAccessType):
pass
self.assertIs(Subclass().get_defining_module(), self.module)


if __name__ == "__main__":
unittest.main()
@@ -0,0 +1,2 @@
The internal function _PyType_GetModuleByDef now correctly handles
inheritance patterns involving static types.
24 changes: 21 additions & 3 deletions Modules/_testmultiphase.c
Expand Up @@ -122,6 +122,8 @@ static PyType_Spec Example_Type_spec = {


static PyModuleDef def_meth_state_access;
static PyModuleDef def_nonmodule;
static PyModuleDef def_nonmodule_with_methods;

/*[clinic input]
_testmultiphase.StateAccessType.get_defining_module
Expand Down Expand Up @@ -149,6 +151,24 @@ _testmultiphase_StateAccessType_get_defining_module_impl(StateAccessTypeObject *
return retval;
}

/*[clinic input]
_testmultiphase.StateAccessType.getmodulebydef_bad_def
cls: defining_class
Test that result of _PyType_GetModuleByDef with a bad def is NULL.
[clinic start generated code]*/

static PyObject *
_testmultiphase_StateAccessType_getmodulebydef_bad_def_impl(StateAccessTypeObject *self,
PyTypeObject *cls)
/*[clinic end generated code: output=64509074dfcdbd31 input=906047715ee293cd]*/
{
_PyType_GetModuleByDef(Py_TYPE(self), &def_nonmodule); // should raise
assert(PyErr_Occurred());
return NULL;
}

/*[clinic input]
_testmultiphase.StateAccessType.increment_count_clinic
Expand Down Expand Up @@ -245,6 +265,7 @@ _testmultiphase_StateAccessType_get_count_impl(StateAccessTypeObject *self,

static PyMethodDef StateAccessType_methods[] = {
_TESTMULTIPHASE_STATEACCESSTYPE_GET_DEFINING_MODULE_METHODDEF
_TESTMULTIPHASE_STATEACCESSTYPE_GETMODULEBYDEF_BAD_DEF_METHODDEF
_TESTMULTIPHASE_STATEACCESSTYPE_GET_COUNT_METHODDEF
_TESTMULTIPHASE_STATEACCESSTYPE_INCREMENT_COUNT_CLINIC_METHODDEF
{
Expand Down Expand Up @@ -433,9 +454,6 @@ PyInit__testmultiphase(PyObject *spec)

/**** Importing a non-module object ****/

static PyModuleDef def_nonmodule;
static PyModuleDef def_nonmodule_with_methods;

/* Create a SimpleNamespace(three=3) */
static PyObject*
createfunc_nonmodule(PyObject *spec, PyModuleDef *def)
Expand Down
32 changes: 31 additions & 1 deletion Modules/clinic/_testmultiphase.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 7 additions & 9 deletions Objects/typeobject.c
Expand Up @@ -3707,22 +3707,20 @@ _PyType_GetModuleByDef(PyTypeObject *type, struct PyModuleDef *def)
// to check i < PyTuple_GET_SIZE(mro) at the first loop iteration.
assert(PyTuple_GET_SIZE(mro) >= 1);

Py_ssize_t i = 0;
do {
Py_ssize_t n = PyTuple_GET_SIZE(mro);
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *super = PyTuple_GET_ITEM(mro, i);
// _PyType_GetModuleByDef() must only be called on a heap type created
// by PyType_FromModuleAndSpec() or on its subclasses.
// type_ready_mro() ensures that a static type cannot inherit from a
// heap type.
assert(_PyType_HasFeature((PyTypeObject *)type, Py_TPFLAGS_HEAPTYPE));
if(!_PyType_HasFeature((PyTypeObject *)super, Py_TPFLAGS_HEAPTYPE)) {
// Static types in the MRO need to be skipped
continue;
}

PyHeapTypeObject *ht = (PyHeapTypeObject*)super;
PyObject *module = ht->ht_module;
if (module && _PyModule_GetDef(module) == def) {
return module;
}
i++;
} while (i < PyTuple_GET_SIZE(mro));
}

PyErr_Format(
PyExc_TypeError,
Expand Down

0 comments on commit 8b8673f

Please sign in to comment.