Skip to content

Commit

Permalink
[AOTI] Support ReinterpretView in abi mode (#114169)
Browse files Browse the repository at this point in the history
#113967 added support for
ReinterpretView but it turnes out we codegen it differently in abi
compat mode. This PR adds support for abi compat mode as well.

Pull Request resolved: #114169
Approved by: https://github.com/aakhundov
  • Loading branch information
oulgen authored and pytorchmergebot committed Nov 21, 2023
1 parent b5dd37f commit ef90508
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 4 deletions.
21 changes: 21 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,27 @@ def forward(self, x):
]
self.check_model(Model(), (a,), constraints=constraints)

def test_triton_kernel_reinterpret_view(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")

@triton.jit
def pass_kernel(x, y):
pass

class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
# AOT export does not allow for input mutation
x = x.clone()
pass_kernel[(1,)](x, torch.empty_like(x))
return x

example_inputs = (torch.randn(4, device=self.device),)
self.check_model(Model(), example_inputs)

def test_shifted_constraint_ranges(self):
class Model(torch.nn.Module):
def __init__(self):
Expand Down
7 changes: 5 additions & 2 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def data_type_logger(msg):
schedule_log.debug("Data type propagation: %s", msg)


TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"])
TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype", "check_alignment"])
SizeArg = namedtuple("SizeArg", ["name", "expr"])

DeviceCodegen = namedtuple("DeviceCodegen", ["scheduling", "wrapper_codegen"])
Expand Down Expand Up @@ -633,6 +633,7 @@ def python_argdefs(self):
inplaced.inner_name,
inplaced.other_names[-1],
V.graph.get_dtype(inplaced.other_names[-1]),
True,
)
)
for outer, inner in chain(
Expand All @@ -642,7 +643,9 @@ def python_argdefs(self):
continue
arg_defs.append(inner)
call_args.append(outer)
precompile_args.append(TensorArg(inner, outer, V.graph.get_dtype(outer)))
precompile_args.append(
TensorArg(inner, outer, V.graph.get_dtype(outer), True)
)
for outer, inner in self.sizevars.items():
arg_defs.append(inner)
call_args.append(outer)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/triton_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def is_aligned(
https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222
"""
if isinstance(x, TensorArg):
if x.buffer.startswith("reinterpret_tensor"):
if not x.check_alignment:
return False
if include_tensor:
return not V.graph.scheduler.is_unaligned_buffer(x.buffer)
Expand Down
8 changes: 7 additions & 1 deletion torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,13 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
continue
if isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
signature.append(
TensorArg(key, arg.codegen_reference(), arg.get_dtype())
TensorArg(
key,
arg.codegen_reference(),
arg.get_dtype(),
# For ReinterpretView, we do not want to check alignment
not isinstance(arg, ReinterpretView),
)
)
else:
signature.append(SizeArg(key, arg))
Expand Down

0 comments on commit ef90508

Please sign in to comment.