Skip to content

Commit

Permalink
Fix inductor sub with symbolic integers.
Browse files Browse the repository at this point in the history
Fix: #108159

ghstack-source-id: 0a7d24b4bdb0c9280a8ce31d38138b7817e0cfdf
Pull Request resolved: #108518
  • Loading branch information
ysiraichi committed Sep 4, 2023
1 parent db63bf3 commit e2a8984
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
2 changes: 0 additions & 2 deletions test/inductor/test_torchinductor_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,6 @@ def div(x):
test(div)

@onlyCPU
@unittest.expectedFailure
# Ref: https://github.com/pytorch/pytorch/issues/108159
def test_sub_constant_folding(self, device):
def sub(x):
return x - torch.zeros(3)
Expand Down
19 changes: 14 additions & 5 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,21 @@ def inner(*inputs: List[TensorBox], alpha=None):
inputs[-1] = mul(inputs[-1], alpha)
else:
assert alpha is None

# Get the first TensorBox/ExpandView input as the reference tensor for
# sizes, data types, and device.
tensor_inputs = [
inp for inp in inputs if isinstance(inp, (TensorBox, ExpandView))
]
# Use the first tensor found. Otherwise, fallback to using the first argument.
ref = tensor_inputs[0] if len(tensor_inputs) > 0 else inputs[0]

loaders = [x.make_loader() for x in inputs]
ranges = inputs[0].get_size()
dtype = override_return_dtype or inputs[0].get_dtype()
is_cuda = decode_device(inputs[0].get_device()).type == "cuda"
ranges = ref.get_size()
dtype = override_return_dtype or ref.get_dtype()
is_cuda = decode_device(ref.get_device()).type == "cuda"

for other in inputs[1:]:
for other in inputs:
assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
other.get_size()
), f"ndim mismatch {fn} {ranges} {other.get_size()}"
Expand All @@ -403,7 +412,7 @@ def inner_fn(index):
device = i.get_device()
break
if not device:
device = inputs[0].get_device()
device = ref.get_device()

device = override_device or device

Expand Down

0 comments on commit e2a8984

Please sign in to comment.