Skip to content

Commit

Permalink
[inductor] Fix CPP wrapper codegen for ExternKernel args (#117931)
Browse files Browse the repository at this point in the history
Summary:

Recently, this has been fixed for the Python wrapper codegen in D52899373 (#117838). Here we extend the fix to CPP wrapper codegen / AOTInductor.

Test Plan:
New unit tests. In OSS:

```
python test/inductor/test_aot_inductor.py -k test_triton_kernel_multi_output_arg
```

```
python test/inductor/test_aot_inductor.py -k test_triton_kernel_extern_kernel_arg
```

Differential Revision: D52936248
  • Loading branch information
aakhundov authored and facebook-github-bot committed Jan 20, 2024
1 parent 5063362 commit 719a71f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
38 changes: 38 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,42 @@ def forward(self, a):
model.weight += 1
self.check_model(model, example_inputs)

def test_triton_kernel_extern_kernel_arg(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")

class Model(torch.nn.Module):
def forward(self, x, y):
out = torch.zeros_like(x)
# torch.mm is ExternKernelOut
add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16)
return out

example_inputs = (
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device="cuda"),
)

self.check_model(Model(), example_inputs)

def test_triton_kernel_multi_output_arg(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")

class Model(torch.nn.Module):
def forward(self, x, y):
out = torch.zeros_like(x)
# torch.sort creates fallback kernel and hence MultiOutput
add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16)
return out

example_inputs = (
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device="cuda"),
)

self.check_model(Model(), example_inputs)


common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)

Expand Down Expand Up @@ -1673,6 +1709,8 @@ def fail_abi_compatible_cuda(is_skip=False):
"test_normal_functional": fail_abi_compatible_cuda(),
# There is a double-free issue which will be fixed in another PR
"test_repeat_output": fail_abi_compatible_cuda(is_skip=True),
# no ABI shim fn for torch.sort; remove this when adding one
"test_triton_kernel_multi_output_arg": fail_abi_compatible_cuda(),
}

if TEST_WITH_ROCM:
Expand Down
6 changes: 3 additions & 3 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .. import codecache, config, ir
from ..codecache import CudaKernelParamCache
from ..ir import ComputedBuffer, InputBuffer, ReinterpretView
from ..ir import ReinterpretView
from ..triton_heuristics import grid as default_grid
from ..utils import (
cache_on_self,
Expand Down Expand Up @@ -1147,7 +1147,7 @@ def __repr__(self):
return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s))
elif isinstance(s, torch._ops.OpOverload):
return _get_qualified_name(s)
elif isinstance(s, (ir.Buffer, ComputedBuffer, InputBuffer, ReinterpretView)):
elif isinstance(s, (ir.Buffer, ReinterpretView)):
return s.codegen_reference()
else:
return repr(s)
Expand Down Expand Up @@ -2788,7 +2788,7 @@ def val_to_arg_str(self, val) -> str:
return f"{val}L"
elif isinstance(val, str):
return f'"{val}"'
elif isinstance(val, (ComputedBuffer, InputBuffer, ReinterpretView)):
elif isinstance(val, (ir.Buffer, ReinterpretView)):
return val.codegen_reference()
elif isinstance(val, torch.device):
return self.codegen_device(val)
Expand Down

0 comments on commit 719a71f

Please sign in to comment.