From dbf1d6f5e7de33ec61a89cb86fee8baa0c3c97a6 Mon Sep 17 00:00:00 2001 From: Prathmesh Adsod <91672445+PrathmeshAdsod@users.noreply.github.com> Date: Tue, 19 Aug 2025 05:51:23 +0000 Subject: [PATCH] Mark map/accumulate iterators exhausted when the user callback raises StopIteration --- Modules/itertoolsmodule.c | 12 +++++++++++- Python/bltinmodule.c | 12 ++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index bc23ad7e8488ee..3c96bcea388e23 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -2983,6 +2983,7 @@ typedef struct { PyObject *it; PyObject *binop; PyObject *initial; + int finished; itertools_state *state; } accumulateobject; @@ -3024,6 +3025,7 @@ itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable, lz->total = NULL; lz->it = it; lz->initial = Py_XNewRef(initial); + lz->finished = 0; lz->state = find_state_by_type(type); return (PyObject *)lz; } @@ -3060,6 +3062,10 @@ accumulate_next(PyObject *op) accumulateobject *lz = accumulateobject_CAST(op); PyObject *val, *newtotal; + if (lz->finished) { + return NULL; + } + if (lz->initial != Py_None) { lz->total = lz->initial; lz->initial = Py_NewRef(Py_None); @@ -3079,8 +3085,12 @@ accumulate_next(PyObject *op) else newtotal = PyObject_CallFunctionObjArgs(lz->binop, lz->total, val, NULL); Py_DECREF(val); - if (newtotal == NULL) + if (newtotal == NULL) { + if (lz->binop != NULL && PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration)) { + lz->finished = 1; + } return NULL; + } Py_INCREF(newtotal); Py_SETREF(lz->total, newtotal); diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c index 209bc56dd1b153..3dd2c91934826c 100644 --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -1354,6 +1354,7 @@ typedef struct { PyObject *iters; PyObject *func; int strict; + int finished; } mapobject; #define _mapobject_CAST(op) ((mapobject *)(op)) @@ -1411,6 +1412,7 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds) func = PyTuple_GET_ITEM(args, 0); lz->func = Py_NewRef(func); lz->strict = strict; + lz->finished = 0; return (PyObject *)lz; } @@ -1456,6 +1458,7 @@ map_vectorcall(PyObject *type, PyObject * const*args, lz->iters = iters; lz->func = Py_NewRef(args[0]); lz->strict = 0; + lz->finished = 0; return (PyObject *)lz; } @@ -1489,6 +1492,10 @@ map_next(PyObject *self) PyObject *result = NULL; PyThreadState *tstate = _PyThreadState_GET(); + if (lz->finished) { + return NULL; + } + const Py_ssize_t niters = PyTuple_GET_SIZE(lz->iters); if (niters <= (Py_ssize_t)Py_ARRAY_LENGTH(small_stack)) { stack = small_stack; @@ -1516,6 +1523,11 @@ map_next(PyObject *self) } result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL); + if (result == NULL && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_StopIteration)) { + lz->finished = 1; + } + } exit: for (i=0; i < nargs; i++) {