Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not generate zero-numel NT by default in helper and improve to_padded_tensor msg #113162

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/native/nested/NestedTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ Tensor NestedTensor_to_padded_tensor_generic(
const Tensor& t,
double padding,
OptionalIntArrayRef output_size) {
TORCH_CHECK(t.numel() > 0, "to_padded_tensor only supports tensors with non-zero numel")
// TODO: support noncontiguous case
// error out for now
TORCH_CHECK(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ Tensor NestedTensor_to_padded_tensor_cuda(
const Tensor& t,
double padding,
OptionalIntArrayRef output_size) {
TORCH_CHECK(t.numel() > 0, "to_padded_tensor only supports tensors with non-zero numel")
int64_t t_dim = t.dim();
if (t_dim >= 2 && t_dim <= 4 &&
(t.dtype() == at::kFloat || t.dtype() == at::kDouble ||
Expand Down
32 changes: 25 additions & 7 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,23 @@ def noncontiguous_to_padded_tensor(input, shape=None):
# Helper function to generate a random nested tensor


def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided):
def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided, must_non_empty=True):
soulitzer marked this conversation as resolved.
Show resolved Hide resolved
if min_dims is None:
min_dims = tuple([0] * len(max_dims))
ts1 = []
for _ in range(num_tensors):
tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item()
for (min_dim, max_dim) in zip(min_dims, max_dims)])
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
ts1.append(t1)
has_non_empty = False
while not has_non_empty:
# Repeat until we have at least one non-empty tensor
soulitzer marked this conversation as resolved.
Show resolved Hide resolved
ts1 = []
for _ in range(num_tensors):
tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item()
for (min_dim, max_dim) in zip(min_dims, max_dims)])
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
if t1.numel() > 0:
has_non_empty = True
ts1.append(t1)
# Execute only once if must_non_empty is False
if not must_non_empty:
break
return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout)


Expand Down Expand Up @@ -1898,6 +1906,16 @@ def test_linear_noncontiguous(self, device, dtype):
lambda: torch.nn.functional.linear(nt_noncontiguous, weight)
)

@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_zero_numel_errors(self, device, dtype):
ts = [torch.ones(1, 0), torch.ones(0, 0)]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype, layout=torch.strided)
self.assertRaisesRegex(
RuntimeError,
r"to_padded_tensor only supports tensors with non-zero numel",
lambda: torch.nested.to_padded_tensor(nt, 0.0)
)

@dtypes(torch.float, torch.float16, torch.double)
def test_transpose(self, device, dtype):
nt = random_nt(device, dtype, 4, (4, 4))
Expand Down