Skip to content

Commit

Permalink
Merge pull request #310 from jcrist/eq_no_default
Browse files Browse the repository at this point in the history
Use `==` instead of `is` to check for no_default
  • Loading branch information
eriknw committed Jun 2, 2016
2 parents 2fc057d + 23db5b0 commit 6198158
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 11 deletions.
22 changes: 12 additions & 10 deletions toolz/itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def accumulate(binop, seq, initial=no_default):
itertools.accumulate : In standard itertools for Python 3.2+
"""
seq = iter(seq)
result = next(seq) if initial is no_default else initial
result = next(seq) if initial == no_default else initial
yield result
for elem in seq:
result = binop(result, elem)
Expand Down Expand Up @@ -397,7 +397,7 @@ def get(ind, seq, default=no_default):
return seq[ind]
except TypeError: # `ind` may be a list
if isinstance(ind, list):
if default is no_default:
if default == no_default:
if len(ind) > 1:
return operator.itemgetter(*ind)(seq)
elif ind:
Expand All @@ -406,12 +406,12 @@ def get(ind, seq, default=no_default):
return ()
else:
return tuple(_get(i, seq, default) for i in ind)
elif default is not no_default:
elif default != no_default:
return default
else:
raise
except (KeyError, IndexError): # we know `ind` is not a list
if default is no_default:
if default == no_default:
raise
else:
return default
Expand Down Expand Up @@ -555,7 +555,8 @@ def reduceby(key, binop, seq, init=no_default):
{True: set([2, 4]),
False: set([1, 3])}
"""
if init is not no_default and not callable(init):
is_no_default = init == no_default
if not is_no_default and not callable(init):
_init = init
init = lambda: _init
if not callable(key):
Expand All @@ -564,7 +565,7 @@ def reduceby(key, binop, seq, init=no_default):
for item in seq:
k = key(item)
if k not in d:
if init is no_default:
if is_no_default:
d[k] = item
continue
else:
Expand Down Expand Up @@ -721,7 +722,7 @@ def pluck(ind, seqs, default=no_default):
get
map
"""
if default is no_default:
if default == no_default:
get = getter(ind)
return map(get, seqs)
elif isinstance(ind, list):
Expand Down Expand Up @@ -805,6 +806,7 @@ def join(leftkey, leftseq, rightkey, rightseq,
d = groupby(leftkey, leftseq)
seen_keys = set()

left_default_is_no_default = (left_default == no_default)
for item in rightseq:
key = rightkey(item)
seen_keys.add(key)
Expand All @@ -813,10 +815,10 @@ def join(leftkey, leftseq, rightkey, rightseq,
for match in left_matches:
yield (match, item)
except KeyError:
if left_default is not no_default:
if not left_default_is_no_default:
yield (left_default, item)

if right_default is not no_default:
if right_default != no_default:
for key, matches in d.items():
if key not in seen_keys:
for match in matches:
Expand Down Expand Up @@ -847,7 +849,7 @@ def diff(*seqs, **kwargs):
if N < 2:
raise TypeError('Too few sequences given (min 2 required)')
default = kwargs.get('default', no_default)
if default is no_default:
if default == no_default:
iters = zip(*seqs)
else:
iters = zip_longest(*seqs, fillvalue=default)
Expand Down
2 changes: 1 addition & 1 deletion toolz/sandbox/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def fold(binop, seq, default=no_default, map=map, chunksize=128, combine=None):
chunks = partition_all(chunksize, seq)

# Evaluate sequence in chunks via map
if default is no_default:
if default == no_default:
results = map(lambda chunk: reduce(binop, chunk), chunks)
else:
results = map(lambda chunk: reduce(binop, chunk, default), chunks)
Expand Down
6 changes: 6 additions & 0 deletions toolz/sandbox/tests/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from toolz.sandbox.parallel import fold
from toolz import reduce
from operator import add
from pickle import dumps, loads

# is comparison will fail between this and no_default
no_default2 = loads(dumps('__no__default__'))


def test_fold():
Expand All @@ -16,3 +20,5 @@ def setadd(s, item):
assert fold(setadd, [1, 2, 3], set()) == set((1, 2, 3))
assert (fold(setadd, [1, 2, 3], set(), chunksize=2, combine=set.union)
== set((1, 2, 3)))

assert fold(add, range(10), default=no_default2) == fold(add, range(10))
17 changes: 17 additions & 0 deletions toolz/tests/test_itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from toolz.utils import raises
from functools import partial
from random import Random
from pickle import dumps, loads
from toolz.itertoolz import (remove, groupby, merge_sorted,
concat, concatv, interleave, unique,
isiterable, getter,
Expand All @@ -17,6 +18,10 @@
from operator import add, mul


# is comparison will fail between this and no_default
no_default2 = loads(dumps('__no__default__'))


def identity(x):
return x

Expand Down Expand Up @@ -181,6 +186,7 @@ def test_get():
assert raises(KeyError, lambda: get(10, {'a': 1}))
assert raises(TypeError, lambda: get({}, [1, 2, 3]))
assert raises(TypeError, lambda: get([1, 2, 3], 1, None))
assert raises(KeyError, lambda: get('foo', {}, default=no_default2))


def test_mapcat():
Expand Down Expand Up @@ -248,6 +254,8 @@ def test_reduceby():

def test_reduce_by_init():
assert reduceby(iseven, add, [1, 2, 3, 4]) == {True: 2 + 4, False: 1 + 3}
assert reduceby(iseven, add, [1, 2, 3, 4], no_default2) == {True: 2 + 4,
False: 1 + 3}


def test_reduce_by_callable_default():
Expand All @@ -274,6 +282,7 @@ def binop(a, b):

start = object()
assert list(accumulate(binop, [], start)) == [start]
assert list(accumulate(add, [1, 2, 3], no_default2)) == [1, 3, 6]


def test_accumulate_works_on_consumable_iterables():
Expand Down Expand Up @@ -328,6 +337,9 @@ def test_pluck():
assert raises(IndexError, lambda: list(pluck(1, [[0]])))
assert raises(KeyError, lambda: list(pluck('name', [{'id': 1}])))

assert list(pluck(0, [[0, 1], [2, 3], [4, 5]], no_default2)) == [0, 2, 4]
assert raises(IndexError, lambda: list(pluck(1, [[0]], no_default2)))


def test_join():
names = [(1, 'one'), (2, 'two'), (3, 'three')]
Expand All @@ -345,6 +357,11 @@ def addpair(pair):

assert result == expected

result = set(starmap(add, join(first, names, second, fruit,
left_default=no_default2,
right_default=no_default2)))
assert result == expected


def test_getter():
assert getter(0)('Alice') == 'A'
Expand Down

0 comments on commit 6198158

Please sign in to comment.