Skip to content

Commit

Permalink
randperm: add torch check to ensure generator device = tensor device (#…
Browse files Browse the repository at this point in the history
…47022)

Summary:
**BC-breaking Note:**

This PR disallows passing in a generator of a different device than the tensor being created during `randperm` execution. For example, the following code which used to work no longer works.
```
> torch.randperm(3, device='cuda', generator=torch.Generator(device='cpu'))
tensor([0, 1, 2], device='cuda:0')
```
It now errors:
```
> torch.randperm(3, device='cuda', generator=torch.Generator(device='cpu'))
RuntimeError: Expected a 'cuda:0' generator device but found 'cpu'
```

**PR Summary:**

Fixes #44714

Also added + ran tests to ensure this functionality.

Disclaimer: More work needs to be done with regards to small cuda tensors when a generator is specified, look at the issue thread for more details.

Pull Request resolved: #47022

Reviewed By: samestep

Differential Revision: D24608237

Pulled By: janeyx99

fbshipit-source-id: b83c47219c7816d93f938f7ce86dc8857513961b
  • Loading branch information
janeyx99 authored and facebook-github-bot committed Nov 4, 2020
1 parent 07e8f48 commit e4bc785
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
4 changes: 1 addition & 3 deletions aten/src/ATen/core/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
*
* By default, there is one generator per device, and a device's generator is
* lazily created. A user can use the torch.Generator() api to create their own generator.
* Currently torch.Generator() can only create a CPUGeneratorImpl.
*/

/**
Expand All @@ -43,7 +42,7 @@
* Please use the public mutex_ when using any methods from these classes, except for the
* read-only methods. You can learn about the usage by looking into the unittests
* (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard.
*
*
* TODO: Look into changing the threading semantics of Generators in ATen (e.g., making
* them non-thread safe and instead making the generator state splittable, to accommodate
* forks into other threads).
Expand Down Expand Up @@ -126,4 +125,3 @@ Generator make_generator(Args&&... args) {
}

} // namespace at

1 change: 1 addition & 0 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ Tensor& randperm_out(Tensor& result, int64_t n) {

Tensor& randperm_out_cpu(Tensor& result, int64_t n, c10::optional<Generator> generator) {
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()), "Expected a '", result.device(), "' generator device but found '", generator->device(), "'");
check_supported_max_int_with_precision(n, result);
result.resize_({n});
auto gen = get_generator_or_default<CPUGeneratorImpl>(generator, detail::getDefaultCPUGenerator());
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cuda/TensorFactories.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, const TensorOpti

Tensor& randperm_out_cuda(Tensor& result, int64_t n, c10::optional<Generator> generator) {
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()), "Expected a '", result.device(), "' generator device but found '", generator->device(), "'");
check_supported_max_int_with_precision(n, result);

result.resize_({n});
Expand Down
18 changes: 16 additions & 2 deletions test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,8 +1066,8 @@ def test_logspace_special_steps(self, device, dtype):
self._test_logspace_base2(device, dtype, steps=steps)

@dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False, include_complex=False))
@dtypesIfCUDA(*((torch.testing.get_all_int_dtypes() + [torch.float32, torch.float16, torch.bfloat16])
if TEST_WITH_ROCM
@dtypesIfCUDA(*((torch.testing.get_all_int_dtypes() + [torch.float32, torch.float16, torch.bfloat16])
if TEST_WITH_ROCM
else torch.testing.get_all_dtypes(include_bool=False, include_half=True, include_complex=False)))
def test_logspace(self, device, dtype):
_from = random.random()
Expand Down Expand Up @@ -1273,6 +1273,20 @@ def test_randperm(self, device):
torch.randperm(n, out=non_contiguous_tensor)
self.assertEqual(non_contiguous_tensor, res)

# Test exceptions when device and generator types are incompatible
@onlyCUDA
def test_randperm_device_compatibility(self, device):
cuda_gen = torch.Generator(device='cuda')
cpu_gen = torch.Generator(device='cpu')
for n in (0, 3, 100, 30000):
regex = 'Expected a .* generator device but found .*'
cuda_t = torch.tensor(n, device='cuda')
self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cuda', generator=cpu_gen))
self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cuda', generator=cpu_gen, out=cuda_t))
cpu_t = torch.tensor(n, device='cpu')
self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cpu', generator=cuda_gen))
self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, device='cpu', generator=cuda_gen, out=cpu_t))
self.assertRaisesRegex(RuntimeError, regex, lambda: torch.randperm(n, generator=cuda_gen)) # implicitly on CPU

# Class for testing *like ops, like torch.ones_like
class TestLikeTensorCreation(TestCase):
Expand Down

0 comments on commit e4bc785

Please sign in to comment.