Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions Lib/test/test_defaultdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import copy
import pickle
import threading
import time
import unittest

from collections import defaultdict
from test.support import threading_helper

def foobar():
return list
Expand Down Expand Up @@ -186,5 +189,62 @@ def test_union(self):
with self.assertRaises(TypeError):
i |= None

@threading_helper.requires_working_threading()
def test_no_value_overwrite_race_condition(self):
call_count = 0

def unique_factory():
nonlocal call_count
call_count += 1
return f"value_{call_count}_{threading.get_ident()}"

d = defaultdict(unique_factory)
results = {}

def worker(thread_id):
for _ in range(5):
value = d['shared_key']
if 'shared_key' not in results:
results['shared_key'] = value
time.sleep(0.001)

threads = []
for i in range(3):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()

for t in threads:
t.join()

self.assertIn('shared_key', d)

final_value = d['shared_key']
self.assertEqual(results['shared_key'], final_value)

self.assertTrue(final_value.startswith('value_'))

self.assertEqual(call_count, 1)

def test_factory_called_only_when_key_missing(self):
factory_calls = []

def tracked_factory():
factory_calls.append(threading.get_ident())
return [1, 2, 3]

d = defaultdict(tracked_factory)

value1 = d['key']
self.assertEqual(value1, [1, 2, 3])
initial_call_count = len(factory_calls)

for _ in range(10):
value = d['key']
self.assertEqual(value, [1, 2, 3])

self.assertEqual(len(factory_calls), initial_call_count)


if __name__ == "__main__":
unittest.main()
17 changes: 13 additions & 4 deletions Modules/_collectionsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -2228,14 +2228,23 @@ defdict_missing(PyObject *op, PyObject *key)
Py_DECREF(tup);
return NULL;
}

value = _PyObject_CallNoArgs(factory);
if (value == NULL)
return value;
if (PyObject_SetItem(op, key, value) < 0) {
Py_DECREF(value);
return NULL;

/* Use PyDict_SetDefaultRef to atomically insert the value only if the key is absent.
* This ensures we don't overwrite a value that another thread inserted
* between the factory call and this insertion.
*/
PyObject *result_value = NULL;
int res = PyDict_SetDefaultRef(op, key, value, &result_value);

if (res != 0) {
Py_DECREF(value);
}
return value;

return result_value;
}

static inline PyObject*
Expand Down
Loading