From 4310b8876eb35b1aad77b27e2bb235e6e99985f6 Mon Sep 17 00:00:00 2001 From: fatelei Date: Thu, 11 Dec 2025 20:06:06 +0800 Subject: [PATCH 1/3] fix: __missing__ race condition --- Lib/test/test_defaultdict.py | 73 ++++++++++++++++++++++++++++++++++++ Modules/_collectionsmodule.c | 26 ++++++++++--- 2 files changed, 93 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py index bdbe9b81e8fb3f..2d0917679ab7c6 100644 --- a/Lib/test/test_defaultdict.py +++ b/Lib/test/test_defaultdict.py @@ -2,6 +2,8 @@ import copy import pickle +import threading +import time import unittest from collections import defaultdict @@ -186,5 +188,76 @@ def test_union(self): with self.assertRaises(TypeError): i |= None + def test_no_value_overwrite_race_condition(self): + """Test that concurrent access to missing keys doesn't overwrite values.""" + # Use a factory that returns unique objects so we can detect overwrites + call_count = 0 + + def unique_factory(): + nonlocal call_count + call_count += 1 + # Return a unique object that identifies this call + return f"value_{call_count}_{threading.get_ident()}" + + d = defaultdict(unique_factory) + results = {} + + def worker(thread_id): + # Multiple threads access the same missing key + for _ in range(5): + value = d['shared_key'] + if 'shared_key' not in results: + results['shared_key'] = value + # Small delay to increase chance of race conditions + time.sleep(0.001) + + # Start multiple threads + threads = [] + for i in range(3): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # Key should exist in the dictionary + self.assertIn('shared_key', d) + + # All threads should see the same value (no overwrites occurred) + final_value = d['shared_key'] + self.assertEqual(results['shared_key'], final_value) + + # The value should be from the first successful factory call + self.assertTrue(final_value.startswith('value_')) + + # Factory should only be called once (since key only inserted once) + self.assertEqual(call_count, 1) + + def test_factory_called_only_when_key_missing(self): + """Test that factory is only called when key is truly missing.""" + factory_calls = [] + + def tracked_factory(): + factory_calls.append(threading.get_ident()) + return [1, 2, 3] + + d = defaultdict(tracked_factory) + + # First access should call factory + value1 = d['key'] + self.assertEqual(value1, [1, 2, 3]) + initial_call_count = len(factory_calls) + + # Multiple subsequent accesses should not call factory + for _ in range(10): + value = d['key'] + self.assertEqual(value, [1, 2, 3]) + + # Factory call count should not have increased + self.assertEqual(len(factory_calls), initial_call_count) + + if __name__ == "__main__": unittest.main() diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index 3ba48d5d9d3c64..73f551dc650b34 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -2218,7 +2218,6 @@ defdict_missing(PyObject *op, PyObject *key) { defdictobject *dd = defdictobject_CAST(op); PyObject *factory = dd->default_factory; - PyObject *value; if (factory == NULL || factory == Py_None) { /* XXX Call dict.__missing__(key) */ PyObject *tup; @@ -2228,14 +2227,29 @@ defdict_missing(PyObject *op, PyObject *key) Py_DECREF(tup); return NULL; } - value = _PyObject_CallNoArgs(factory); - if (value == NULL) - return value; - if (PyObject_SetItem(op, key, value) < 0) { + + PyObject *value = _PyObject_CallNoArgs(factory); + if (value == NULL) { + return NULL; + } + + /* Use PyDict_SetDefaultRef to atomically insert the value only if the key is absent. + * This ensures we don't overwrite a value that another thread inserted + * between the factory call and this insertion. + */ + PyObject *result_value = NULL; + int res = PyDict_SetDefaultRef(op, key, value, &result_value); + + if (res < 0) { Py_DECREF(value); return NULL; } - return value; + + if (res != 0) { + Py_DECREF(value); + } + + return result_value; } static inline PyObject* From 437f14f2e1b86a0c52cbdc91cadc9dfd071e3992 Mon Sep 17 00:00:00 2001 From: fatelei Date: Thu, 11 Dec 2025 21:22:15 +0800 Subject: [PATCH 2/3] fix: fix wasi test failed --- Lib/test/test_defaultdict.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py index 2d0917679ab7c6..23aef2123afae0 100644 --- a/Lib/test/test_defaultdict.py +++ b/Lib/test/test_defaultdict.py @@ -7,6 +7,7 @@ import unittest from collections import defaultdict +from test.support import threading_helper def foobar(): return list @@ -188,6 +189,7 @@ def test_union(self): with self.assertRaises(TypeError): i |= None + @threading_helper.requires_working_threading() def test_no_value_overwrite_race_condition(self): """Test that concurrent access to missing keys doesn't overwrite values.""" # Use a factory that returns unique objects so we can detect overwrites From b0836c1d769efc03e6116c3908e1cd21e98449cc Mon Sep 17 00:00:00 2001 From: fatelei Date: Thu, 11 Dec 2025 21:26:25 +0800 Subject: [PATCH 3/3] fix: fix comment problem --- Lib/test/test_defaultdict.py | 15 --------------- Modules/_collectionsmodule.c | 11 +++-------- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py index 23aef2123afae0..625bac46a3ee09 100644 --- a/Lib/test/test_defaultdict.py +++ b/Lib/test/test_defaultdict.py @@ -191,54 +191,42 @@ def test_union(self): @threading_helper.requires_working_threading() def test_no_value_overwrite_race_condition(self): - """Test that concurrent access to missing keys doesn't overwrite values.""" - # Use a factory that returns unique objects so we can detect overwrites call_count = 0 def unique_factory(): nonlocal call_count call_count += 1 - # Return a unique object that identifies this call return f"value_{call_count}_{threading.get_ident()}" d = defaultdict(unique_factory) results = {} def worker(thread_id): - # Multiple threads access the same missing key for _ in range(5): value = d['shared_key'] if 'shared_key' not in results: results['shared_key'] = value - # Small delay to increase chance of race conditions time.sleep(0.001) - # Start multiple threads threads = [] for i in range(3): t = threading.Thread(target=worker, args=(i,)) threads.append(t) t.start() - # Wait for all threads to complete for t in threads: t.join() - # Key should exist in the dictionary self.assertIn('shared_key', d) - # All threads should see the same value (no overwrites occurred) final_value = d['shared_key'] self.assertEqual(results['shared_key'], final_value) - # The value should be from the first successful factory call self.assertTrue(final_value.startswith('value_')) - # Factory should only be called once (since key only inserted once) self.assertEqual(call_count, 1) def test_factory_called_only_when_key_missing(self): - """Test that factory is only called when key is truly missing.""" factory_calls = [] def tracked_factory(): @@ -247,17 +235,14 @@ def tracked_factory(): d = defaultdict(tracked_factory) - # First access should call factory value1 = d['key'] self.assertEqual(value1, [1, 2, 3]) initial_call_count = len(factory_calls) - # Multiple subsequent accesses should not call factory for _ in range(10): value = d['key'] self.assertEqual(value, [1, 2, 3]) - # Factory call count should not have increased self.assertEqual(len(factory_calls), initial_call_count) diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index 73f551dc650b34..50b077b440091d 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -2218,6 +2218,7 @@ defdict_missing(PyObject *op, PyObject *key) { defdictobject *dd = defdictobject_CAST(op); PyObject *factory = dd->default_factory; + PyObject *value; if (factory == NULL || factory == Py_None) { /* XXX Call dict.__missing__(key) */ PyObject *tup; @@ -2228,10 +2229,9 @@ defdict_missing(PyObject *op, PyObject *key) return NULL; } - PyObject *value = _PyObject_CallNoArgs(factory); - if (value == NULL) { + value = _PyObject_CallNoArgs(factory); + if (value == NULL) return NULL; - } /* Use PyDict_SetDefaultRef to atomically insert the value only if the key is absent. * This ensures we don't overwrite a value that another thread inserted @@ -2240,11 +2240,6 @@ defdict_missing(PyObject *op, PyObject *key) PyObject *result_value = NULL; int res = PyDict_SetDefaultRef(op, key, value, &result_value); - if (res < 0) { - Py_DECREF(value); - return NULL; - } - if (res != 0) { Py_DECREF(value); }