diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index 6bcda307f256f2..c016fb76bfd0a0 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -164,11 +164,14 @@ loops that truncate the stream. Added the optional *initial* parameter. -.. function:: batched(iterable, n) +.. function:: batched(iterable, n, *, strict=False) Batch data from the *iterable* into tuples of length *n*. The last batch may be shorter than *n*. + If *strict* is true, will raise a :exc:`ValueError` if the final + batch is shorter than *n*. + Loops over the input iterable and accumulates data into tuples up to size *n*. The input is consumed lazily, just enough to fill a batch. The result is yielded as soon as the batch is full or when the input @@ -190,16 +193,21 @@ loops that truncate the stream. Roughly equivalent to:: - def batched(iterable, n): + def batched(iterable, n, *, strict=False): # batched('ABCDEFG', 3) --> ABC DEF G if n < 1: raise ValueError('n must be at least one') it = iter(iterable) while batch := tuple(islice(it, n)): + if strict and len(batch) != n: + raise ValueError('batched(): incomplete batch') yield batch .. versionadded:: 3.12 + .. versionchanged:: 3.13 + Added the *strict* option. + .. function:: chain(*iterables) @@ -1039,7 +1047,7 @@ The following recipes have a more mathematical flavor: def reshape(matrix, cols): "Reshape a 2-D matrix to have a given number of columns." # reshape([(0, 1), (2, 3), (4, 5)], 3) --> (0, 1, 2), (3, 4, 5) - return batched(chain.from_iterable(matrix), cols) + return batched(chain.from_iterable(matrix), cols, strict=True) def transpose(matrix): "Swap the rows and columns of a 2-D matrix." @@ -1270,6 +1278,10 @@ The following recipes have a more mathematical flavor: [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)] >>> list(reshape(M, 4)) [(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)] + >>> list(reshape(M, 5)) + Traceback (most recent call last): + ... + ValueError: batched(): incomplete batch >>> list(reshape(M, 6)) [(0, 1, 2, 3, 4, 5), (6, 7, 8, 9, 10, 11)] >>> list(reshape(M, 12)) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 705e880d98685e..9af0730ea98004 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -187,7 +187,11 @@ def test_batched(self): [('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)]) self.assertEqual(list(batched('ABCDEFG', 1)), [('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)]) + self.assertEqual(list(batched('ABCDEF', 2, strict=True)), + [('A', 'B'), ('C', 'D'), ('E', 'F')]) + with self.assertRaises(ValueError): # Incomplete batch when strict + list(batched('ABCDEFG', 3, strict=True)) with self.assertRaises(TypeError): # Too few arguments list(batched('ABCDEFG')) with self.assertRaises(TypeError): diff --git a/Misc/NEWS.d/next/Library/2023-12-15-18-10-26.gh-issue-113202.xv_Ww8.rst b/Misc/NEWS.d/next/Library/2023-12-15-18-10-26.gh-issue-113202.xv_Ww8.rst new file mode 100644 index 00000000000000..44f26aef60a33a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-12-15-18-10-26.gh-issue-113202.xv_Ww8.rst @@ -0,0 +1 @@ +Add a ``strict`` option to ``batched()`` in the ``itertools`` module. diff --git a/Modules/clinic/itertoolsmodule.c.h b/Modules/clinic/itertoolsmodule.c.h index fa2c5e0e922387..3ec479943a83d4 100644 --- a/Modules/clinic/itertoolsmodule.c.h +++ b/Modules/clinic/itertoolsmodule.c.h @@ -10,7 +10,7 @@ preserve #include "pycore_modsupport.h" // _PyArg_UnpackKeywords() PyDoc_STRVAR(batched_new__doc__, -"batched(iterable, n)\n" +"batched(iterable, n, *, strict=False)\n" "--\n" "\n" "Batch data into tuples of length n. The last batch may be shorter than n.\n" @@ -25,10 +25,14 @@ PyDoc_STRVAR(batched_new__doc__, " ...\n" " (\'A\', \'B\', \'C\')\n" " (\'D\', \'E\', \'F\')\n" -" (\'G\',)"); +" (\'G\',)\n" +"\n" +"If \"strict\" is True, raises a ValueError if the final batch is shorter\n" +"than n."); static PyObject * -batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n); +batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n, + int strict); static PyObject * batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) @@ -36,14 +40,14 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) PyObject *return_value = NULL; #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) - #define NUM_KEYWORDS 2 + #define NUM_KEYWORDS 3 static struct { PyGC_Head _this_is_not_used; PyObject_VAR_HEAD PyObject *ob_item[NUM_KEYWORDS]; } _kwtuple = { .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) - .ob_item = { &_Py_ID(iterable), &_Py_ID(n), }, + .ob_item = { &_Py_ID(iterable), &_Py_ID(n), &_Py_ID(strict), }, }; #undef NUM_KEYWORDS #define KWTUPLE (&_kwtuple.ob_base.ob_base) @@ -52,18 +56,20 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) # define KWTUPLE NULL #endif // !Py_BUILD_CORE - static const char * const _keywords[] = {"iterable", "n", NULL}; + static const char * const _keywords[] = {"iterable", "n", "strict", NULL}; static _PyArg_Parser _parser = { .keywords = _keywords, .fname = "batched", .kwtuple = KWTUPLE, }; #undef KWTUPLE - PyObject *argsbuf[2]; + PyObject *argsbuf[3]; PyObject * const *fastargs; Py_ssize_t nargs = PyTuple_GET_SIZE(args); + Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 2; PyObject *iterable; Py_ssize_t n; + int strict = 0; fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 2, 2, 0, argsbuf); if (!fastargs) { @@ -82,7 +88,15 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) } n = ival; } - return_value = batched_new_impl(type, iterable, n); + if (!noptargs) { + goto skip_optional_kwonly; + } + strict = PyObject_IsTrue(fastargs[2]); + if (strict < 0) { + goto exit; + } +skip_optional_kwonly: + return_value = batched_new_impl(type, iterable, n, strict); exit: return return_value; @@ -914,4 +928,4 @@ itertools_count(PyTypeObject *type, PyObject *args, PyObject *kwargs) exit: return return_value; } -/*[clinic end generated code: output=782fe7e30733779b input=a9049054013a1b77]*/ +/*[clinic end generated code: output=c6a515f765da86b5 input=a9049054013a1b77]*/ diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index ab99fa4d873bf5..164741495c7baf 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -105,20 +105,11 @@ class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type" /* batched object ************************************************************/ -/* Note: The built-in zip() function includes a "strict" argument - that was needed because that function would silently truncate data, - and there was no easy way for a user to detect the data loss. - The same reasoning does not apply to batched() which never drops data. - Instead, batched() produces a shorter tuple which can be handled - as the user sees fit. If requested, it would be reasonable to add - "fillvalue" support which had demonstrated value in zip_longest(). - For now, the API is kept simple and clean. - */ - typedef struct { PyObject_HEAD PyObject *it; Py_ssize_t batch_size; + bool strict; } batchedobject; /*[clinic input] @@ -126,6 +117,9 @@ typedef struct { itertools.batched.__new__ as batched_new iterable: object n: Py_ssize_t + * + strict: bool = False + Batch data into tuples of length n. The last batch may be shorter than n. Loops over the input iterable and accumulates data into tuples @@ -140,11 +134,15 @@ or when the input iterable is exhausted. ('D', 'E', 'F') ('G',) +If "strict" is True, raises a ValueError if the final batch is shorter +than n. + [clinic start generated code]*/ static PyObject * -batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n) -/*[clinic end generated code: output=7ebc954d655371b6 input=ffd70726927c5129]*/ +batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n, + int strict) +/*[clinic end generated code: output=c6de11b061529d3e input=7814b47e222f5467]*/ { PyObject *it; batchedobject *bo; @@ -170,6 +168,7 @@ batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n) } bo->batch_size = n; bo->it = it; + bo->strict = (bool) strict; return (PyObject *)bo; } @@ -233,6 +232,12 @@ batched_next(batchedobject *bo) Py_DECREF(result); return NULL; } + if (bo->strict) { + Py_CLEAR(bo->it); + Py_DECREF(result); + PyErr_SetString(PyExc_ValueError, "batched(): incomplete batch"); + return NULL; + } _PyTuple_Resize(&result, i); return result; }