Skip to content

Commit

Permalink
Do not generate zero-numel NT by default in helper and improve to_pad…
Browse files Browse the repository at this point in the history
…ded_tensor msg

ghstack-source-id: 53b61ff106d3fbee31853a32d06b0f3599fc5a61
Pull Request resolved: #113162
  • Loading branch information
soulitzer committed Nov 7, 2023
1 parent 3b928cb commit 0e389e2
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/native/nested/NestedTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ Tensor NestedTensor_to_padded_tensor_generic(
const Tensor& t,
double padding,
OptionalIntArrayRef output_size) {
TORCH_CHECK(t.numel() > 0, "to_padded_tensor only supports tensors with non-zero numel")
// TODO: support noncontiguous case
// error out for now
TORCH_CHECK(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ Tensor NestedTensor_to_padded_tensor_cuda(
const Tensor& t,
double padding,
OptionalIntArrayRef output_size) {
TORCH_CHECK(t.numel() > 0, "to_padded_tensor only supports tensors with non-zero numel")
int64_t t_dim = t.dim();
if (t_dim >= 2 && t_dim <= 4 &&
(t.dtype() == at::kFloat || t.dtype() == at::kDouble ||
Expand Down
33 changes: 26 additions & 7 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,24 @@ def noncontiguous_to_padded_tensor(input, shape=None):
# Helper function to generate a random nested tensor


def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided):
def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided, must_non_empty=True):
if min_dims is None:
min_dims = tuple([0] * len(max_dims))
ts1 = []
for _ in range(num_tensors):
tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item()
for (min_dim, max_dim) in zip(min_dims, max_dims)])
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
ts1.append(t1)
has_non_empty = False
while not has_non_empty:
# Repeat until we have at least one non-empty tensor
ts1 = []
has_non_empty = False
for _ in range(num_tensors):
tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item()
for (min_dim, max_dim) in zip(min_dims, max_dims)])
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
if t1.numel() > 0:
has_non_empty = True
ts1.append(t1)
# Execute only once if must_non_empty is False
if not must_non_empty:
break
return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout)


Expand Down Expand Up @@ -1898,6 +1907,16 @@ def test_linear_noncontiguous(self, device, dtype):
lambda: torch.nn.functional.linear(nt_noncontiguous, weight)
)

@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_zero_numel_errors(self, device, dtype):
ts = [torch.ones(1, 0), torch.ones(0, 0)]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype, layout=torch.strided)
self.assertRaisesRegex(
RuntimeError,
r"to_padded_tensor only supports tensors with non-zero numel",
lambda: torch.nested.to_padded_tensor(nt, 0.0)
)

@dtypes(torch.float, torch.float16, torch.double)
def test_transpose(self, device, dtype):
nt = random_nt(device, dtype, 4, (4, 4))
Expand Down

0 comments on commit 0e389e2

Please sign in to comment.