From c30fb6c08c3b26411c813228d263300dbdf2e779 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 26 Oct 2023 19:34:33 -0700 Subject: [PATCH 1/3] [Inductor] Add triton.autotune support for user defined triton kernels with constant/simple grids [ghstack-poisoned] --- test/dynamo/test_functions.py | 43 +++++++++++++++++++++ torch/_dynamo/variables/builder.py | 6 ++- torch/_dynamo/variables/functions.py | 17 +++++++- torch/_inductor/codegen/wrapper.py | 20 +++++++--- torch/_inductor/ir.py | 10 ++++- torch/_inductor/triton_heuristics.py | 58 +++++++++++++++++++++++++--- 6 files changed, 139 insertions(+), 15 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 8837c4cc13985..f19b689783daf 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1402,6 +1402,30 @@ def add_kernel( output = x + y tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8), + ], + key=[], + ) + @triton.jit + def add_kernel_autotuned( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + @triton.jit def mul2_kernel( in_ptr0, @@ -1987,6 +2011,25 @@ def call_triton( # reset back CONSTANT_C = prev_c + @requires_cuda() + @requires_triton() + @common_utils.parametrize("grad", [False, True]) + @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) + def test_triton_kernel_autotune(self, grad, backend): + def call_triton(x: torch.Tensor, y: torch.Tensor): + output = torch.zeros_like(x, requires_grad=grad) + n_elements = output.numel() + grid = (n_elements,) + add_kernel_autotuned[grid](x, y, output, n_elements) + return output + + t1 = torch.rand(5, device="cuda", requires_grad=grad) + t2 = torch.rand(5, device="cuda", requires_grad=grad) + + torch_add = t1 + t2 + compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True) + self.assertEqual(compiled_func(t1, t2), torch_add) + @requires_cuda() @requires_triton() @common_utils.parametrize("grad", [False, True]) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 64aadc9157a46..74c8974f3d000 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -362,12 +362,16 @@ def _wrap(self, value): from torch.utils._triton import has_triton if has_triton(): + from triton.runtime.autotuner import Autotuner from triton.runtime.jit import JITFunction else: class JITFunction: pass + class Autotuner: + pass + make_guards = self.make_guards # Handle exact type() match @@ -716,7 +720,7 @@ def index_source(key): sym_node_proxy, new_symint == 1, ) - elif isinstance(value, JITFunction): + elif isinstance(value, (JITFunction, Autotuner)): return TritonKernelVariable( value, None, # No kernel idx provided diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 1bd06a0727ab1..138b82aeed741 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -652,10 +652,12 @@ def get_val(v): class TritonKernelVariable(VariableTracker): def __init__(self, kernel, kernel_idx, grid, **kwargs): - super().__init__(**kwargs) + from triton.runtime.autotuner import Autotuner from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + super().__init__(**kwargs) + assert kernel is not None self.kernel = kernel @@ -665,6 +667,19 @@ def __init__(self, kernel, kernel_idx, grid, **kwargs): self.grid = grid + if isinstance(kernel, Autotuner): + # We only support configs and keys arguments of triton.autotune + # Make sure other arguments are defaulted + defaults = inspect.signature(Autotuner).parameters + if ( + defaults["warmup"].default != kernel.warmup + or defaults["rep"].default != kernel.rep + or defaults["prune_configs_by"].default != kernel.early_config_prune + ): + raise Unsupported( + "Only configs and keys are supported for triton.autotune" + ) + def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 9c844152e3c7c..8924729964e7c 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -777,7 +777,7 @@ def get_unique_kernel_name(self, name: str) -> str: self.user_defined_kernel_count += 1 return new_name - def define_user_defined_triton_kernel(self, name, kernel, kwargs): + def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs): original_name = kernel.__name__ compile_wrapper = IndentedBuffer() compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") @@ -787,7 +787,7 @@ def define_user_defined_triton_kernel(self, name, kernel, kwargs): import triton import triton.language as tl from torch._inductor.utils import instance_descriptor - from torch._inductor.triton_heuristics import template + from torch._inductor.triton_heuristics import user_autotune """, strip=True, ) @@ -828,12 +828,20 @@ def define_user_defined_triton_kernel(self, name, kernel, kwargs): "configs": [config_of(signature)], "kernel_name": name, } + configs = [ + { + "kwargs": config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + } + for config in configs + ] compile_wrapper.splice( f""" - @template( - num_stages={num_stages}, - num_warps={num_warps}, - meta={triton_meta!r} + @user_autotune( + configs={configs!r}, + meta={triton_meta!r}, + filename=__file__ ) @triton.jit """ diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d00ee55b04631..5ced029eea5fb 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3768,9 +3768,15 @@ def apply_constraint(self): class UserDefinedTritonKernel(ExternKernel): def codegen(self, wrapper): + from triton.runtime.autotuner import Autotuner + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table kernel = kernel_side_table.get_kernel(self.kernel_idx) + configs = [] + if isinstance(kernel, Autotuner): + configs = kernel.configs + kernel = kernel.fn new_name = wrapper.get_unique_kernel_name(kernel.__name__) self.codegen_comment(wrapper) @@ -3779,7 +3785,9 @@ def codegen(self, wrapper): self.grid, self.codegen_kwargs(), ) - wrapper.define_user_defined_triton_kernel(new_name, kernel, self.kwargs) + wrapper.define_user_defined_triton_kernel( + new_name, kernel, configs, self.kwargs + ) def should_allocate(self): return False diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py index be062e8eb3cbe..0d65ed4aed646 100644 --- a/torch/_inductor/triton_heuristics.py +++ b/torch/_inductor/triton_heuristics.py @@ -12,7 +12,7 @@ import re import threading from enum import auto, Enum -from typing import Any, Callable, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple import torch @@ -62,6 +62,7 @@ class HeuristicType(Enum): REDUCTION = auto() PERSISTENT_REDUCTION = auto() TEMPLATE = auto() + USER_AUTOTUNE = auto() class AutotuneHint(Enum): @@ -344,7 +345,7 @@ def launcher({', '.join(def_args)}, grid, stream): return binary, launcher - def bench(self, launcher, *args, grid): + def bench(self, launcher, *args, grid, **kwargs): """Measure the performance of a given launcher""" if launcher.n_spills > config.triton.spill_threshold: log.debug( @@ -362,16 +363,17 @@ def kernel_call(): {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} ) - cloned_args = self.clone_args(*args) + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) launcher( *cloned_args, + **cloned_kwargs, grid=grid, stream=stream, ) return do_bench(kernel_call, rep=40, fast_flush=True) - def clone_args(self, *args): + def clone_args(self, *args, **kwargs): from .compile_fx import clone_preserve_strides # clone inplace buffers to avoid autotune contaminating them if @@ -385,7 +387,15 @@ def clone_args(self, *args): else: cloned_args.append(arg) - return cloned_args + cloned_kwargs: Dict[str, Any] = {} + for name, arg in kwargs.items(): + if name in self.mutated_arg_names: + assert isinstance(arg, torch.Tensor) + cloned_kwargs[name] = clone_preserve_strides(arg) + else: + cloned_kwargs[name] = arg + + return cloned_args, cloned_kwargs @dynamo_timed def benchmark_all_configs(self, *args, **kwargs): @@ -451,7 +461,10 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs): Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1; while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. """ - if self.heuristic_type == HeuristicType.TEMPLATE: + if ( + self.heuristic_type == HeuristicType.TEMPLATE + or self.heuristic_type == HeuristicType.USER_AUTOTUNE + ): # skip triton template return launcher @@ -1130,6 +1143,39 @@ def template(num_stages, num_warps, meta, filename=None): ) +def user_autotune(configs, meta, filename=None): + """ + Compile a user defined triton kernel + """ + defaults = inspect.signature(triton.Config).parameters + default_num_stages = defaults["num_stages"].default + default_num_warps = defaults["num_warps"].default + + if len(configs) == 0: + configs = [ + triton.Config( + {}, num_stages=default_num_stages, num_warps=default_num_warps + ) + ] + else: + configs = [ + triton.Config( + c.get("kwargs", {}), + num_stages=c.get("num_stages", default_num_stages), + num_warps=c.get("num_warps", default_num_warps), + ) + for c in configs + ] + + return cached_autotune( + None, + configs, + meta=meta, + heuristic_type=HeuristicType.USER_AUTOTUNE, + filename=filename, + ) + + def foreach(meta, num_warps, filename=None): """ Compile a triton foreach kernel From fea16b4becc180e08b2b43dab52a4a730fd9e66e Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 26 Oct 2023 19:50:31 -0700 Subject: [PATCH 2/3] Update on "[Inductor] Add triton.autotune support for user defined triton kernels with constant/simple grids" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned] --- torch/_inductor/triton_heuristics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py index 0d65ed4aed646..218d97d2ab7d6 100644 --- a/torch/_inductor/triton_heuristics.py +++ b/torch/_inductor/triton_heuristics.py @@ -373,7 +373,7 @@ def kernel_call(): return do_bench(kernel_call, rep=40, fast_flush=True) - def clone_args(self, *args, **kwargs): + def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: from .compile_fx import clone_preserve_strides # clone inplace buffers to avoid autotune contaminating them if @@ -468,7 +468,7 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs): # skip triton template return launcher - cloned_args = self.clone_args(*args) + cloned_args, _ = self.clone_args(*args) config2launcher = {launcher.config: launcher} def benchmark_one_config(config): From b3d02e9758937f834b660616ba100d304ab457e4 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Fri, 27 Oct 2023 00:48:33 -0700 Subject: [PATCH 3/3] Update on "[Inductor] Add triton.autotune support for user defined triton kernels with constant/simple grids" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned] --- torch/_inductor/codegen/wrapper.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 8924729964e7c..bb0bbb5c75602 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -793,20 +793,12 @@ def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs): ) compile_wrapper.newline() - # TODO(oulgen): num_stages and num_warps are default values of - # triton.Config. Can we do better? Or ask the user to provide? - num_stages = 2 - num_warps = 4 - from ..ir import Buffer from .common import SizeArg, TensorArg signature: List[Union[TensorArg, SizeArg]] = [] constants = {} for key, arg in kwargs.items(): - # Not a real argument - if key == "grid": - continue if ( key in kernel.__annotations__ and "constexpr" in kernel.__annotations__[key]