Skip to content

Commit

Permalink
Fix dataloader hang with large sampler (pytorch#48669)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#48666

Pull Request resolved: pytorch#48669

Reviewed By: zhangguanheng66

Differential Revision: D25255763

Pulled By: VitalyFedyunin

fbshipit-source-id: d06421f52bb1d00cdf8025f1a2ba0d1f9284731a
  • Loading branch information
ssnl authored and shaibagon committed Dec 3, 2020
1 parent 4d3fa6c commit d5abf18
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 51 deletions.
47 changes: 47 additions & 0 deletions test/test_dataloader.py
Expand Up @@ -484,6 +484,17 @@ def __len__(self):
return self.size


class EmptyTensorDataset(torch.utils.data.Dataset):
def __init__(self, len):
self.len = len

def __len__(self):
return self.len

def __getitem__(self, any):
return torch.empty(0)


class SynchronizedSeedDataset(SynchronizedDataset):
def __getitem__(self, idx):
self.sync_once()
Expand All @@ -504,6 +515,24 @@ def _test_timeout_pin_memory(persistent_workers):
_ = next(iter(dataloader))


def _test_large_sampler_indices(persistent_workers):
# See
# test_large_sampler_indices
# https://github.com/pytorch/pytorch/issues/48666

dataloader = torch.utils.data.DataLoader(
EmptyTensorDataset(10000000),
batch_size=40960,
persistent_workers=persistent_workers,
num_workers=1)

it = iter(dataloader)

for x in it:
assert x.numel() == 0
raise RuntimeError('My Error')


def disable_stderr(worker_id):
r"""
Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
Expand Down Expand Up @@ -978,6 +1007,24 @@ def test_timeout(self):
finally:
p.terminate()

def test_large_sampler_indices(self):
# Test that the data loader cleanly exit when the process errors
# 1. having an reference to the iterator
# 2. using a sampler that yields big elements s.t. _index_queues putters block
#
# More context: https://github.com/pytorch/pytorch/issues/48666

p = ErrorTrackingProcess(target=_test_large_sampler_indices, args=(self.persistent_workers,))
p.start()
p.join(JOIN_TIMEOUT)
try:
self.assertFalse(p.is_alive())
self.assertNotEqual(p.exitcode, 0)
self.assertIsInstance(p.exception, RuntimeError)
self.assertRegex(str(p.exception), r'My Error')
finally:
p.terminate()

def test_invalid_ctor_args_combinations(self):
# general
with self.assertRaisesRegex(ValueError, "num_workers option should be non-negative"):
Expand Down
136 changes: 85 additions & 51 deletions torch/utils/data/dataloader.py
Expand Up @@ -618,46 +618,72 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# simple things like acquiring an internal lock of a queue may hang.
# Therefore, in this case, we actually need to prevent `__del__` from
# being executed, and rely on the automatic termination of daemonic
# children. Thus, we register an `atexit` hook that sets a global flag
# children.
#
# Thus, we register an `atexit` hook that sets a global flag
# `_utils.python_exit_status`. Since `atexit` hooks are executed in the
# reverse order of registration, we are guaranteed that this flag is
# set before library resources we use are freed. (Hooks freeing those
# resources are registered at importing the Python core libraries at
# the top of this file.) So in `__del__`, we check if
# `_utils.python_exit_status` is set or `None` (freed), and perform
# no-op if so.
# set before library resources we use are freed (which, at least in
# CPython, is done via an `atexit` handler defined in
# `multiprocessing/util.py`
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
# registered when an object requiring this mechanism is first
# created, e.g., `mp.Queue`
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
# )
#
# So in `__del__`, we check if `_utils.python_exit_status` is set or
# `None` (freed), and perform no-op if so.
#
# However, simply letting library clean-up codes run can also be bad,
# because such codes (i.e., `multiprocessing.util._exit_function()`)
# include join putting threads for `mp.Queue`, which can be blocking.
# Hence, the main process putting threads are called with
# `cancel_join_thread` at creation. See later section
# [ 3b. A process won't hang when putting into a queue; ]
# for more details.
#
# Here are two example cases where library clean-up codes can run
# before `__del__` is called:
#
# Another problem with `__del__` is also related to the library cleanup
# calls. When a process ends, it shuts the all its daemonic children
# down with a SIGTERM (instead of joining them without a timeout).
# Simiarly for threads, but by a different mechanism. This fact,
# together with a few implementation details of multiprocessing, forces
# us to make workers daemonic. All of our problems arise when a
# DataLoader is used in a subprocess, and are caused by multiprocessing
# code which looks more or less like this:
# 1. If we hold onto a reference to the iterator, it more often
# than not tries to do `multiprocessing` library cleaning before
# clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
# and thus prevents our cleaning-up code to run first.
#
# try:
# your_function_using_a_dataloader()
# finally:
# multiprocessing.util._exit_function()
# 2. A similar issue araises when a `DataLoader` is used in a subprocess.
# When a process ends, it shuts the all its daemonic children
# down with a SIGTERM (instead of joining them without a timeout).
# Simiarly for threads, but by a different mechanism. This fact,
# together with a few implementation details of multiprocessing, forces
# us to make workers daemonic. All of our problems arise when a
# DataLoader is used in a subprocess, and are caused by multiprocessing
# code which looks more or less like this:
#
# The joining/termination mentioned above happens inside
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
# throws, the stack trace stored in the exception will prevent the
# frame which uses `DataLoaderIter` to be freed. If the frame has any
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
# its `__del__`, which starts the shutdown procedure, will not be
# called. That, in turn, means that workers aren't notified. Attempting
# to join in `_exit_function` will then result in a hang.
# try:
# your_function_using_a_dataloader()
# finally:
# multiprocessing.util._exit_function()
#
# For context, `_exit_function` is also registered as an `atexit` call.
# So it is unclear to me (@ssnl) why this is needed in a finally block.
# The code dates back to 2008 and there is no comment on the original
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
# the finally block and the `atexit` registration) that explains this.
# The joining/termination mentioned above happens inside
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
# throws, the stack trace stored in the exception will prevent the
# frame which uses `DataLoaderIter` to be freed. If the frame has any
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
# its `__del__`, which starts the shutdown procedure, will not be
# called. That, in turn, means that workers aren't notified. Attempting
# to join in `_exit_function` will then result in a hang.
#
# Another choice is to just shutdown workers with logic in 1 above
# whenever we see an error in `next`. This isn't ideal because
# For context, `_exit_function` is also registered as an `atexit` call.
# So it is unclear to me (@ssnl) why this is needed in a finally block.
# The code dates back to 2008 and there is no comment on the original
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
# the finally block and the `atexit` registration) that explains this.
#
#
# Finally, another choice is to just shutdown workers with logic in 1
# above whenever we see an error in `next`. This isn't ideal because
# a. It prevents users from using try-catch to resume data loading.
# b. It doesn't prevent hanging if users have references to the
# iterator.
Expand Down Expand Up @@ -705,30 +731,33 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# We use `mp.Queue` which has a separate background thread to put
# objects from an unbounded buffer array. The background thread is
# daemonic and usually automatically joined when the process
# exits.
# *exits*.
#
# However, in case that the receiver has ended abruptly while
# reading from the pipe, the join will hang forever. Therefore,
# for both `worker_result_queue` (worker -> main process/pin_memory_thread)
# and each `index_queue` (main process -> worker), we use
# `q.cancel_join_thread()` in sender process before any `q.put` to
# prevent this automatic join.
#
# Moreover, having all queues called `cancel_join_thread` makes
# implementing graceful shutdown logic in `__del__` much easier.
# It won't need to get from any queue, which would also need to be
# guarded by periodic status checks.
# In case that the receiver has ended abruptly while
# reading from the pipe, the join will hang forever. The usual
# solution for this in Python is calling `q.cancel_join_thread`,
# which prevents automatically joining it when finalizing
# (exiting).
#
# Nonetheless, `cancel_join_thread` must only be called when the
# queue is **not** going to be read from or write into by another
# process, because it may hold onto a lock or leave corrupted data
# in the queue, leading other readers/writers to hang.
#
# `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does
# a blocking `put` if the queue is full. So there is no above
# problem, but we do need to wrap the `put` in a loop that breaks
# not only upon success, but also when the main process stops
# reading, i.e., is shutting down.
# Hence,
# + For worker processes, we only do so (for their output
# queues, i.e., `worker_result_queue`) before exiting.
# + For `pin_memory_thread`, its output queue `data_queue` is a
# `queue.Queue` that does blocking `put` if the queue is full.
# So there is no above problem, but as a result, in
# `_pin_memory_loop`, we do need to wrap the `put` in a loop
# that breaks not only upon success, but also when the main
# process stops reading, i.e., is shutting down.
# + For loader process, we `cancel_join_thread()` for all
# `_index_queues` because the whole purpose of workers and
# `pin_memory_thread` is to serve the loader process. If
# loader process is already exiting, we don't really care if
# the queues are corrupted.
#
#
# Now let's get back to 1:
Expand Down Expand Up @@ -867,7 +896,9 @@ def __init__(self, loader):
for i in range(self._num_workers):
# No certainty which module multiprocessing_context is
index_queue = multiprocessing_context.Queue() # type: ignore
# index_queue.cancel_join_thread()
# Need to `cancel_join_thread` here!
# See sections (2) and (3b) above.
index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
Expand Down Expand Up @@ -1234,6 +1265,9 @@ def _shutdown_workers(self):
if not self._shutdown:
self._shutdown = True
try:
# Normal exit when last reference is gone / iterator is depleted.
# See (1) and the second half of the note.

# Exit `pin_memory_thread` first because exiting workers may leave
# corrupted data in `worker_result_queue` which `pin_memory_thread`
# reads from.
Expand Down

0 comments on commit d5abf18

Please sign in to comment.