Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] Support ReinterpretView in inductor codegen #113967

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,12 +1989,17 @@ def call_triton(x: torch.Tensor):

@requires_cuda()
@skipIfRocm
def test_triton_kernel_None_arg(self):
def test_triton_kernel_various_args(self):
@triton.autotune(
configs=[triton.Config({"BLOCK_SIZE": 128})],
key=[],
)
@triton.jit
def pass_kernel(
out_ptr,
dummy_None,
n_elements,
dummy_None,
dummy_empty,
BLOCK_SIZE: "tl.constexpr",
):
pass
Expand All @@ -2003,7 +2008,7 @@ def pass_kernel(
def call_triton(output):
n_elements = output.numel()
grid = (n_elements,)
pass_kernel[grid](output, None, n_elements, BLOCK_SIZE=16)
pass_kernel[grid](output, n_elements, None, torch.empty_like(output))
return output

output = torch.randn(5, device="cuda")
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/codegen/triton_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def is_aligned(
https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222
"""
if isinstance(x, TensorArg):
if x.buffer.startswith("reinterpret_tensor"):
return False
if include_tensor:
return not V.graph.scheduler.is_unaligned_buffer(x.buffer)
else:
Expand Down
16 changes: 5 additions & 11 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .. import codecache, config, ir
from ..codecache import CudaKernelParamCache
from ..ir import ComputedBuffer, InputBuffer
from ..ir import ComputedBuffer, InputBuffer, ReinterpretView
from ..triton_heuristics import grid as default_grid
from ..utils import (
cache_on_self,
Expand Down Expand Up @@ -834,14 +834,12 @@ def define_kernel(
self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")

def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
from ..ir import Buffer

original_name = kernel.__name__

# Distinguish between different functions using function id
cache_key = [id(kernel.fn)]
for arg in kwargs.values():
if isinstance(arg, Buffer):
if isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
cache_key.append(arg.get_dtype())
elif len(configs) > 0:
# We need to key on non tensor arg only in autotune mode
Expand Down Expand Up @@ -880,7 +878,7 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
):
constants[key] = arg
continue
if isinstance(arg, Buffer):
if isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
signature.append(
TensorArg(key, arg.codegen_reference(), arg.get_dtype())
)
Expand Down Expand Up @@ -1041,9 +1039,7 @@ def __repr__(self):
return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s))
elif isinstance(s, torch._ops.OpOverload):
return _get_qualified_name(s)
elif isinstance(s, ComputedBuffer):
return s.codegen_reference()
elif isinstance(s, InputBuffer):
elif isinstance(s, (ComputedBuffer, InputBuffer, ReinterpretView)):
return s.codegen_reference()
else:
return repr(s)
Expand Down Expand Up @@ -2342,9 +2338,7 @@ def val_to_arg_str(self, val):
return f"{val}L"
elif isinstance(val, str):
return f'"{val}"'
elif isinstance(val, ComputedBuffer):
return val.codegen_reference()
elif isinstance(val, InputBuffer):
elif isinstance(val, (ComputedBuffer, InputBuffer, ReinterpretView)):
return val.codegen_reference()
elif isinstance(val, torch.device):
return self.codegen_device(val)
Expand Down
Loading