Skip to content

Commit

Permalink
pin_memory support for NT
Browse files Browse the repository at this point in the history
ghstack-source-id: 02f149130fb37d360be9435d9135921343a33b29
Pull Request resolved: #110404
  • Loading branch information
jbschlosser committed Oct 2, 2023
1 parent 6b00b30 commit 8af4ee3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
5 changes: 4 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4405,14 +4405,17 @@
- func: is_pinned(Tensor self, Device? device=None) -> bool
variants: method
dispatch:
CUDA: is_pinned_cuda
NestedTensorCUDA, CUDA: is_pinned_cuda
MPS: is_pinned_mps
CompositeExplicitAutograd: is_pinned_default

# TODO: add a copy kwarg that guarantees that the tensor is put into fresh
# pinned memory
- func: pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)
variants: method
dispatch:
NestedTensorCUDA, NestedTensorCPU: pin_memory_nested
CompositeImplicitAutograd: pin_memory

# Unlike pin_memory, this is guaranteed to give a new non-aliasing tensor
- func: _pin_memory(Tensor self, Device? device=None) -> Tensor
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,14 @@ Tensor cos_nested(const Tensor& self) {
return map_nt(self, at::cos);
}

Tensor pin_memory_nested(const Tensor& self, c10::optional<Device> device) {
if (self.is_pinned(device)) {
return self;
}
auto* nt_input = get_nested_tensor_impl(self);
const auto& input_buffer = nt_input->get_buffer();
return wrap_buffer(at::_pin_memory(input_buffer, device), nt_input->get_nested_sizes());
}

} // namespace native
} // namespace at
14 changes: 14 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
onlyCPU,
onlyCUDA,
skipMeta,
PYTORCH_CUDA_MEMCHECK,
)
from torch.testing._internal.common_dtype import floating_types_and_half
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -2941,6 +2942,19 @@ def compare_metadata(nt1, nt2):
self.assertEqual(b, nt_contiguous)
self.assertEqual(b, nt_noncontiguous)

@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
@onlyCUDA
def test_pin_memory(self, device):
nt = random_nt_from_dims([3, None, 5], device='cpu')
self.assertFalse(nt.is_pinned())
pinned = nt.pin_memory(device)
self.assertTrue(pinned.is_pinned())
self.assertEqual(nt, pinned)
self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
# test that pin_memory on already pinned tensor has no effect
self.assertIs(pinned, pinned.pin_memory())
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())


instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
Expand Down

0 comments on commit 8af4ee3

Please sign in to comment.