diff --git a/tests/tests_itertools.py b/tests/tests_itertools.py index bfb6eb2cd..0c030e751 100644 --- a/tests/tests_itertools.py +++ b/tests/tests_itertools.py @@ -24,3 +24,5 @@ def test_product(): assert list(product(a, a[::-1], file=our_file)) == list(it.product(a, a[::-1])) assert list(product(a, NoLenIter(a), file=our_file)) == list(it.product(a, NoLenIter(a))) + + assert list(product(a, repeat=2, file=our_file)) == list(it.product(a, repeat=2)) diff --git a/tqdm/contrib/itertools.py b/tqdm/contrib/itertools.py index e67651a41..2d3a7f194 100644 --- a/tqdm/contrib/itertools.py +++ b/tqdm/contrib/itertools.py @@ -9,7 +9,7 @@ __all__ = ['product'] -def product(*iterables, **tqdm_kwargs): +def product(*iterables, repeat=1, **tqdm_kwargs): """ Equivalent of `itertools.product`. @@ -29,7 +29,7 @@ def product(*iterables, **tqdm_kwargs): total *= i kwargs.setdefault("total", total) with tqdm_class(**kwargs) as t: - it = itertools.product(*iterables) + it = itertools.product(*iterables, repeat=repeat) for i in it: yield i t.update()