Skip to content

Commit

Permalink
Fix inductor torch.cat edge case for empty tensor (#107193)
Browse files Browse the repository at this point in the history
Align with eager behavior on this edge case- essentially, the empty
tensor is ignored by the operator.

Fixes #107118

Pull Request resolved: #107193
Approved by: https://github.com/wanchaol, https://github.com/eellison, https://github.com/peterbell10
  • Loading branch information
wconstab authored and pytorchmergebot committed Aug 16, 2023
1 parent 7cb2a6b commit 1f6c1d9
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
31 changes: 31 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3263,6 +3263,37 @@ def fn(a):
(torch.randn([1, 3, 3, 16]).to(memory_format=torch.channels_last),),
)

def test_cat_empty(self):
def fn_2(*tensors):
return torch.cat(tensors)

self.common(
fn_2,
(
torch.randn([1, 3, 3, 16]),
torch.ones([0]),
),
)
self.common(
fn_2,
(
torch.randn([1, 3, 3, 16]),
torch.ones([0]),
torch.randn([1, 3, 3, 16]),
),
)

@expectedFailureCodegenDynamic
def test_cat_single_empty(self):
# fails dynamic check for 'has a dynamic dimension'
def fn_2(*tensors):
return torch.cat(tensors)

self.common(
fn_2,
(torch.ones([0]),),
)

def test_cat_upcasting(self):
def fn(arg4_1, slice_7):
cat_1 = aten.cat.default([arg4_1, slice_7], 1)
Expand Down
13 changes: 12 additions & 1 deletion torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,19 @@ def mm(self, input2):

@register_decomposition([aten.cat.default])
def cat(tensors, dim=0):
if len(tensors) == 1:
def non_empty_tensor(x):
# special case for cat'ing with an empty tensor -
# just drop the 'empty' inputs so they don't confuse the logic below.
return len(x.shape) > 1 or x.shape[0] > 0

filtered_tensors = list(filter(non_empty_tensor, tensors))

if len(filtered_tensors) == 1:
return tensors[0].clone()
elif 1 < len(filtered_tensors) < len(tensors):
# on the first call, when we remove empty tensors, we redispatch recursively
return aten.cat.default(filtered_tensors, dim)
# when no 'filtering' has occured, we raise to prevent infinite recursion (no more decomposition needed)
return NotImplemented


Expand Down
20 changes: 15 additions & 5 deletions torch/_inductor/fx_passes/split_cat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import logging
import operator
from typing import Callable, List, Sequence, Tuple, Union
Expand Down Expand Up @@ -117,12 +118,21 @@ def normalize_cat_default(match: Match, *args, **kwargs):
log.info("couldn't find cat args")
return
assert isinstance(tensors, (list, tuple))
if "example_value" not in tensors[0].meta:
log.warning("example value absent for node: %s", tensors[0])
return
for tensor in itertools.chain([cat_node], tensors):
if "example_value" not in tensor.meta:
log.warning("example value absent for node: %s", tensor)
return

ndim = cat_node.meta["example_value"].dim()

ndim = tensors[0].meta["example_value"].dim()
assert all(ndim == x.meta["example_value"].dim() for x in tensors)
def is_empty_tensor(x):
# special case where torch.cat supports cat'ing with an empty tensor
x_shape = x.meta["example_value"].shape
return len(x_shape) == 1 and x_shape[0] == 0

assert all(
ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors
)

if cat_dim < 0: # Normalize cat dim
cat_dim += ndim
Expand Down

0 comments on commit 1f6c1d9

Please sign in to comment.