Skip to content

Commit

Permalink
Merge pull request #253 from llllllllll/scanl
Browse files Browse the repository at this point in the history
ENH: Adds `start` argument to accumulate
  • Loading branch information
mrocklin committed Aug 10, 2015
2 parents 78c7b6d + bfc45c5 commit 0d3639b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
19 changes: 12 additions & 7 deletions toolz/itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import collections
import operator
from functools import partial
from toolz.compatibility import (map, filter, filterfalse, zip, zip_longest,
iteritems)
from toolz.compatibility import (map, filterfalse, zip, zip_longest, iteritems)
from toolz.utils import no_default


__all__ = ('remove', 'accumulate', 'groupby', 'merge_sorted', 'interleave',
Expand All @@ -26,7 +26,7 @@ def remove(predicate, seq):
return filterfalse(predicate, seq)


def accumulate(binop, seq):
def accumulate(binop, seq, initial=no_default):
""" Repeatedly apply binary function to a sequence, accumulating results
>>> from operator import add, mul
Expand All @@ -42,11 +42,19 @@ def accumulate(binop, seq):
>>> sum = partial(reduce, add)
>>> cumsum = partial(accumulate, add)
Accumulate also takes an optional argument that will be used as the first
value. This is similar to reduce.
>>> list(accumulate(add, [1, 2, 3], -1))
[-1, 0, 2, 5]
>>> list(accumulate(add, [], 1))
[1]
See Also:
itertools.accumulate : In standard itertools for Python 3.2+
"""
seq = iter(seq)
result = next(seq)
result = next(seq) if initial is no_default else initial
yield result
for elem in seq:
result = binop(result, elem)
Expand Down Expand Up @@ -343,9 +351,6 @@ def last(seq):
rest = partial(drop, 1)


no_default = '__no__default__'


def _get(ind, seq, default):
try:
return seq[ind]
Expand Down
7 changes: 7 additions & 0 deletions toolz/tests/test_itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,13 @@ def test_iterate():
def test_accumulate():
assert list(accumulate(add, [1, 2, 3, 4, 5])) == [1, 3, 6, 10, 15]
assert list(accumulate(mul, [1, 2, 3, 4, 5])) == [1, 2, 6, 24, 120]
assert list(accumulate(add, [1, 2, 3, 4, 5], -1)) == [-1, 0, 2, 5, 9, 14]

def binop(a, b):
raise AssertionError('binop should not be called')

start = object()
assert list(accumulate(binop, [], start)) == [start]


def test_accumulate_works_on_consumable_iterables():
Expand Down

0 comments on commit 0d3639b

Please sign in to comment.