Skip to content

Commit

Permalink
pin_memory support for NT
Browse files Browse the repository at this point in the history
ghstack-source-id: fcc16e3726ee7b8cd95e5d5780662dd005b4ff48
Pull Request resolved: #110404
  • Loading branch information
jbschlosser committed Oct 4, 2023
1 parent f8c0ccb commit 909b7a0
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 10 deletions.
3 changes: 2 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4405,7 +4405,7 @@
- func: is_pinned(Tensor self, Device? device=None) -> bool
variants: method
dispatch:
CUDA: is_pinned_cuda
NestedTensorCUDA, CUDA: is_pinned_cuda
MPS: is_pinned_mps
CompositeExplicitAutograd: is_pinned_default

Expand All @@ -4419,6 +4419,7 @@
dispatch:
CUDA: _pin_memory_cuda
MPS: _pin_memory_mps
NestedTensorCUDA, NestedTensorCPU: _pin_memory_nested
autogen: _pin_memory.out

- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,15 @@ Tensor cos_nested(const Tensor& self) {
return map_nt(self, at::cos);
}

Tensor _pin_memory_nested(const Tensor& self, c10::optional<Device> device) {
auto* nt_input = get_nested_tensor_impl(self);
const auto& input_buffer = nt_input->get_unsafe_storage_as_tensor();
return wrap_buffer(
at::_pin_memory(input_buffer, device),
nt_input->get_nested_sizes(),
nt_input->get_nested_strides(),
nt_input->get_storage_offsets());
}

} // namespace native
} // namespace at
5 changes: 5 additions & 0 deletions aten/src/ATen/templates/RegisterBackendSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ bool is_pinned(const Tensor& self, c10::optional<at::Device> device) {
at::Tensor _pin_memory(const Tensor& self, c10::optional<at::Device> device) {
TORCH_CHECK(self.device().is_cpu(), "cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
DispatchKeySet _dk = c10::DispatchKeySet(c10::computeDispatchKey(c10::nullopt, self.layout(), device.value_or(at::kCUDA)));
if (self.is_nested()) {
constexpr auto nested_key_set = c10::DispatchKeySet(
{c10::DispatchKey::NestedTensor, c10::DispatchKey::AutogradNestedTensor});
_dk = _dk.add(self.key_set() & nested_key_set);
}
return at::_ops::_pin_memory::redispatch(_dk, self, device);
}

Expand Down
24 changes: 15 additions & 9 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2282,16 +2282,22 @@ def test_nested_tensor_multiprocessing(self, device, context):

dataset = [torch.nested.nested_tensor([torch.randn(5)], device=device) for _ in range(10)]

loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=4,
collate_fn=_clone_collate,
multiprocessing_context=context,
)
pin_memory_settings = [False]
if device == 'cpu' and torch.cuda.is_available():
pin_memory_settings.append(True)

for pin_memory in pin_memory_settings:
loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=4,
collate_fn=_clone_collate,
pin_memory=pin_memory,
multiprocessing_context=context,
)

for i, batch in enumerate(loader):
self.assertEqual(batch[0], dataset[i])
for i, batch in enumerate(loader):
self.assertEqual(batch[0], dataset[i])

# Error case: default collate_fn doesn't currently support batches of nested tensors.
# Following the current semantics, we'd need to stack them, which isn't possible atm.
Expand Down
15 changes: 15 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
onlyCPU,
onlyCUDA,
skipMeta,
PYTORCH_CUDA_MEMCHECK,
)
from torch.testing._internal.common_dtype import floating_types_and_half
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -2941,6 +2942,20 @@ def compare_metadata(nt1, nt2):
self.assertEqual(b, nt_contiguous)
self.assertEqual(b, nt_noncontiguous)

@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
@onlyCUDA
def test_pin_memory(self, device):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
for nt in [nt_contiguous, nt_noncontiguous]:
self.assertFalse(nt.is_pinned())
pinned = nt.pin_memory(device)
self.assertTrue(pinned.is_pinned())
self.assertEqual(nt, pinned)
self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
# test that pin_memory on already pinned tensor has no effect
self.assertIs(pinned, pinned.pin_memory())
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())


instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
Expand Down

0 comments on commit 909b7a0

Please sign in to comment.