From 4997c59814adda88d1ba493e6ab0da7f6037689e Mon Sep 17 00:00:00 2001 From: Sam Van Kooten Date: Fri, 12 May 2023 11:28:05 -0600 Subject: [PATCH 1/2] Fix length detection in contrib.concurrent --- tests/tests_concurrent.py | 4 +++- tqdm/contrib/concurrent.py | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/tests_concurrent.py b/tests/tests_concurrent.py index 5cd439c94..f00410908 100644 --- a/tests/tests_concurrent.py +++ b/tests/tests_concurrent.py @@ -38,7 +38,9 @@ def test_process_map(): @mark.parametrize("iterables,should_warn", [([], False), (['x'], False), ([()], False), (['x', ()], False), (['x' * 1001], True), - (['x' * 100, ('x',) * 1001], True)]) + (['x' * 100, ('x',) * 1001], False), + (['x' * 1001, ('x',) * 100], False), + (['x' * 1001, ('x',) * 1001], True)]) def test_chunksize_warning(iterables, should_warn): """Test contrib.concurrent.process_map chunksize warnings""" patch = importorskip('unittest.mock').patch diff --git a/tqdm/contrib/concurrent.py b/tqdm/contrib/concurrent.py index cd81d622a..0deb2db24 100644 --- a/tqdm/contrib/concurrent.py +++ b/tqdm/contrib/concurrent.py @@ -39,7 +39,8 @@ def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): """ kwargs = tqdm_kwargs.copy() if "total" not in kwargs: - kwargs["total"] = length_hint(iterables[0]) + shortest_iterable_len = min(map(length_hint, iterables)) + kwargs["total"] = shortest_iterable_len tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4)) chunksize = kwargs.pop("chunksize", 1) @@ -92,12 +93,12 @@ def process_map(fn, *iterables, **tqdm_kwargs): if iterables and "chunksize" not in tqdm_kwargs: # default `chunksize=1` has poor performance for large iterables # (most time spent dispatching items to workers). - longest_iterable_len = max(map(length_hint, iterables)) - if longest_iterable_len > 1000: + shortest_iterable_len = min(map(length_hint, iterables)) + if shortest_iterable_len > 1000: from warnings import warn warn("Iterable length %d > 1000 but `chunksize` is not set." " This may seriously degrade multiprocess performance." - " Set `chunksize=1` or more." % longest_iterable_len, + " Set `chunksize=1` or more." % shortest_iterable_len, TqdmWarning, stacklevel=2) if "lock_name" not in tqdm_kwargs: tqdm_kwargs = tqdm_kwargs.copy() From dd37e6c9a4b681d56ec58ac06b960f94d8cd82da Mon Sep 17 00:00:00 2001 From: Sam Van Kooten Date: Fri, 12 May 2023 12:00:09 -0600 Subject: [PATCH 2/2] Handle no-length iterables in contrib.concurrent --- tqdm/contrib/concurrent.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tqdm/contrib/concurrent.py b/tqdm/contrib/concurrent.py index 0deb2db24..5c467394e 100644 --- a/tqdm/contrib/concurrent.py +++ b/tqdm/contrib/concurrent.py @@ -26,6 +26,16 @@ def ensure_lock(tqdm_class, lock_name=""): tqdm_class.set_lock(old_lock) +def _shortest_iterable_length(iterables): + # Negative default for iterables that have no length (e.g. itertools.repeat) + iterable_lengths = [length_hint(iterable, -1) for iterable in iterables] + # Remove the negative values + iterable_lengths = filter(lambda length: length >= 0, iterable_lengths) + # Take the shortest of the finite iterables + shortest_iterable_len = min(iterable_lengths) + return shortest_iterable_len + + def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): """ Implementation of `thread_map` and `process_map`. @@ -39,8 +49,7 @@ def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): """ kwargs = tqdm_kwargs.copy() if "total" not in kwargs: - shortest_iterable_len = min(map(length_hint, iterables)) - kwargs["total"] = shortest_iterable_len + kwargs["total"] = _shortest_iterable_length(iterables) tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4)) chunksize = kwargs.pop("chunksize", 1) @@ -93,7 +102,7 @@ def process_map(fn, *iterables, **tqdm_kwargs): if iterables and "chunksize" not in tqdm_kwargs: # default `chunksize=1` has poor performance for large iterables # (most time spent dispatching items to workers). - shortest_iterable_len = min(map(length_hint, iterables)) + shortest_iterable_len = _shortest_iterable_length(iterables) if shortest_iterable_len > 1000: from warnings import warn warn("Iterable length %d > 1000 but `chunksize` is not set."