diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index 76353e94..9e886ab7 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -9,11 +9,12 @@ from toolz.utils import no_default -__all__ = ('remove', 'accumulate', 'groupby', 'merge_sorted', 'interleave', - 'unique', 'isiterable', 'isdistinct', 'take', 'drop', 'take_nth', - 'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv', - 'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate', - 'sliding_window', 'partition', 'partition_all', 'count', 'pluck', +__all__ = ('remove', 'accumulate', 'groupby', 'indices', 'merge_sorted', + 'interleave', 'unique', 'isiterable', 'isdistinct', 'take', + 'drop', 'take_nth', 'first', 'second', 'nth', 'last', 'get', + 'concat', 'concatv', 'mapcat', 'cons', 'interpose', + 'frequencies', 'reduceby', 'iterate', 'sliding_window', + 'partition', 'partition_all', 'count', 'pluck', 'join', 'tail', 'diff', 'topk', 'peek', 'random_sample') @@ -97,6 +98,32 @@ def groupby(key, seq): return rv +def indices(*sizes): + """ Iterates over a length/shape. + + >>> list(indices(3, 2)) + [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)] + + This can help nicely index an array. + + >>> l = [[1, 2], + ... [3, 4], + ... [5, 6]] + + >>> for i, j in indices(3, 2): + ... print("l[%i][%i] = %i" % (i, j, l[i][j])) + l[0][0] = 1 + l[0][1] = 2 + l[1][0] = 3 + l[1][1] = 4 + l[2][0] = 5 + l[2][1] = 6 + + """ + + return itertools.product(*map(range, sizes)) + + def merge_sorted(*seqs, **kwargs): """ Merge and sort a collection of sorted collections diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 93aa856d..1afe27e0 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -4,7 +4,7 @@ from functools import partial from random import Random from pickle import dumps, loads -from toolz.itertoolz import (remove, groupby, merge_sorted, +from toolz.itertoolz import (remove, groupby, indices, merge_sorted, concat, concatv, interleave, unique, isiterable, getter, mapcat, isdistinct, first, second, @@ -52,6 +52,31 @@ def test_groupby(): assert groupby(iseven, [1, 2, 3, 4]) == {True: [2, 4], False: [1, 3]} +def test_indices(): + assert list(indices(0)) == [] + assert list(indices(0, 5)) == [] + assert list(indices(5, 0)) == [] + + assert list(indices(5)) == [(0,), + (1,), + (2,), + (3,), + (4,)] + + assert list(indices(1, 5)) == [(0, 0,), + (0, 1,), + (0, 2,), + (0, 3,), + (0, 4,)] + + assert list(indices(3, 2)) == [(0, 0), + (0, 1), + (1, 0), + (1, 1), + (2, 0), + (2, 1)] + + def test_groupby_non_callable(): assert groupby(0, [(1, 2), (1, 3), (2, 2), (2, 4)]) == \ {1: [(1, 2), (1, 3)],