Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pin_memory support for NT #110404

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4405,14 +4405,17 @@
- func: is_pinned(Tensor self, Device? device=None) -> bool
variants: method
dispatch:
CUDA: is_pinned_cuda
NestedTensorCUDA, CUDA: is_pinned_cuda
MPS: is_pinned_mps
CompositeExplicitAutograd: is_pinned_default

# TODO: add a copy kwarg that guarantees that the tensor is put into fresh
# pinned memory
- func: pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)
variants: method
dispatch:
NestedTensorCUDA, NestedTensorCPU: pin_memory_nested
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
CompositeImplicitAutograd: pin_memory

# Unlike pin_memory, this is guaranteed to give a new non-aliasing tensor
- func: _pin_memory(Tensor self, Device? device=None) -> Tensor
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,14 @@ Tensor cos_nested(const Tensor& self) {
return map_nt(self, at::cos);
}

Tensor pin_memory_nested(const Tensor& self, c10::optional<Device> device) {
if (self.is_pinned(device)) {
return self;
}
auto* nt_input = get_nested_tensor_impl(self);
const auto& input_buffer = nt_input->get_buffer();
return wrap_buffer(at::_pin_memory(input_buffer, device), nt_input->get_nested_sizes());
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace native
} // namespace at
26 changes: 16 additions & 10 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2276,16 +2276,22 @@ class TestDataLoaderDeviceType(TestCase):
def test_nested_tensor_multiprocessing(self, device):
dataset = [torch.nested.nested_tensor([torch.randn(5)], device=device) for _ in range(100)]

loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=4,
collate_fn=_identity,
multiprocessing_context=('spawn' if 'cuda' in device else 'fork'),
)

for i, batch in enumerate(loader):
self.assertEqual(batch[0], dataset[i])
pin_memory_settings = [False]
if device == 'cpu' and torch.cuda.is_available():
pin_memory_settings.append(True)

for pin_memory in pin_memory_settings:
loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=4,
collate_fn=_identity,
pin_memory=pin_memory,
multiprocessing_context=('spawn' if 'cuda' in device else 'fork'),
)

for i, batch in enumerate(loader):
self.assertEqual(batch[0], dataset[i])


class IntegrationTestDataLoaderDataPipe(TestCase):
Expand Down
14 changes: 14 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
onlyCPU,
onlyCUDA,
skipMeta,
PYTORCH_CUDA_MEMCHECK,
)
from torch.testing._internal.common_dtype import floating_types_and_half
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -2941,6 +2942,19 @@ def compare_metadata(nt1, nt2):
self.assertEqual(b, nt_contiguous)
self.assertEqual(b, nt_noncontiguous)

@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
@onlyCUDA
def test_pin_memory(self, device):
nt = random_nt_from_dims([3, None, 5], device='cpu')
self.assertFalse(nt.is_pinned())
pinned = nt.pin_memory(device)
self.assertTrue(pinned.is_pinned())
self.assertEqual(nt, pinned)
self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
# test that pin_memory on already pinned tensor has no effect
self.assertIs(pinned, pinned.pin_memory())
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())


instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
Expand Down
Loading