Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] Add triton.autotune support for user defined triton kernels with constant/simple grids #112228

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 43 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,30 @@ def add_kernel(
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

@triton.jit
def mul2_kernel(
in_ptr0,
Expand Down Expand Up @@ -1987,6 +2011,25 @@ def call_triton(
# reset back
CONSTANT_C = prev_c

@requires_cuda()
@requires_triton()
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_triton_kernel_autotune(self, grad, backend):
def call_triton(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x, requires_grad=grad)
n_elements = output.numel()
grid = (n_elements,)
add_kernel_autotuned[grid](x, y, output, n_elements)
return output

t1 = torch.rand(5, device="cuda", requires_grad=grad)
t2 = torch.rand(5, device="cuda", requires_grad=grad)

torch_add = t1 + t2
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
self.assertEqual(compiled_func(t1, t2), torch_add)

@requires_cuda()
@requires_triton()
@common_utils.parametrize("grad", [False, True])
Expand Down
6 changes: 5 additions & 1 deletion torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,16 @@ def _wrap(self, value):
from torch.utils._triton import has_triton

if has_triton():
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
else:

class JITFunction:
pass

class Autotuner:
pass

make_guards = self.make_guards

# Handle exact type() match
Expand Down Expand Up @@ -716,7 +720,7 @@ def index_source(key):
sym_node_proxy,
new_symint == 1,
)
elif isinstance(value, JITFunction):
elif isinstance(value, (JITFunction, Autotuner)):
return TritonKernelVariable(
value,
None, # No kernel idx provided
Expand Down
17 changes: 16 additions & 1 deletion torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,10 +652,12 @@ def get_val(v):

class TritonKernelVariable(VariableTracker):
def __init__(self, kernel, kernel_idx, grid, **kwargs):
super().__init__(**kwargs)
from triton.runtime.autotuner import Autotuner

from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table

super().__init__(**kwargs)

assert kernel is not None

self.kernel = kernel
Expand All @@ -665,6 +667,19 @@ def __init__(self, kernel, kernel_idx, grid, **kwargs):

self.grid = grid

if isinstance(kernel, Autotuner):
# We only support configs and keys arguments of triton.autotune
# Make sure other arguments are defaulted
defaults = inspect.signature(Autotuner).parameters
if (
defaults["warmup"].default != kernel.warmup
or defaults["rep"].default != kernel.rep
or defaults["prune_configs_by"].default != kernel.early_config_prune
):
raise Unsupported(
"Only configs and keys are supported for triton.autotune"
)

def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
Expand Down
20 changes: 14 additions & 6 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def get_unique_kernel_name(self, name: str) -> str:
self.user_defined_kernel_count += 1
return new_name

def define_user_defined_triton_kernel(self, name, kernel, kwargs):
def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs):
original_name = kernel.__name__
compile_wrapper = IndentedBuffer()
compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''")
Expand All @@ -787,7 +787,7 @@ def define_user_defined_triton_kernel(self, name, kernel, kwargs):
import triton
import triton.language as tl
from torch._inductor.utils import instance_descriptor
from torch._inductor.triton_heuristics import template
from torch._inductor.triton_heuristics import user_autotune
""",
strip=True,
)
Expand Down Expand Up @@ -828,12 +828,20 @@ def define_user_defined_triton_kernel(self, name, kernel, kwargs):
"configs": [config_of(signature)],
"kernel_name": name,
}
configs = [
{
"kwargs": config.kwargs,
"num_warps": config.num_warps,
"num_stages": config.num_stages,
}
for config in configs
]
compile_wrapper.splice(
f"""
@template(
num_stages={num_stages},
num_warps={num_warps},
meta={triton_meta!r}
@user_autotune(
configs={configs!r},
meta={triton_meta!r},
filename=__file__
)
@triton.jit
"""
Expand Down
10 changes: 9 additions & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3768,9 +3768,15 @@ def apply_constraint(self):

class UserDefinedTritonKernel(ExternKernel):
def codegen(self, wrapper):
from triton.runtime.autotuner import Autotuner

from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table

kernel = kernel_side_table.get_kernel(self.kernel_idx)
configs = []
if isinstance(kernel, Autotuner):
configs = kernel.configs
kernel = kernel.fn
new_name = wrapper.get_unique_kernel_name(kernel.__name__)

self.codegen_comment(wrapper)
Expand All @@ -3779,7 +3785,9 @@ def codegen(self, wrapper):
self.grid,
self.codegen_kwargs(),
)
wrapper.define_user_defined_triton_kernel(new_name, kernel, self.kwargs)
wrapper.define_user_defined_triton_kernel(
new_name, kernel, configs, self.kwargs
)

def should_allocate(self):
return False
Expand Down
60 changes: 53 additions & 7 deletions torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
import threading
from enum import auto, Enum
from typing import Any, Callable, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

import torch

Expand Down Expand Up @@ -62,6 +62,7 @@ class HeuristicType(Enum):
REDUCTION = auto()
PERSISTENT_REDUCTION = auto()
TEMPLATE = auto()
USER_AUTOTUNE = auto()


class AutotuneHint(Enum):
Expand Down Expand Up @@ -344,7 +345,7 @@ def launcher({', '.join(def_args)}, grid, stream):

return binary, launcher

def bench(self, launcher, *args, grid):
def bench(self, launcher, *args, grid, **kwargs):
"""Measure the performance of a given launcher"""
if launcher.n_spills > config.triton.spill_threshold:
log.debug(
Expand All @@ -362,16 +363,17 @@ def kernel_call():
{**dict(zip(self.arg_names, args)), **launcher.config.kwargs}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not really sure what this pre_hook stuff is. I assume I need to put **kwargs here too, how to trigger/test this?

)

cloned_args = self.clone_args(*args)
cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
launcher(
*cloned_args,
**cloned_kwargs,
grid=grid,
stream=stream,
)

return do_bench(kernel_call, rep=40, fast_flush=True)

def clone_args(self, *args):
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
from .compile_fx import clone_preserve_strides

# clone inplace buffers to avoid autotune contaminating them if
Expand All @@ -385,7 +387,15 @@ def clone_args(self, *args):
else:
cloned_args.append(arg)

return cloned_args
cloned_kwargs: Dict[str, Any] = {}
for name, arg in kwargs.items():
if name in self.mutated_arg_names:
assert isinstance(arg, torch.Tensor)
cloned_kwargs[name] = clone_preserve_strides(arg)
else:
cloned_kwargs[name] = arg

return cloned_args, cloned_kwargs

@dynamo_timed
def benchmark_all_configs(self, *args, **kwargs):
Expand Down Expand Up @@ -451,11 +461,14 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs):
Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1;
while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
"""
if self.heuristic_type == HeuristicType.TEMPLATE:
if (
self.heuristic_type == HeuristicType.TEMPLATE
or self.heuristic_type == HeuristicType.USER_AUTOTUNE
):
# skip triton template
return launcher

cloned_args = self.clone_args(*args)
cloned_args, _ = self.clone_args(*args)
config2launcher = {launcher.config: launcher}

def benchmark_one_config(config):
Expand Down Expand Up @@ -1130,6 +1143,39 @@ def template(num_stages, num_warps, meta, filename=None):
)


def user_autotune(configs, meta, filename=None):
"""
Compile a user defined triton kernel
"""
defaults = inspect.signature(triton.Config).parameters
default_num_stages = defaults["num_stages"].default
default_num_warps = defaults["num_warps"].default

if len(configs) == 0:
configs = [
triton.Config(
{}, num_stages=default_num_stages, num_warps=default_num_warps
)
]
else:
configs = [
triton.Config(
c.get("kwargs", {}),
num_stages=c.get("num_stages", default_num_stages),
num_warps=c.get("num_warps", default_num_warps),
)
for c in configs
]

return cached_autotune(
None,
configs,
meta=meta,
heuristic_type=HeuristicType.USER_AUTOTUNE,
filename=filename,
)


def foreach(meta, num_warps, filename=None):
"""
Compile a triton foreach kernel
Expand Down