diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index 3777a8cc8d9ff..c4bc824fdb3cf 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -259,6 +259,10 @@ Tensor NestedTensor_to_padded_tensor_generic( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(nt.get_buffer().numel() == 0); return nt.get_buffer().clone(); } + TORCH_CHECK( + t.numel() > 0, + "to_padded_tensor: at least one constituent tensor should have non-zero numel" + ) // TODO: doesn't handle empty/scalar entries because we don't need // it for transformers; see to_padded_tensor in diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 41baa7a31c32a..607518cba6170 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -140,6 +140,7 @@ Tensor NestedTensor_to_padded_tensor_cuda( const Tensor& t, double padding, OptionalIntArrayRef output_size) { + TORCH_CHECK(t.numel() > 0, "to_padded_tensor: at least one constituent tensor should have non-zero numel") int64_t t_dim = t.dim(); if (t_dim >= 2 && t_dim <= 4 && (t.dtype() == at::kFloat || t.dtype() == at::kDouble || diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index bf8e0def01332..c5e2d16108eaf 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -96,15 +96,34 @@ def noncontiguous_to_padded_tensor(input, shape=None): # Helper function to generate a random nested tensor -def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided): +def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided, require_non_empty=True): if min_dims is None: min_dims = tuple([0] * len(max_dims)) + + assert len(max_dims) == len(min_dims) + for min_dim, max_dim in zip(min_dims, max_dims): + assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim" + assert min_dim >= 0, "random_nt: min_dim must be non-negative" + if require_non_empty: + assert not (min_dim == 0 and max_dim == 1), ( + "random_nt: zero cannot be the only possible value if require_non_empty is True" + ) + + if require_non_empty: + # Select a random idx that will be required to be non-empty + non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item() + ts1 = [] - for _ in range(num_tensors): - tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item() - for (min_dim, max_dim) in zip(min_dims, max_dims)]) + for i, _ in enumerate(range(num_tensors)): + tensor_dims = [] + for min_dim, max_dim in zip(min_dims, max_dims): + new_min_dim = min_dim + if require_non_empty and i == non_zero_idx and min_dim == 0: + new_min_dim = 1 + tensor_dims.append(torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item()) t1 = torch.randn(tensor_dims, device=device, dtype=dtype) ts1.append(t1) + return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout) @@ -1335,7 +1354,7 @@ def test_nested_tensor_sum_dim(self, device, dtype): params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7))) def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True): - nt = random_nt(device, dtype, ntensors, max_sizes) + nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False) nt2 = nt.clone() ub2 = nt2.unbind() nt.requires_grad_(True) @@ -1898,6 +1917,16 @@ def test_linear_noncontiguous(self, device, dtype): lambda: torch.nn.functional.linear(nt_noncontiguous, weight) ) + @dtypes(torch.float, torch.float16, torch.double) + def test_to_padded_tensor_zero_numel_errors(self, device, dtype): + ts = [torch.ones(1, 0), torch.ones(0, 0)] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype, layout=torch.strided) + self.assertRaisesRegex( + RuntimeError, + r"at least one constituent tensor should have non-zero numel", + lambda: torch.nested.to_padded_tensor(nt, 0.0) + ) + @dtypes(torch.float, torch.float16, torch.double) def test_transpose(self, device, dtype): nt = random_nt(device, dtype, 4, (4, 4))