diff --git a/Lib/test/test_free_threading/test_itertools.py b/Lib/test/test_free_threading/test_itertools.py index 9d366041917bb3..bb6047e8669475 100644 --- a/Lib/test/test_free_threading/test_itertools.py +++ b/Lib/test/test_free_threading/test_itertools.py @@ -1,94 +1,59 @@ import unittest -from threading import Thread, Barrier -from itertools import batched, chain, cycle +from itertools import batched, chain, combinations_with_replacement, cycle, permutations from test.support import threading_helper threading_helper.requires_working_threading(module=True) -class ItertoolsThreading(unittest.TestCase): - - @threading_helper.reap_threads - def test_batched(self): - number_of_threads = 10 - number_of_iterations = 20 - barrier = Barrier(number_of_threads) - def work(it): - barrier.wait() - while True: - try: - next(it) - except StopIteration: - break - data = tuple(range(1000)) - for it in range(number_of_iterations): - batch_iterator = batched(data, 2) - worker_threads = [] - for ii in range(number_of_threads): - worker_threads.append( - Thread(target=work, args=[batch_iterator])) +def work_iterator(it): + while True: + try: + next(it) + except StopIteration: + break - with threading_helper.start_threads(worker_threads): - pass - barrier.reset() +class ItertoolsThreading(unittest.TestCase): @threading_helper.reap_threads - def test_cycle(self): - number_of_threads = 6 + def test_batched(self): number_of_iterations = 10 - number_of_cycles = 400 + for _ in range(number_of_iterations): + it = batched(tuple(range(1000)), 2) + threading_helper.run_concurrently(work_iterator, nthreads=10, args=[it]) - barrier = Barrier(number_of_threads) + @threading_helper.reap_threads + def test_cycle(self): def work(it): - barrier.wait() - for _ in range(number_of_cycles): - try: - next(it) - except StopIteration: - pass + for _ in range(400): + next(it) - data = (1, 2, 3, 4) - for it in range(number_of_iterations): - cycle_iterator = cycle(data) - worker_threads = [] - for ii in range(number_of_threads): - worker_threads.append( - Thread(target=work, args=[cycle_iterator])) - - with threading_helper.start_threads(worker_threads): - pass - - barrier.reset() + number_of_iterations = 6 + for _ in range(number_of_iterations): + it = cycle((1, 2, 3, 4)) + threading_helper.run_concurrently(work, nthreads=6, args=[it]) @threading_helper.reap_threads def test_chain(self): - number_of_threads = 6 - number_of_iterations = 20 - - barrier = Barrier(number_of_threads) - def work(it): - barrier.wait() - while True: - try: - next(it) - except StopIteration: - break - - data = [(1, )] * 200 - for it in range(number_of_iterations): - chain_iterator = chain(*data) - worker_threads = [] - for ii in range(number_of_threads): - worker_threads.append( - Thread(target=work, args=[chain_iterator])) - - with threading_helper.start_threads(worker_threads): - pass + number_of_iterations = 10 + for _ in range(number_of_iterations): + it = chain(*[(1,)] * 200) + threading_helper.run_concurrently(work_iterator, nthreads=6, args=[it]) - barrier.reset() + @threading_helper.reap_threads + def test_combinations_with_replacement(self): + number_of_iterations = 6 + for _ in range(number_of_iterations): + it = combinations_with_replacement(tuple(range(2)), 2) + threading_helper.run_concurrently(work_iterator, nthreads=6, args=[it]) + @threading_helper.reap_threads + def test_permutations(self): + number_of_iterations = 6 + for _ in range(number_of_iterations): + it = permutations(tuple(range(4)), 2) + threading_helper.run_concurrently(work_iterator, nthreads=6, args=[it]) if __name__ == "__main__": diff --git a/Misc/NEWS.d/next/Library/2026-02-03-08-50-58.gh-issue-123471.yF1Gym.rst b/Misc/NEWS.d/next/Library/2026-02-03-08-50-58.gh-issue-123471.yF1Gym.rst new file mode 100644 index 00000000000000..85e9a03426e1fc --- /dev/null +++ b/Misc/NEWS.d/next/Library/2026-02-03-08-50-58.gh-issue-123471.yF1Gym.rst @@ -0,0 +1 @@ +Make concurrent iteration over :class:`itertools.combinations_with_replacement` and :class:`itertools.permutations` safe under free-threading. diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 8685eff8be65c3..7e73f76bc20b58 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -2587,7 +2587,7 @@ cwr_traverse(PyObject *op, visitproc visit, void *arg) } static PyObject * -cwr_next(PyObject *op) +cwr_next_lock_held(PyObject *op) { cwrobject *co = cwrobject_CAST(op); PyObject *elem; @@ -2666,6 +2666,16 @@ cwr_next(PyObject *op) return NULL; } +static PyObject * +cwr_next(PyObject *op) +{ + PyObject *result; + Py_BEGIN_CRITICAL_SECTION(op); + result = cwr_next_lock_held(op); + Py_END_CRITICAL_SECTION() + return result; +} + static PyMethodDef cwr_methods[] = { {"__sizeof__", cwr_sizeof, METH_NOARGS, sizeof_doc}, {NULL, NULL} /* sentinel */ @@ -2846,7 +2856,7 @@ permutations_traverse(PyObject *op, visitproc visit, void *arg) } static PyObject * -permutations_next(PyObject *op) +permutations_next_lock_held(PyObject *op) { permutationsobject *po = permutationsobject_CAST(op); PyObject *elem; @@ -2936,6 +2946,16 @@ permutations_next(PyObject *op) return NULL; } +static PyObject * +permutations_next(PyObject *op) +{ + PyObject *result; + Py_BEGIN_CRITICAL_SECTION(op); + result = permutations_next_lock_held(op); + Py_END_CRITICAL_SECTION() + return result; +} + static PyMethodDef permuations_methods[] = { {"__sizeof__", permutations_sizeof, METH_NOARGS, sizeof_doc}, {NULL, NULL} /* sentinel */