From 7b6570a24c7e0213a91676e4e4139d5f738ff5f8 Mon Sep 17 00:00:00 2001 From: Mital Ashok Date: Thu, 31 Aug 2017 23:24:57 +0100 Subject: [PATCH] Added peekn --- doc/source/api.rst | 1 + toolz/curried/__init__.py | 1 + toolz/itertoolz.py | 22 ++++++++++++++++++++-- toolz/tests/test_itertoolz.py | 15 +++++++++++++-- 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/doc/source/api.rst b/doc/source/api.rst index f53ff19d..86ac4c0b 100644 --- a/doc/source/api.rst +++ b/doc/source/api.rst @@ -34,6 +34,7 @@ Itertoolz partition partition_all peek + peekn pluck random_sample reduceby diff --git a/toolz/curried/__init__.py b/toolz/curried/__init__.py index 43aeffd4..5a361952 100644 --- a/toolz/curried/__init__.py +++ b/toolz/curried/__init__.py @@ -80,6 +80,7 @@ partition = toolz.curry(toolz.partition) partition_all = toolz.curry(toolz.partition_all) partitionby = toolz.curry(toolz.partitionby) +peekn = toolz.curry(toolz.peekn) pluck = toolz.curry(toolz.pluck) random_sample = toolz.curry(toolz.random_sample) reduce = toolz.curry(toolz.reduce) diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index a25eea3c..54fe0c74 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -14,7 +14,7 @@ 'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv', 'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate', 'sliding_window', 'partition', 'partition_all', 'count', 'pluck', - 'join', 'tail', 'diff', 'topk', 'peek', 'random_sample') + 'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample') def remove(predicate, seq): @@ -942,7 +942,25 @@ def peek(seq): """ iterator = iter(seq) item = next(iterator) - return item, itertools.chain([item], iterator) + return item, itertools.chain((item,), iterator) + + +def peekn(n, seq): + """ Retrieve the next n elements of a sequence + + Returns a tuple of the first n elements and an iterable equivalent + to the original, still having the elements retrieved. + + >>> seq = [0, 1, 2, 3, 4] + >>> first_two, seq = peekn(2, seq) + >>> first_two + (0, 1) + >>> list(seq) + [0, 1, 2, 3, 4] + """ + iterator = iter(seq) + peeked = tuple(take(n, iterator)) + return peeked, itertools.chain(iter(peeked), iterator) def random_sample(prob, seq, random_state=None): diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 93aa856d..bf06ab6f 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -13,7 +13,7 @@ reduceby, iterate, accumulate, sliding_window, count, partition, partition_all, take_nth, pluck, join, - diff, topk, peek, random_sample) + diff, topk, peek, peekn, random_sample) from toolz.compatibility import range, filter from operator import add, mul @@ -496,12 +496,23 @@ def test_topk_is_stable(): def test_peek(): alist = ["Alice", "Bob", "Carol"] element, blist = peek(alist) - element == alist[0] + assert element == alist[0] assert list(blist) == alist assert raises(StopIteration, lambda: peek([])) +def test_peekn(): + alist = ("Alice", "Bob", "Carol") + elements, blist = peekn(2, alist) + assert elements == alist[:2] + assert tuple(blist) == alist + + elements, blist = peekn(len(alist) * 4, alist) + assert elements == alist + assert tuple(blist) == alist + + def test_random_sample(): alist = list(range(100))