Skip to content

Commit

Permalink
Update Dataloader with default parameter device (#65402)
Browse files Browse the repository at this point in the history
Summary:
pin_memory, has optional device parameter to specify
which device you want to pin for.  With this above change
the Dataloader will work only for CUDA backend. To add
support for other backend which supports pinned memory,
dataloader is updated with device as optional parameter.

Fixes #{issue number}

Pull Request resolved: #65402

Reviewed By: zou3519

Differential Revision: D32282204

Pulled By: VitalyFedyunin

fbshipit-source-id: e2e09876969af108d0db38af7c2d1b2f1cfa9858
  • Loading branch information
jeejakp12 authored and facebook-github-bot committed Apr 21, 2022
1 parent 5dfc723 commit 3b76e15
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 15 deletions.
13 changes: 13 additions & 0 deletions test/test_dataloader.py
Expand Up @@ -2349,6 +2349,19 @@ def test_pin_memory(self):
self.assertTrue(sample['a_tensor'].is_pinned())
self.assertTrue(sample['another_dict']['a_number'].is_pinned())

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_pin_memory_device(self):
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True, pin_memory_device='cuda')
for sample in loader:
self.assertTrue(sample['a_tensor'].is_pinned(device='cuda'))
self.assertTrue(sample['another_dict']['a_number'].is_pinned(device='cuda'))

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_pin_memory_with_only_device(self):
loader = DataLoader(self.dataset, batch_size=2, pin_memory_device='cuda')
for sample in loader:
self.assertFalse(sample['a_tensor'].is_pinned(device='cuda'))
self.assertFalse(sample['another_dict']['a_number'].is_pinned(device='cuda'))

class DummyDataset(torch.utils.data.Dataset):
def __init__(self):
Expand Down
20 changes: 10 additions & 10 deletions torch/utils/data/_utils/pin_memory.py
Expand Up @@ -14,7 +14,7 @@
from torch._utils import ExceptionWrapper


def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
torch.set_num_threads(1)
Expand All @@ -31,7 +31,7 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
idx, data = r
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
try:
data = pin_memory(data)
data = pin_memory(data, device)
except Exception:
data = ExceptionWrapper(
where="in pin memory thread for device {}".format(device_id))
Expand All @@ -45,27 +45,27 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
del r # save memory


def pin_memory(data):
def pin_memory(data, device=None):
if isinstance(data, torch.Tensor):
return data.pin_memory()
return data.pin_memory(device)
elif isinstance(data, string_classes):
return data
elif isinstance(data, collections.abc.Mapping):
try:
return type(data)({k: pin_memory(sample) for k, sample in data.items()}) # type: ignore[call-arg]
return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg]
except TypeError:
# The mapping type may not support `__init__(iterable)`.
return {k: pin_memory(sample) for k, sample in data.items()}
return {k: pin_memory(sample, device) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return type(data)(*(pin_memory(sample) for sample in data))
return type(data)(*(pin_memory(sample, device) for sample in data))
elif isinstance(data, tuple):
return [pin_memory(sample) for sample in data] # Backwards compatibility.
return [pin_memory(sample, device) for sample in data] # Backwards compatibility.
elif isinstance(data, collections.abc.Sequence):
try:
return type(data)([pin_memory(sample) for sample in data]) # type: ignore[call-arg]
return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg]
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return [pin_memory(sample) for sample in data]
return [pin_memory(sample, device) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
Expand Down
28 changes: 23 additions & 5 deletions torch/utils/data/dataloader.py
Expand Up @@ -103,7 +103,7 @@ class DataLoader(Generic[T_co]):
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
into CUDA pinned memory before returning them. If your data elements
into device/CUDA pinned memory before returning them. If your data elements
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
Expand All @@ -124,6 +124,8 @@ class DataLoader(Generic[T_co]):
persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
the worker processes after a dataset has been consumed once. This allows to
maintain the workers `Dataset` instances alive. (default: ``False``)
pin_memory_device (str, optional): the data loader will copy Tensors
into device pinned memory before returning them if pin_memory is set to true.
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
Expand Down Expand Up @@ -159,6 +161,7 @@ class DataLoader(Generic[T_co]):
drop_last: bool
timeout: float
sampler: Union[Sampler, Iterable]
pin_memory_device: str
prefetch_factor: int
_iterator : Optional['_BaseDataLoaderIter']
__initialized = False
Expand All @@ -171,7 +174,8 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
persistent_workers: bool = False,
pin_memory_device: str = ""):
torch._C._log_api_usage_once("python.data_loader")

if num_workers < 0:
Expand All @@ -193,6 +197,7 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.pin_memory_device = pin_memory_device
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
Expand Down Expand Up @@ -503,7 +508,20 @@ def __init__(self, loader: DataLoader) -> None:
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._prefetch_factor = loader.prefetch_factor
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
# for other backends, pin_memory_device need to set. if not set
# default behaviour is CUDA device. if pin_memory_device is selected
# and pin_memory is not set, the default behaviour false.
if (len(loader.pin_memory_device) == 0):
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._pin_memory_device = None
else:
if not loader.pin_memory:
warn_msg = ("pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
"please set pin_memory to true, if you need to use the device pin memory")
warnings.warn(warn_msg)

self._pin_memory = loader.pin_memory
self._pin_memory_device = loader.pin_memory_device
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
Expand Down Expand Up @@ -572,7 +590,7 @@ def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
return data


Expand Down Expand Up @@ -939,7 +957,7 @@ def __init__(self, loader):
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event))
self._pin_memory_thread_done_event, self._pin_memory_device))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
Expand Down

0 comments on commit 3b76e15

Please sign in to comment.