From afcd33b8feb1c808612a8d97c5a5bc7626d67789 Mon Sep 17 00:00:00 2001 From: Tongzhou Wang Date: Thu, 1 Aug 2019 02:04:21 -0400 Subject: [PATCH] fix pin_memory_thread not exiting quickly --- torch/utils/data/_utils/pin_memory.py | 2 +- torch/utils/data/dataloader.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index cebaacc844219..b73306a32ae13 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -22,7 +22,7 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event): except queue.Empty: continue idx, data = r - if not isinstance(data, ExceptionWrapper): + if not done_event.is_set() and not isinstance(data, ExceptionWrapper): try: data = pin_memory(data) except Exception: diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 33621bd9b5de2..c37378679eb5b 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -886,9 +886,13 @@ def _shutdown_workers(self): # corrupted data in `worker_result_queue` which `pin_memory_thread` # reads from. if hasattr(self, 'pin_memory_thread'): - self.pin_memory_thread_done_event.set() # Use hasattr in case error happens before we set the attribute. + self.pin_memory_thread_done_event.set() + # Send something to pin_memory_thread in case it is waiting + # so that it can wake up and check `pin_memory_thread_done_event` + self.worker_result_queue.put((None, None)) self.pin_memory_thread.join() + self.worker_result_queue.close() # Exit workers now. self.workers_done_event.set()