From 719a71f40eb50bec902db3f9069a64d785503aa7 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Sat, 20 Jan 2024 10:35:51 -0800 Subject: [PATCH] [inductor] Fix CPP wrapper codegen for ExternKernel args (#117931) Summary: Recently, this has been fixed for the Python wrapper codegen in D52899373 (https://github.com/pytorch/pytorch/pull/117838). Here we extend the fix to CPP wrapper codegen / AOTInductor. Test Plan: New unit tests. In OSS: ``` python test/inductor/test_aot_inductor.py -k test_triton_kernel_multi_output_arg ``` ``` python test/inductor/test_aot_inductor.py -k test_triton_kernel_extern_kernel_arg ``` Differential Revision: D52936248 --- test/inductor/test_aot_inductor.py | 38 ++++++++++++++++++++++++++++++ torch/_inductor/codegen/wrapper.py | 6 ++--- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index ea539804a9759..55b4e66bc5a62 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1575,6 +1575,42 @@ def forward(self, a): model.weight += 1 self.check_model(model, example_inputs) + def test_triton_kernel_extern_kernel_arg(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def forward(self, x, y): + out = torch.zeros_like(x) + # torch.mm is ExternKernelOut + add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16) + return out + + example_inputs = ( + torch.randn(4, 4, device="cuda"), + torch.randn(4, 4, device="cuda"), + ) + + self.check_model(Model(), example_inputs) + + def test_triton_kernel_multi_output_arg(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def forward(self, x, y): + out = torch.zeros_like(x) + # torch.sort creates fallback kernel and hence MultiOutput + add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16) + return out + + example_inputs = ( + torch.randn(4, 4, device="cuda"), + torch.randn(4, 4, device="cuda"), + ) + + self.check_model(Model(), example_inputs) + common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) @@ -1673,6 +1709,8 @@ def fail_abi_compatible_cuda(is_skip=False): "test_normal_functional": fail_abi_compatible_cuda(), # There is a double-free issue which will be fixed in another PR "test_repeat_output": fail_abi_compatible_cuda(is_skip=True), + # no ABI shim fn for torch.sort; remove this when adding one + "test_triton_kernel_multi_output_arg": fail_abi_compatible_cuda(), } if TEST_WITH_ROCM: diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 75ea609780db3..67a40ad145f40 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -23,7 +23,7 @@ from .. import codecache, config, ir from ..codecache import CudaKernelParamCache -from ..ir import ComputedBuffer, InputBuffer, ReinterpretView +from ..ir import ReinterpretView from ..triton_heuristics import grid as default_grid from ..utils import ( cache_on_self, @@ -1147,7 +1147,7 @@ def __repr__(self): return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s)) elif isinstance(s, torch._ops.OpOverload): return _get_qualified_name(s) - elif isinstance(s, (ir.Buffer, ComputedBuffer, InputBuffer, ReinterpretView)): + elif isinstance(s, (ir.Buffer, ReinterpretView)): return s.codegen_reference() else: return repr(s) @@ -2788,7 +2788,7 @@ def val_to_arg_str(self, val) -> str: return f"{val}L" elif isinstance(val, str): return f'"{val}"' - elif isinstance(val, (ComputedBuffer, InputBuffer, ReinterpretView)): + elif isinstance(val, (ir.Buffer, ReinterpretView)): return val.codegen_reference() elif isinstance(val, torch.device): return self.codegen_device(val)