Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions Lib/test/test_free_threading/test_capi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import ctypes
import sys
import unittest

from test.support import threading_helper
from test.support.threading_helper import run_concurrently


_PyImport_AddModuleRef = ctypes.pythonapi.PyImport_AddModuleRef
_PyImport_AddModuleRef.argtypes = (ctypes.c_char_p,)
_PyImport_AddModuleRef.restype = ctypes.py_object


@threading_helper.requires_working_threading()
class TestImportCAPI(unittest.TestCase):
def test_pyimport_addmoduleref_thread_safe(self):
# gh-137422: Concurrent calls to PyImport_AddModuleRef with the same
# module name must return the same module object.

NUM_ITERS = 10
NTHREADS = 4

module_name = f"test_free_threading_addmoduleref_{id(self)}"
module_name_bytes = module_name.encode()
sys.modules.pop(module_name, None)
results = []

def worker():
module = _PyImport_AddModuleRef(module_name_bytes)
results.append(module)

for _ in range(NUM_ITERS):
try:
run_concurrently(worker_func=worker, nthreads=NTHREADS)
self.assertEqual(len(results), NTHREADS)
reference = results[0]
for module in results[1:]:
self.assertIs(module, reference)
self.assertIn(module_name, sys.modules)
self.assertIs(sys.modules[module_name], reference)
finally:
results.clear()
sys.modules.pop(module_name, None)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fix :term:`free threading` race condition in
:c:func:`PyImport_AddModuleRef`. It was previously possible for two calls to
the function return two different objects, only one of which was stored in
:data:`sys.modules`.
23 changes: 17 additions & 6 deletions Python/import.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "Python.h"
#include "pycore_audit.h" // _PySys_Audit()
#include "pycore_ceval.h"
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
#include "pycore_hashtable.h" // _Py_hashtable_new_full()
#include "pycore_import.h" // _PyImport_BootstrapImp()
#include "pycore_initconfig.h" // _PyStatus_OK()
Expand Down Expand Up @@ -309,13 +310,8 @@ PyImport_GetModule(PyObject *name)
if not, create a new one and insert it in the modules dictionary. */

static PyObject *
import_add_module(PyThreadState *tstate, PyObject *name)
import_add_module_lock_held(PyObject *modules, PyObject *name)
{
PyObject *modules = get_modules_dict(tstate, false);
if (modules == NULL) {
return NULL;
}

PyObject *m;
if (PyMapping_GetOptionalItem(modules, name, &m) < 0) {
return NULL;
Expand All @@ -335,6 +331,21 @@ import_add_module(PyThreadState *tstate, PyObject *name)
return m;
}

static PyObject *
import_add_module(PyThreadState *tstate, PyObject *name)
{
PyObject *modules = get_modules_dict(tstate, false);
if (modules == NULL) {
return NULL;
}

PyObject *m;
Py_BEGIN_CRITICAL_SECTION(modules);
m = import_add_module_lock_held(modules, name);
Py_END_CRITICAL_SECTION();
return m;
}

PyObject *
PyImport_AddModuleRef(const char *name)
{
Expand Down
Loading