From 0e389e2286fc0b7a6423e39cc6ac1926f0a9a7dd Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 7 Nov 2023 10:28:47 -0500 Subject: [PATCH] Do not generate zero-numel NT by default in helper and improve to_padded_tensor msg ghstack-source-id: 53b61ff106d3fbee31853a32d06b0f3599fc5a61 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113162 --- .../ATen/native/nested/NestedTensorMath.cpp | 1 + .../cuda/NestedTensorTransformerFunctions.cpp | 1 + test/test_nestedtensor.py | 33 +++++++++++++++---- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index 3777a8cc8d9ff..614a21ecb3c82 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -245,6 +245,7 @@ Tensor NestedTensor_to_padded_tensor_generic( const Tensor& t, double padding, OptionalIntArrayRef output_size) { + TORCH_CHECK(t.numel() > 0, "to_padded_tensor only supports tensors with non-zero numel") // TODO: support noncontiguous case // error out for now TORCH_CHECK( diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 41baa7a31c32a..9da8dfd547e31 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -140,6 +140,7 @@ Tensor NestedTensor_to_padded_tensor_cuda( const Tensor& t, double padding, OptionalIntArrayRef output_size) { + TORCH_CHECK(t.numel() > 0, "to_padded_tensor only supports tensors with non-zero numel") int64_t t_dim = t.dim(); if (t_dim >= 2 && t_dim <= 4 && (t.dtype() == at::kFloat || t.dtype() == at::kDouble || diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index bf8e0def01332..8cc5ed82729e0 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -96,15 +96,24 @@ def noncontiguous_to_padded_tensor(input, shape=None): # Helper function to generate a random nested tensor -def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided): +def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided, must_non_empty=True): if min_dims is None: min_dims = tuple([0] * len(max_dims)) - ts1 = [] - for _ in range(num_tensors): - tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item() - for (min_dim, max_dim) in zip(min_dims, max_dims)]) - t1 = torch.randn(tensor_dims, device=device, dtype=dtype) - ts1.append(t1) + has_non_empty = False + while not has_non_empty: + # Repeat until we have at least one non-empty tensor + ts1 = [] + has_non_empty = False + for _ in range(num_tensors): + tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item() + for (min_dim, max_dim) in zip(min_dims, max_dims)]) + t1 = torch.randn(tensor_dims, device=device, dtype=dtype) + if t1.numel() > 0: + has_non_empty = True + ts1.append(t1) + # Execute only once if must_non_empty is False + if not must_non_empty: + break return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout) @@ -1898,6 +1907,16 @@ def test_linear_noncontiguous(self, device, dtype): lambda: torch.nn.functional.linear(nt_noncontiguous, weight) ) + @dtypes(torch.float, torch.float16, torch.double) + def test_to_padded_tensor_zero_numel_errors(self, device, dtype): + ts = [torch.ones(1, 0), torch.ones(0, 0)] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype, layout=torch.strided) + self.assertRaisesRegex( + RuntimeError, + r"to_padded_tensor only supports tensors with non-zero numel", + lambda: torch.nested.to_padded_tensor(nt, 0.0) + ) + @dtypes(torch.float, torch.float16, torch.double) def test_transpose(self, device, dtype): nt = random_nt(device, dtype, 4, (4, 4))