diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py index bdbe9b81e8fb3f..625bac46a3ee09 100644 --- a/Lib/test/test_defaultdict.py +++ b/Lib/test/test_defaultdict.py @@ -2,9 +2,12 @@ import copy import pickle +import threading +import time import unittest from collections import defaultdict +from test.support import threading_helper def foobar(): return list @@ -186,5 +189,62 @@ def test_union(self): with self.assertRaises(TypeError): i |= None + @threading_helper.requires_working_threading() + def test_no_value_overwrite_race_condition(self): + call_count = 0 + + def unique_factory(): + nonlocal call_count + call_count += 1 + return f"value_{call_count}_{threading.get_ident()}" + + d = defaultdict(unique_factory) + results = {} + + def worker(thread_id): + for _ in range(5): + value = d['shared_key'] + if 'shared_key' not in results: + results['shared_key'] = value + time.sleep(0.001) + + threads = [] + for i in range(3): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + self.assertIn('shared_key', d) + + final_value = d['shared_key'] + self.assertEqual(results['shared_key'], final_value) + + self.assertTrue(final_value.startswith('value_')) + + self.assertEqual(call_count, 1) + + def test_factory_called_only_when_key_missing(self): + factory_calls = [] + + def tracked_factory(): + factory_calls.append(threading.get_ident()) + return [1, 2, 3] + + d = defaultdict(tracked_factory) + + value1 = d['key'] + self.assertEqual(value1, [1, 2, 3]) + initial_call_count = len(factory_calls) + + for _ in range(10): + value = d['key'] + self.assertEqual(value, [1, 2, 3]) + + self.assertEqual(len(factory_calls), initial_call_count) + + if __name__ == "__main__": unittest.main() diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index 3ba48d5d9d3c64..50b077b440091d 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -2228,14 +2228,23 @@ defdict_missing(PyObject *op, PyObject *key) Py_DECREF(tup); return NULL; } + value = _PyObject_CallNoArgs(factory); if (value == NULL) - return value; - if (PyObject_SetItem(op, key, value) < 0) { - Py_DECREF(value); return NULL; + + /* Use PyDict_SetDefaultRef to atomically insert the value only if the key is absent. + * This ensures we don't overwrite a value that another thread inserted + * between the factory call and this insertion. + */ + PyObject *result_value = NULL; + int res = PyDict_SetDefaultRef(op, key, value, &result_value); + + if (res != 0) { + Py_DECREF(value); } - return value; + + return result_value; } static inline PyObject*