From 520a1218ae49e907bd11fe98f0cb162972eb229d Mon Sep 17 00:00:00 2001 From: lucas sproule Date: Mon, 22 Apr 2024 16:44:12 -0400 Subject: [PATCH] added implementation for nwise to itertools --- .gitignore | 3 +- Modules/clinic/itertoolsmodule.c.h | 36 +++++++ Modules/itertoolsmodule.c | 156 +++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 8872e9d5508ff1..6ec871ac559d40 100644 --- a/.gitignore +++ b/.gitignore @@ -46,7 +46,8 @@ gmon.out .DS_Store *.exe - +.cache +compile_commands.json # Ignore core dumps... but not Tools/msi/core/ or the like. core !core/ diff --git a/Modules/clinic/itertoolsmodule.c.h b/Modules/clinic/itertoolsmodule.c.h index 3ec479943a83d4..0078b485894c4c 100644 --- a/Modules/clinic/itertoolsmodule.c.h +++ b/Modules/clinic/itertoolsmodule.c.h @@ -134,6 +134,42 @@ pairwise_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) return return_value; } + +PyDoc_STRVAR(nwise_new__doc__, +"nwise(iterable, /)\n" +"--\n" +"\n" +"Return an iterator of overlapping pairs taken from the input iterator.\n" +"\n" +" s -> (s0,s1), (s1,s2), (s2, s3), ..."); + +static PyObject * +nwise_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n); + +static PyObject * +nwise_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) +{ + PyObject *return_value = NULL; + PyTypeObject *base_tp = clinic_state()->nwise_type; + PyObject *iterable; + Py_ssize_t n; + + if ((type == base_tp || type->tp_init == base_tp->tp_init) && + !_PyArg_NoKeywords("nwise", kwargs)) { + goto exit; + } + if (!_PyArg_CheckPositional("nwise", PyTuple_GET_SIZE(args), 2, 2)) { + goto exit; + } + iterable = PyTuple_GET_ITEM(args, 0); + n = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, 1)); + return_value = nwise_new_impl(type, iterable, n); + +exit: + return return_value; +} + + PyDoc_STRVAR(itertools_groupby__doc__, "groupby(iterable, key=None)\n" "--\n" diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 6ee447ef6a8cd6..f6d7192a2024dc 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -28,6 +28,7 @@ typedef struct { PyTypeObject *_grouper_type; PyTypeObject *islice_type; PyTypeObject *pairwise_type; + PyTypeObject *nwise_type; PyTypeObject *permutations_type; PyTypeObject *product_type; PyTypeObject *repeat_type; @@ -84,6 +85,7 @@ class itertools.compress "compressobject *" "clinic_state()->compress_type" class itertools.filterfalse "filterfalseobject *" "clinic_state()->filterfalse_type" class itertools.count "countobject *" "clinic_state()->count_type" class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type" +class itertools.nwise "nwiseobject *" "clinic_state()->nwise_type" [clinic start generated code]*/ /*[clinic end generated code: output=da39a3ee5e6b4b0d input=aa48fe4de9d4080f]*/ @@ -384,6 +386,157 @@ static PyType_Spec pairwise_spec = { }; +/* nwise object ***********************************************************/ + +typedef struct { + PyObject_HEAD PyObject *it; + PyObject*list; + Py_ssize_t n; + Py_ssize_t current_index; +} nwiseobject; + +/*[clinic input] +@classmethod +itertools.nwise.__new__ as nwise_new + iterable: object + n: integer + / +Return an iterator of overlapping n values taken from the input iterator. + + s -> (s0,s1, s2), (s1,s2, s3), (s2, s3, s4), ... + +[clinic start generated code]*/ + +static PyObject *nwise_new_impl(PyTypeObject *type, PyObject *iterable, + Py_ssize_t n) { + PyObject *it; + nwiseobject *no; + it = PyObject_GetIter(iterable); + + if (it == NULL) { + return NULL; + } + no = (nwiseobject *)type->tp_alloc(type, 0); + if (no == NULL) { + Py_DECREF(it); + return NULL; + } + no->it = it; + no->n = n; + no->list = PyList_New(0); + if (no->list== NULL) { + // Failed to create a new list + return NULL; + } + no->current_index = 0; + + return (PyObject *)no; +} + +static void nwise_dealloc(nwiseobject *no) { + PyTypeObject *tp = Py_TYPE(no); + PyObject_GC_UnTrack(no); + Py_XDECREF(no->it); + + tp->tp_free(no); + Py_DECREF(tp); +} + +static int nwise_traverse(nwiseobject *no, visitproc visit, void *arg) { + Py_VISIT(Py_TYPE(no)); + Py_VISIT(no->it); + + return 0; +} + +PyObject *nwise_next(nwiseobject *no) { + PyObject *it = no->it; + PyObject *item; + + if (it == NULL) { + return NULL; + } + + + + // Create a list to store items + + + iternextfunc iternext = *Py_TYPE(it)->tp_iternext; + if (PyList_Size(no->list) < no->n) { + for (Py_ssize_t i = 0; i < no->n; i++) { + item = iternext(it); + if (item == NULL) { + // Iterator is exhausted prematurely + Py_CLEAR(no->it); + Py_DECREF(no->list); + return NULL; + } + PyList_Append(no->list, item); // Set items in the list + } + PyObject *tuple_result = PyList_AsTuple(no->list); + return tuple_result; + } else { + // If the list is already full, remove the oldest item and add the next one + PyObject *first_item = PyList_GetItem(no->list, 0); + Py_DECREF(first_item); + + // Shift the remaining items to the left + for (Py_ssize_t i = 1; i < no->n; i++) { + PyObject *current_item = PyList_GetItem(no->list, i); + PyList_SET_ITEM(no->list, i - 1, current_item); + } + + // Retrieve the next item from the iterator and add it to the end of the list + item = iternext(it); + if (PyErr_Occurred()) { + if (!PyErr_ExceptionMatches(PyExc_StopIteration)) { + /* Input raised an exception other than StopIteration */ + Py_CLEAR(no->it); + Py_DECREF(no->list); + return NULL; + } + PyErr_Clear(); + } + + + if (item == NULL) { + // 'iternext' returned NULL unexpectedly + Py_CLEAR(no->it); + Py_DECREF(no->list); + return NULL; + } + PyList_SET_ITEM(no->list, no->n - 1, item); + + } + + // Convert the list to a tuple and return + PyObject *tuple_result = PyList_AsTuple(no->list); + return tuple_result; +} + + +static PyType_Slot nwise_slots[] = { + {Py_tp_dealloc, nwise_dealloc}, + {Py_tp_getattro, PyObject_GenericGetAttr}, + {Py_tp_doc, (void *)nwise_new__doc__}, + {Py_tp_traverse, nwise_traverse}, + {Py_tp_iter, PyObject_SelfIter}, + {Py_tp_iternext, nwise_next}, + {Py_tp_alloc, PyType_GenericAlloc}, + {Py_tp_new, nwise_new}, + {Py_tp_free, PyObject_GC_Del}, + {0, NULL}, +}; + +static PyType_Spec nwise_spec = { + .name = "itertools.nwise", + .basicsize = sizeof(nwiseobject), + .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE | + Py_TPFLAGS_IMMUTABLETYPE), + .slots = nwise_slots, +}; + /* groupby object ************************************************************/ typedef struct { @@ -4660,6 +4813,7 @@ itertoolsmodule_traverse(PyObject *mod, visitproc visit, void *arg) Py_VISIT(state->_grouper_type); Py_VISIT(state->islice_type); Py_VISIT(state->pairwise_type); + Py_VISIT(state->nwise_type); Py_VISIT(state->permutations_type); Py_VISIT(state->product_type); Py_VISIT(state->repeat_type); @@ -4689,6 +4843,7 @@ itertoolsmodule_clear(PyObject *mod) Py_CLEAR(state->_grouper_type); Py_CLEAR(state->islice_type); Py_CLEAR(state->pairwise_type); + Py_CLEAR(state->nwise_type); Py_CLEAR(state->permutations_type); Py_CLEAR(state->product_type); Py_CLEAR(state->repeat_type); @@ -4735,6 +4890,7 @@ itertoolsmodule_exec(PyObject *mod) ADD_TYPE(mod, state->_grouper_type, &_grouper_spec); ADD_TYPE(mod, state->islice_type, &islice_spec); ADD_TYPE(mod, state->pairwise_type, &pairwise_spec); + ADD_TYPE(mod, state->nwise_type, &nwise_spec); ADD_TYPE(mod, state->permutations_type, &permutations_spec); ADD_TYPE(mod, state->product_type, &product_spec); ADD_TYPE(mod, state->repeat_type, &repeat_spec);