diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 0d6ee2e03bd6..a1afc216d42a 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -484,6 +484,17 @@ def __len__(self): return self.size +class EmptyTensorDataset(torch.utils.data.Dataset): + def __init__(self, len): + self.len = len + + def __len__(self): + return self.len + + def __getitem__(self, any): + return torch.empty(0) + + class SynchronizedSeedDataset(SynchronizedDataset): def __getitem__(self, idx): self.sync_once() @@ -504,6 +515,24 @@ def _test_timeout_pin_memory(persistent_workers): _ = next(iter(dataloader)) +def _test_large_sampler_indices(persistent_workers): + # See + # test_large_sampler_indices + # https://github.com/pytorch/pytorch/issues/48666 + + dataloader = torch.utils.data.DataLoader( + EmptyTensorDataset(10000000), + batch_size=40960, + persistent_workers=persistent_workers, + num_workers=1) + + it = iter(dataloader) + + for x in it: + assert x.numel() == 0 + raise RuntimeError('My Error') + + def disable_stderr(worker_id): r""" Avoids printing "ERROR: Unexpected segmentation fault encountered in worker." @@ -978,6 +1007,24 @@ def test_timeout(self): finally: p.terminate() + def test_large_sampler_indices(self): + # Test that the data loader cleanly exit when the process errors + # 1. having an reference to the iterator + # 2. using a sampler that yields big elements s.t. _index_queues putters block + # + # More context: https://github.com/pytorch/pytorch/issues/48666 + + p = ErrorTrackingProcess(target=_test_large_sampler_indices, args=(self.persistent_workers,)) + p.start() + p.join(JOIN_TIMEOUT) + try: + self.assertFalse(p.is_alive()) + self.assertNotEqual(p.exitcode, 0) + self.assertIsInstance(p.exception, RuntimeError) + self.assertRegex(str(p.exception), r'My Error') + finally: + p.terminate() + def test_invalid_ctor_args_combinations(self): # general with self.assertRaisesRegex(ValueError, "num_workers option should be non-negative"): diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 8d7726ebd129..d1025c02cc9b 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -618,46 +618,72 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # simple things like acquiring an internal lock of a queue may hang. # Therefore, in this case, we actually need to prevent `__del__` from # being executed, and rely on the automatic termination of daemonic - # children. Thus, we register an `atexit` hook that sets a global flag + # children. + # + # Thus, we register an `atexit` hook that sets a global flag # `_utils.python_exit_status`. Since `atexit` hooks are executed in the # reverse order of registration, we are guaranteed that this flag is - # set before library resources we use are freed. (Hooks freeing those - # resources are registered at importing the Python core libraries at - # the top of this file.) So in `__del__`, we check if - # `_utils.python_exit_status` is set or `None` (freed), and perform - # no-op if so. + # set before library resources we use are freed (which, at least in + # CPython, is done via an `atexit` handler defined in + # `multiprocessing/util.py` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 + # registered when an object requiring this mechanism is first + # created, e.g., `mp.Queue` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 + # ) + # + # So in `__del__`, we check if `_utils.python_exit_status` is set or + # `None` (freed), and perform no-op if so. + # + # However, simply letting library clean-up codes run can also be bad, + # because such codes (i.e., `multiprocessing.util._exit_function()`) + # include join putting threads for `mp.Queue`, which can be blocking. + # Hence, the main process putting threads are called with + # `cancel_join_thread` at creation. See later section + # [ 3b. A process won't hang when putting into a queue; ] + # for more details. + # + # Here are two example cases where library clean-up codes can run + # before `__del__` is called: # - # Another problem with `__del__` is also related to the library cleanup - # calls. When a process ends, it shuts the all its daemonic children - # down with a SIGTERM (instead of joining them without a timeout). - # Simiarly for threads, but by a different mechanism. This fact, - # together with a few implementation details of multiprocessing, forces - # us to make workers daemonic. All of our problems arise when a - # DataLoader is used in a subprocess, and are caused by multiprocessing - # code which looks more or less like this: + # 1. If we hold onto a reference to the iterator, it more often + # than not tries to do `multiprocessing` library cleaning before + # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) + # and thus prevents our cleaning-up code to run first. # - # try: - # your_function_using_a_dataloader() - # finally: - # multiprocessing.util._exit_function() + # 2. A similar issue araises when a `DataLoader` is used in a subprocess. + # When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Simiarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: # - # The joining/termination mentioned above happens inside - # `_exit_function()`. Now, if `your_function_using_a_dataloader()` - # throws, the stack trace stored in the exception will prevent the - # frame which uses `DataLoaderIter` to be freed. If the frame has any - # reference to the `DataLoaderIter` (e.g., in a method of the iter), - # its `__del__`, which starts the shutdown procedure, will not be - # called. That, in turn, means that workers aren't notified. Attempting - # to join in `_exit_function` will then result in a hang. + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() # - # For context, `_exit_function` is also registered as an `atexit` call. - # So it is unclear to me (@ssnl) why this is needed in a finally block. - # The code dates back to 2008 and there is no comment on the original - # PEP 371 or patch https://bugs.python.org/issue3050 (containing both - # the finally block and the `atexit` registration) that explains this. + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. # - # Another choice is to just shutdown workers with logic in 1 above - # whenever we see an error in `next`. This isn't ideal because + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # + # Finally, another choice is to just shutdown workers with logic in 1 + # above whenever we see an error in `next`. This isn't ideal because # a. It prevents users from using try-catch to resume data loading. # b. It doesn't prevent hanging if users have references to the # iterator. @@ -705,30 +731,33 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # We use `mp.Queue` which has a separate background thread to put # objects from an unbounded buffer array. The background thread is # daemonic and usually automatically joined when the process - # exits. + # *exits*. # - # However, in case that the receiver has ended abruptly while - # reading from the pipe, the join will hang forever. Therefore, - # for both `worker_result_queue` (worker -> main process/pin_memory_thread) - # and each `index_queue` (main process -> worker), we use - # `q.cancel_join_thread()` in sender process before any `q.put` to - # prevent this automatic join. - # - # Moreover, having all queues called `cancel_join_thread` makes - # implementing graceful shutdown logic in `__del__` much easier. - # It won't need to get from any queue, which would also need to be - # guarded by periodic status checks. + # In case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. The usual + # solution for this in Python is calling `q.cancel_join_thread`, + # which prevents automatically joining it when finalizing + # (exiting). # # Nonetheless, `cancel_join_thread` must only be called when the # queue is **not** going to be read from or write into by another # process, because it may hold onto a lock or leave corrupted data # in the queue, leading other readers/writers to hang. # - # `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does - # a blocking `put` if the queue is full. So there is no above - # problem, but we do need to wrap the `put` in a loop that breaks - # not only upon success, but also when the main process stops - # reading, i.e., is shutting down. + # Hence, + # + For worker processes, we only do so (for their output + # queues, i.e., `worker_result_queue`) before exiting. + # + For `pin_memory_thread`, its output queue `data_queue` is a + # `queue.Queue` that does blocking `put` if the queue is full. + # So there is no above problem, but as a result, in + # `_pin_memory_loop`, we do need to wrap the `put` in a loop + # that breaks not only upon success, but also when the main + # process stops reading, i.e., is shutting down. + # + For loader process, we `cancel_join_thread()` for all + # `_index_queues` because the whole purpose of workers and + # `pin_memory_thread` is to serve the loader process. If + # loader process is already exiting, we don't really care if + # the queues are corrupted. # # # Now let's get back to 1: @@ -867,7 +896,9 @@ def __init__(self, loader): for i in range(self._num_workers): # No certainty which module multiprocessing_context is index_queue = multiprocessing_context.Queue() # type: ignore - # index_queue.cancel_join_thread() + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() w = multiprocessing_context.Process( target=_utils.worker._worker_loop, args=(self._dataset_kind, self._dataset, index_queue, @@ -1234,6 +1265,9 @@ def _shutdown_workers(self): if not self._shutdown: self._shutdown = True try: + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + # Exit `pin_memory_thread` first because exiting workers may leave # corrupted data in `worker_result_queue` which `pin_memory_thread` # reads from.