Skip to content

Commit

Permalink
Merge 5435591 into 10c26d8
Browse files Browse the repository at this point in the history
  • Loading branch information
mjstevens777 committed Oct 17, 2018
2 parents 10c26d8 + 5435591 commit 4dd0cd0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 20 deletions.
57 changes: 37 additions & 20 deletions tqdm/_tqdm.py
Expand Up @@ -32,7 +32,6 @@
'TqdmExperimentalWarning', 'TqdmDeprecationWarning',
'TqdmMonitorWarning']


class TqdmTypeError(TypeError):
pass

Expand Down Expand Up @@ -69,30 +68,23 @@ class TqdmMonitorWarning(TqdmWarning, RuntimeWarning):
pass


# Create global parallelism locks to avoid racing issues with parallel bars
# works only if fork available (Linux, MacOSX, but not on Windows)
try:
mp_lock = mp.RLock() # multiprocessing lock
except ImportError: # pragma: no cover
mp_lock = None
except OSError: # pragma: no cover
mp_lock = None
try:
th_lock = th.RLock() # thread lock
except OSError: # pragma: no cover
th_lock = None


class TqdmDefaultWriteLock(object):
"""
Provide a default write lock for thread and multiprocessing safety.
Works only on platforms supporting `fork` (so Windows is excluded).
You must initialize an instance of tqdm or TqdmDefaultWriteLock before
forking in order for the write lock to work.
On Windows, you need to supply the lock from the parent to the children as
an argument to joblib or the parallelism lib you use.
"""

def __init__(self):
global mp_lock, th_lock
self.locks = [lk for lk in [mp_lock, th_lock] if lk is not None]
# Create global parallelism locks to avoid racing issues with parallel bars
# works only if fork available (Linux, MacOSX, but not on Windows)
self.create_mp_lock()
self.create_th_lock()
cls = type(self)
self.locks = [lock for lock in [cls.mp_lock, cls.th_lock] if lock is not None]

def acquire(self):
for lock in self.locks:
Expand All @@ -108,6 +100,31 @@ def __enter__(self):
def __exit__(self, *exc):
self.release()

@classmethod
def create_mp_lock(cls):
if not hasattr(cls, 'mp_lock'):
try:
cls.mp_lock = mp.RLock() # multiprocessing lock
except ImportError: # pragma: no cover
cls.mp_lock = None
except OSError: # pragma: no cover
cls.mp_lock = None

@classmethod
def create_th_lock(cls):
if not hasattr(cls, 'th_lock'):
try:
cls.th_lock = th.RLock() # thread lock
except OSError: # pragma: no cover
cls.th_lock = None


# Create a thread lock before instantiation so that no setup needs to be done
# before running in a multithreaded environment.
# Do not create the multiprocessing lock because it sets the multiprocessing
# context and does not allow the user to use 'spawn' or 'forkserver' methods.
TqdmDefaultWriteLock.create_th_lock()


class tqdm(Comparable):
"""
Expand All @@ -118,7 +135,7 @@ class tqdm(Comparable):

monitor_interval = 10 # set to 0 to disable the thread
monitor = None
_lock = TqdmDefaultWriteLock()
# _instances and _lock defined in __new__

@staticmethod
def format_sizeof(num, suffix='', divisor=1000):
Expand Down Expand Up @@ -444,9 +461,9 @@ def __new__(cls, *args, **kwargs):
# Create a new instance
instance = object.__new__(cls)
# Add to the list of instances
if "_instances" not in cls.__dict__:
if not hasattr(cls, '_instances'):
cls._instances = WeakSet()
if "_lock" not in cls.__dict__:
if not hasattr(cls, '_lock'):
cls._lock = TqdmDefaultWriteLock()
with cls._lock:
cls._instances.add(instance)
Expand Down
38 changes: 38 additions & 0 deletions tqdm/tests/tests_tqdm.py
Expand Up @@ -710,6 +710,44 @@ def test_smoothed_dynamic_min_iters_with_min_interval():
assert '14%' in out and '14%' in out2


@with_setup(pretest, posttest)
def test_rlock_creation():
"""Test that importing tqdm does not create multiprocessing objects."""
import multiprocessing as mp
if sys.version_info < (3, 3):
# unittest.mock is a 3.3+ feature
raise SkipTest

# Use 'spawn' instead of 'fork' so that the process does not inherit any
# globals that have been constructed by running other tests
ctx = mp.get_context('spawn')
with ctx.Pool(1) as pool:
# The pool will propagate the error if the target method fails
pool.apply(_rlock_creation_target)


def _rlock_creation_target():
"""Check that the RLock has not been constructed."""
from unittest.mock import patch
import multiprocessing as mp

# Patch the RLock class/method but use the original implementation
with patch('multiprocessing.RLock', wraps=mp.RLock) as rlock_mock:
# Importing the module should not create a lock
from tqdm import tqdm
assert rlock_mock.call_count == 0
# Creating a progress bar should initialize the lock
with closing(StringIO()) as our_file:
with tqdm(file=our_file) as t:
pass
assert rlock_mock.call_count == 1
# Creating a progress bar again should reuse the lock
with closing(StringIO()) as our_file:
with tqdm(file=our_file) as t:
pass
assert rlock_mock.call_count == 1


@with_setup(pretest, posttest)
def test_disable():
"""Test disable"""
Expand Down

0 comments on commit 4dd0cd0

Please sign in to comment.