From 3b76e151964fce442e27fe8fb5c37af930da4fa1 Mon Sep 17 00:00:00 2001 From: Jeeja Date: Wed, 20 Apr 2022 18:27:48 -0700 Subject: [PATCH] Update Dataloader with default parameter device (#65402) Summary: pin_memory, has optional device parameter to specify which device you want to pin for. With this above change the Dataloader will work only for CUDA backend. To add support for other backend which supports pinned memory, dataloader is updated with device as optional parameter. Fixes #{issue number} Pull Request resolved: https://github.com/pytorch/pytorch/pull/65402 Reviewed By: zou3519 Differential Revision: D32282204 Pulled By: VitalyFedyunin fbshipit-source-id: e2e09876969af108d0db38af7c2d1b2f1cfa9858 --- test/test_dataloader.py | 13 +++++++++++++ torch/utils/data/_utils/pin_memory.py | 20 +++++++++---------- torch/utils/data/dataloader.py | 28 ++++++++++++++++++++++----- 3 files changed, 46 insertions(+), 15 deletions(-) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 4900cd31516a..e57737dda155 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -2349,6 +2349,19 @@ def test_pin_memory(self): self.assertTrue(sample['a_tensor'].is_pinned()) self.assertTrue(sample['another_dict']['a_number'].is_pinned()) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_pin_memory_device(self): + loader = DataLoader(self.dataset, batch_size=2, pin_memory=True, pin_memory_device='cuda') + for sample in loader: + self.assertTrue(sample['a_tensor'].is_pinned(device='cuda')) + self.assertTrue(sample['another_dict']['a_number'].is_pinned(device='cuda')) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_pin_memory_with_only_device(self): + loader = DataLoader(self.dataset, batch_size=2, pin_memory_device='cuda') + for sample in loader: + self.assertFalse(sample['a_tensor'].is_pinned(device='cuda')) + self.assertFalse(sample['another_dict']['a_number'].is_pinned(device='cuda')) class DummyDataset(torch.utils.data.Dataset): def __init__(self): diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index e5c73a542639..fd2879228d76 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -14,7 +14,7 @@ from torch._utils import ExceptionWrapper -def _pin_memory_loop(in_queue, out_queue, device_id, done_event): +def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device): # This setting is thread local, and prevents the copy in pin_memory from # consuming all CPU cores. torch.set_num_threads(1) @@ -31,7 +31,7 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event): idx, data = r if not done_event.is_set() and not isinstance(data, ExceptionWrapper): try: - data = pin_memory(data) + data = pin_memory(data, device) except Exception: data = ExceptionWrapper( where="in pin memory thread for device {}".format(device_id)) @@ -45,27 +45,27 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event): del r # save memory -def pin_memory(data): +def pin_memory(data, device=None): if isinstance(data, torch.Tensor): - return data.pin_memory() + return data.pin_memory(device) elif isinstance(data, string_classes): return data elif isinstance(data, collections.abc.Mapping): try: - return type(data)({k: pin_memory(sample) for k, sample in data.items()}) # type: ignore[call-arg] + return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg] except TypeError: # The mapping type may not support `__init__(iterable)`. - return {k: pin_memory(sample) for k, sample in data.items()} + return {k: pin_memory(sample, device) for k, sample in data.items()} elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple - return type(data)(*(pin_memory(sample) for sample in data)) + return type(data)(*(pin_memory(sample, device) for sample in data)) elif isinstance(data, tuple): - return [pin_memory(sample) for sample in data] # Backwards compatibility. + return [pin_memory(sample, device) for sample in data] # Backwards compatibility. elif isinstance(data, collections.abc.Sequence): try: - return type(data)([pin_memory(sample) for sample in data]) # type: ignore[call-arg] + return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg] except TypeError: # The sequence type may not support `__init__(iterable)` (e.g., `range`). - return [pin_memory(sample) for sample in data] + return [pin_memory(sample, device) for sample in data] elif hasattr(data, "pin_memory"): return data.pin_memory() else: diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index eac1f6778a77..b76a6dc5c331 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -103,7 +103,7 @@ class DataLoader(Generic[T_co]): mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. pin_memory (bool, optional): If ``True``, the data loader will copy Tensors - into CUDA pinned memory before returning them. If your data elements + into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, see the example below. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, @@ -124,6 +124,8 @@ class DataLoader(Generic[T_co]): persistent_workers (bool, optional): If ``True``, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. (default: ``False``) + pin_memory_device (str, optional): the data loader will copy Tensors + into device pinned memory before returning them if pin_memory is set to true. .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` @@ -159,6 +161,7 @@ class DataLoader(Generic[T_co]): drop_last: bool timeout: float sampler: Union[Sampler, Iterable] + pin_memory_device: str prefetch_factor: int _iterator : Optional['_BaseDataLoaderIter'] __initialized = False @@ -171,7 +174,8 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, - persistent_workers: bool = False): + persistent_workers: bool = False, + pin_memory_device: str = ""): torch._C._log_api_usage_once("python.data_loader") if num_workers < 0: @@ -193,6 +197,7 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, self.num_workers = num_workers self.prefetch_factor = prefetch_factor self.pin_memory = pin_memory + self.pin_memory_device = pin_memory_device self.timeout = timeout self.worker_init_fn = worker_init_fn self.multiprocessing_context = multiprocessing_context @@ -503,7 +508,20 @@ def __init__(self, loader: DataLoader) -> None: self._index_sampler = loader._index_sampler self._num_workers = loader.num_workers self._prefetch_factor = loader.prefetch_factor - self._pin_memory = loader.pin_memory and torch.cuda.is_available() + # for other backends, pin_memory_device need to set. if not set + # default behaviour is CUDA device. if pin_memory_device is selected + # and pin_memory is not set, the default behaviour false. + if (len(loader.pin_memory_device) == 0): + self._pin_memory = loader.pin_memory and torch.cuda.is_available() + self._pin_memory_device = None + else: + if not loader.pin_memory: + warn_msg = ("pin memory device is set and pin_memory flag is not used then device pinned memory won't be used" + "please set pin_memory to true, if you need to use the device pin memory") + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory + self._pin_memory_device = loader.pin_memory_device self._timeout = loader.timeout self._collate_fn = loader.collate_fn self._sampler_iter = iter(self._index_sampler) @@ -572,7 +590,7 @@ def _next_data(self): index = self._next_index() # may raise StopIteration data = self._dataset_fetcher.fetch(index) # may raise StopIteration if self._pin_memory: - data = _utils.pin_memory.pin_memory(data) + data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) return data @@ -939,7 +957,7 @@ def __init__(self, loader): target=_utils.pin_memory._pin_memory_loop, args=(self._worker_result_queue, self._data_queue, torch.cuda.current_device(), - self._pin_memory_thread_done_event)) + self._pin_memory_thread_done_event, self._pin_memory_device)) pin_memory_thread.daemon = True pin_memory_thread.start() # Similar to workers (see comment above), we only register