From 909b7a0ccb549f5441b65db2c212140ba8df6bf8 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Wed, 4 Oct 2023 17:43:48 -0400 Subject: [PATCH] pin_memory support for NT ghstack-source-id: fcc16e3726ee7b8cd95e5d5780662dd005b4ff48 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110404 --- aten/src/ATen/native/native_functions.yaml | 3 ++- .../native/nested/NestedTensorUnaryOps.cpp | 10 ++++++++ .../ATen/templates/RegisterBackendSelect.cpp | 5 ++++ test/test_dataloader.py | 24 ++++++++++++------- test/test_nestedtensor.py | 15 ++++++++++++ 5 files changed, 47 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 067c94c40cfdc..59814637a74e9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4405,7 +4405,7 @@ - func: is_pinned(Tensor self, Device? device=None) -> bool variants: method dispatch: - CUDA: is_pinned_cuda + NestedTensorCUDA, CUDA: is_pinned_cuda MPS: is_pinned_mps CompositeExplicitAutograd: is_pinned_default @@ -4419,6 +4419,7 @@ dispatch: CUDA: _pin_memory_cuda MPS: _pin_memory_mps + NestedTensorCUDA, NestedTensorCPU: _pin_memory_nested autogen: _pin_memory.out - func: pinverse(Tensor self, float rcond=1e-15) -> Tensor diff --git a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp index e01535323ea92..c41b6f15214aa 100644 --- a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp +++ b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp @@ -132,5 +132,15 @@ Tensor cos_nested(const Tensor& self) { return map_nt(self, at::cos); } +Tensor _pin_memory_nested(const Tensor& self, c10::optional device) { + auto* nt_input = get_nested_tensor_impl(self); + const auto& input_buffer = nt_input->get_unsafe_storage_as_tensor(); + return wrap_buffer( + at::_pin_memory(input_buffer, device), + nt_input->get_nested_sizes(), + nt_input->get_nested_strides(), + nt_input->get_storage_offsets()); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/templates/RegisterBackendSelect.cpp b/aten/src/ATen/templates/RegisterBackendSelect.cpp index 6463701a4939f..dcb5986ab69ed 100644 --- a/aten/src/ATen/templates/RegisterBackendSelect.cpp +++ b/aten/src/ATen/templates/RegisterBackendSelect.cpp @@ -36,6 +36,11 @@ bool is_pinned(const Tensor& self, c10::optional device) { at::Tensor _pin_memory(const Tensor& self, c10::optional device) { TORCH_CHECK(self.device().is_cpu(), "cannot pin '", self.toString(), "' only dense CPU tensors can be pinned"); DispatchKeySet _dk = c10::DispatchKeySet(c10::computeDispatchKey(c10::nullopt, self.layout(), device.value_or(at::kCUDA))); + if (self.is_nested()) { + constexpr auto nested_key_set = c10::DispatchKeySet( + {c10::DispatchKey::NestedTensor, c10::DispatchKey::AutogradNestedTensor}); + _dk = _dk.add(self.key_set() & nested_key_set); + } return at::_ops::_pin_memory::redispatch(_dk, self, device); } diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 1e0080f900fa7..7188373103049 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -2282,16 +2282,22 @@ def test_nested_tensor_multiprocessing(self, device, context): dataset = [torch.nested.nested_tensor([torch.randn(5)], device=device) for _ in range(10)] - loader = torch.utils.data.DataLoader( - dataset, - batch_size=1, - num_workers=4, - collate_fn=_clone_collate, - multiprocessing_context=context, - ) + pin_memory_settings = [False] + if device == 'cpu' and torch.cuda.is_available(): + pin_memory_settings.append(True) + + for pin_memory in pin_memory_settings: + loader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + num_workers=4, + collate_fn=_clone_collate, + pin_memory=pin_memory, + multiprocessing_context=context, + ) - for i, batch in enumerate(loader): - self.assertEqual(batch[0], dataset[i]) + for i, batch in enumerate(loader): + self.assertEqual(batch[0], dataset[i]) # Error case: default collate_fn doesn't currently support batches of nested tensors. # Following the current semantics, we'd need to stack them, which isn't possible atm. diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index aec2297675d91..64b1f050646ef 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -15,6 +15,7 @@ onlyCPU, onlyCUDA, skipMeta, + PYTORCH_CUDA_MEMCHECK, ) from torch.testing._internal.common_dtype import floating_types_and_half from torch.testing._internal.common_utils import ( @@ -2941,6 +2942,20 @@ def compare_metadata(nt1, nt2): self.assertEqual(b, nt_contiguous) self.assertEqual(b, nt_noncontiguous) + @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") + @onlyCUDA + def test_pin_memory(self, device): + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) + for nt in [nt_contiguous, nt_noncontiguous]: + self.assertFalse(nt.is_pinned()) + pinned = nt.pin_memory(device) + self.assertTrue(pinned.is_pinned()) + self.assertEqual(nt, pinned) + self.assertNotEqual(nt.data_ptr(), pinned.data_ptr()) + # test that pin_memory on already pinned tensor has no effect + self.assertIs(pinned, pinned.pin_memory()) + self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) + instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals())