Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pin_memory support for NT #110404

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4405,7 +4405,7 @@
- func: is_pinned(Tensor self, Device? device=None) -> bool
variants: method
dispatch:
CUDA: is_pinned_cuda
NestedTensorCUDA, CUDA: is_pinned_cuda
MPS: is_pinned_mps
CompositeExplicitAutograd: is_pinned_default

Expand All @@ -4419,6 +4419,7 @@
dispatch:
CUDA: _pin_memory_cuda
MPS: _pin_memory_mps
NestedTensorCUDA, NestedTensorCPU: _pin_memory_nested
autogen: _pin_memory.out

- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,15 @@ Tensor cos_nested(const Tensor& self) {
return map_nt(self, at::cos);
}

Tensor _pin_memory_nested(const Tensor& self, c10::optional<Device> device) {
auto* nt_input = get_nested_tensor_impl(self);
const auto& input_buffer = nt_input->get_unsafe_storage_as_tensor();
return wrap_buffer(
at::_pin_memory(input_buffer, device),
nt_input->get_nested_sizes(),
nt_input->get_nested_strides(),
nt_input->get_storage_offsets());
}

} // namespace native
} // namespace at
5 changes: 5 additions & 0 deletions aten/src/ATen/templates/RegisterBackendSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ bool is_pinned(const Tensor& self, c10::optional<at::Device> device) {
at::Tensor _pin_memory(const Tensor& self, c10::optional<at::Device> device) {
TORCH_CHECK(self.device().is_cpu(), "cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
DispatchKeySet _dk = c10::DispatchKeySet(c10::computeDispatchKey(c10::nullopt, self.layout(), device.value_or(at::kCUDA)));
if (self.is_nested()) {
auto nested_key_set = c10::DispatchKeySet(
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
{c10::DispatchKey::NestedTensor, c10::DispatchKey::AutogradNestedTensor});
_dk = _dk.add(self.key_set() & nested_key_set);
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
}
return at::_ops::_pin_memory::redispatch(_dk, self, device);
}

Expand Down
24 changes: 15 additions & 9 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2282,16 +2282,22 @@ def test_nested_tensor_multiprocessing(self, device, context):

dataset = [torch.nested.nested_tensor([torch.randn(5)], device=device) for _ in range(10)]

loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=4,
collate_fn=_clone_collate,
multiprocessing_context=context,
)
pin_memory_settings = [False]
if device == 'cpu' and torch.cuda.is_available():
pin_memory_settings.append(True)

for pin_memory in pin_memory_settings:
loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=4,
collate_fn=_clone_collate,
pin_memory=pin_memory,
multiprocessing_context=context,
)

for i, batch in enumerate(loader):
self.assertEqual(batch[0], dataset[i])
for i, batch in enumerate(loader):
self.assertEqual(batch[0], dataset[i])

# Error case: default collate_fn doesn't currently support batches of nested tensors.
# Following the current semantics, we'd need to stack them, which isn't possible atm.
Expand Down
15 changes: 15 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
onlyCPU,
onlyCUDA,
skipMeta,
PYTORCH_CUDA_MEMCHECK,
)
from torch.testing._internal.common_dtype import floating_types_and_half
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -2941,6 +2942,20 @@ def compare_metadata(nt1, nt2):
self.assertEqual(b, nt_contiguous)
self.assertEqual(b, nt_noncontiguous)

@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
@onlyCUDA
def test_pin_memory(self, device):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
for nt in [nt_contiguous, nt_noncontiguous]:
self.assertFalse(nt.is_pinned())
pinned = nt.pin_memory(device)
self.assertTrue(pinned.is_pinned())
self.assertEqual(nt, pinned)
self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
# test that pin_memory on already pinned tensor has no effect
self.assertIs(pinned, pinned.pin_memory())
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())


instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
Expand Down