Skip to content

Commit

Permalink
[AOTI] Support ReinterpretView in abi mode
Browse files Browse the repository at this point in the history
#113967 added support for
ReinterpretView but it turnes out we codegen it differently in abi
compat mode. This PR adds support for abi compat mode as well.

ghstack-source-id: f435abbdd5497343124c4c779f35743d87ca1dfd
Pull Request resolved: #114169
  • Loading branch information
oulgen committed Nov 20, 2023
1 parent 100b995 commit bf821f4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
21 changes: 21 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,27 @@ def forward(self, x):
]
self.check_model(Model(), (a,), constraints=constraints)

def test_triton_kernel_reinterpret_view(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")

@triton.jit
def pass_kernel(x, y):
pass

class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
# AOT export does not allow for input mutation
x = x.clone()
pass_kernel[(1,)](x, torch.empty_like(x))
return x

example_inputs = (torch.randn(4, device=self.device),)
self.check_model(Model(), example_inputs)


common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)

Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/codegen/triton_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def is_aligned(
https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222
"""
if isinstance(x, TensorArg):
if x.buffer.startswith("reinterpret_tensor"):
if x.buffer.startswith("reinterpret_tensor") or x.buffer.startswith(
"RAIIAtenTensorHandle"
):
return False
if include_tensor:
return not V.graph.scheduler.is_unaligned_buffer(x.buffer)
Expand Down

0 comments on commit bf821f4

Please sign in to comment.