From bf821f4f3fb6df2077ff08a5c3a3369550deb3a1 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Mon, 20 Nov 2023 13:19:02 -0800 Subject: [PATCH] [AOTI] Support ReinterpretView in abi mode https://github.com/pytorch/pytorch/pull/113967 added support for ReinterpretView but it turnes out we codegen it differently in abi compat mode. This PR adds support for abi compat mode as well. ghstack-source-id: f435abbdd5497343124c4c779f35743d87ca1dfd Pull Request resolved: https://github.com/pytorch/pytorch/pull/114169 --- test/inductor/test_aot_inductor.py | 21 +++++++++++++++++++++ torch/_inductor/codegen/triton_utils.py | 4 +++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index fa7411dd5cf2c..381f79bce3dcf 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1205,6 +1205,27 @@ def forward(self, x): ] self.check_model(Model(), (a,), constraints=constraints) + def test_triton_kernel_reinterpret_view(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + @triton.jit + def pass_kernel(x, y): + pass + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # AOT export does not allow for input mutation + x = x.clone() + pass_kernel[(1,)](x, torch.empty_like(x)) + return x + + example_inputs = (torch.randn(4, device=self.device),) + self.check_model(Model(), example_inputs) + common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 87490fdae3569..e0d66ba1dc7fc 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -60,7 +60,9 @@ def is_aligned( https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222 """ if isinstance(x, TensorArg): - if x.buffer.startswith("reinterpret_tensor"): + if x.buffer.startswith("reinterpret_tensor") or x.buffer.startswith( + "RAIIAtenTensorHandle" + ): return False if include_tensor: return not V.graph.scheduler.is_unaligned_buffer(x.buffer)