Skip to content

Commit

Permalink
Revert "[Inductor] Add triton.autotune support for user defined trito…
Browse files Browse the repository at this point in the history
…n kernels with constant/simple grids (#112228)"

This reverts commit dbb31a2.

Reverted #112228 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing ROCm test in trunk https://hud.pytorch.org/pytorch/pytorch/commit/dbb31a2984fa616b4bb6fac7abb2a06ec0533eb1 ([comment](#112228 (comment)))
  • Loading branch information
pytorchmergebot committed Oct 28, 2023
1 parent 668c3b3 commit 8d44999
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 140 deletions.
43 changes: 0 additions & 43 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,30 +1402,6 @@ def add_kernel(
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

@triton.jit
def mul2_kernel(
in_ptr0,
Expand Down Expand Up @@ -2011,25 +1987,6 @@ def call_triton(
# reset back
CONSTANT_C = prev_c

@requires_cuda()
@requires_triton()
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_triton_kernel_autotune(self, grad, backend):
def call_triton(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x, requires_grad=grad)
n_elements = output.numel()
grid = (n_elements,)
add_kernel_autotuned[grid](x, y, output, n_elements)
return output

t1 = torch.rand(5, device="cuda", requires_grad=grad)
t2 = torch.rand(5, device="cuda", requires_grad=grad)

torch_add = t1 + t2
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
self.assertEqual(compiled_func(t1, t2), torch_add)

@requires_cuda()
@requires_triton()
@common_utils.parametrize("grad", [False, True])
Expand Down
6 changes: 1 addition & 5 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,16 +362,12 @@ def _wrap(self, value):
from torch.utils._triton import has_triton

if has_triton():
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
else:

class JITFunction:
pass

class Autotuner:
pass

make_guards = self.make_guards

# Handle exact type() match
Expand Down Expand Up @@ -720,7 +716,7 @@ def index_source(key):
sym_node_proxy,
new_symint == 1,
)
elif isinstance(value, (JITFunction, Autotuner)):
elif isinstance(value, JITFunction):
return TritonKernelVariable(
value,
None, # No kernel idx provided
Expand Down
17 changes: 1 addition & 16 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,12 +652,10 @@ def get_val(v):

class TritonKernelVariable(VariableTracker):
def __init__(self, kernel, kernel_idx, grid, **kwargs):
from triton.runtime.autotuner import Autotuner
super().__init__(**kwargs)

from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table

super().__init__(**kwargs)

assert kernel is not None

self.kernel = kernel
Expand All @@ -667,19 +665,6 @@ def __init__(self, kernel, kernel_idx, grid, **kwargs):

self.grid = grid

if isinstance(kernel, Autotuner):
# We only support configs and keys arguments of triton.autotune
# Make sure other arguments are defaulted
defaults = inspect.signature(Autotuner).parameters
if (
defaults["warmup"].default != kernel.warmup
or defaults["rep"].default != kernel.rep
or defaults["prune_configs_by"].default != kernel.early_config_prune
):
raise Unsupported(
"Only configs and keys are supported for triton.autotune"
)

def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
Expand Down
28 changes: 14 additions & 14 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def get_unique_kernel_name(self, name: str) -> str:
self.user_defined_kernel_count += 1
return new_name

def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs):
def define_user_defined_triton_kernel(self, name, kernel, kwargs):
original_name = kernel.__name__
compile_wrapper = IndentedBuffer()
compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''")
Expand All @@ -788,18 +788,26 @@ def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs):
import triton
import triton.language as tl
from torch._inductor.utils import instance_descriptor
from torch._inductor.triton_heuristics import user_autotune
from torch._inductor.triton_heuristics import template
""",
strip=True,
)
compile_wrapper.newline()

# TODO(oulgen): num_stages and num_warps are default values of
# triton.Config. Can we do better? Or ask the user to provide?
num_stages = 2
num_warps = 4

from ..ir import Buffer
from .common import SizeArg, TensorArg

signature: List[Union[TensorArg, SizeArg]] = []
constants = {}
for key, arg in kwargs.items():
# Not a real argument
if key == "grid":
continue
if (
key in kernel.__annotations__
and "constexpr" in kernel.__annotations__[key]
Expand All @@ -821,20 +829,12 @@ def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs):
"configs": [config_of(signature)],
"kernel_name": name,
}
configs = [
{
"kwargs": config.kwargs,
"num_warps": config.num_warps,
"num_stages": config.num_stages,
}
for config in configs
]
compile_wrapper.splice(
f"""
@user_autotune(
configs={configs!r},
meta={triton_meta!r},
filename=__file__
@template(
num_stages={num_stages},
num_warps={num_warps},
meta={triton_meta!r}
)
@triton.jit
"""
Expand Down
10 changes: 1 addition & 9 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3768,15 +3768,9 @@ def apply_constraint(self):

class UserDefinedTritonKernel(ExternKernel):
def codegen(self, wrapper):
from triton.runtime.autotuner import Autotuner

from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table

kernel = kernel_side_table.get_kernel(self.kernel_idx)
configs = []
if isinstance(kernel, Autotuner):
configs = kernel.configs
kernel = kernel.fn
new_name = wrapper.get_unique_kernel_name(kernel.__name__)

self.codegen_comment(wrapper)
Expand All @@ -3785,9 +3779,7 @@ def codegen(self, wrapper):
self.grid,
self.codegen_kwargs(),
)
wrapper.define_user_defined_triton_kernel(
new_name, kernel, configs, self.kwargs
)
wrapper.define_user_defined_triton_kernel(new_name, kernel, self.kwargs)

def should_allocate(self):
return False
Expand Down
60 changes: 7 additions & 53 deletions torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
import threading
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, List, Optional, Set, Tuple

import torch

Expand Down Expand Up @@ -62,7 +62,6 @@ class HeuristicType(Enum):
REDUCTION = auto()
PERSISTENT_REDUCTION = auto()
TEMPLATE = auto()
USER_AUTOTUNE = auto()


class AutotuneHint(Enum):
Expand Down Expand Up @@ -345,7 +344,7 @@ def launcher({', '.join(def_args)}, grid, stream):

return binary, launcher

def bench(self, launcher, *args, grid, **kwargs):
def bench(self, launcher, *args, grid):
"""Measure the performance of a given launcher"""
if launcher.n_spills > config.triton.spill_threshold:
log.debug(
Expand All @@ -363,17 +362,16 @@ def kernel_call():
{**dict(zip(self.arg_names, args)), **launcher.config.kwargs}
)

cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
cloned_args = self.clone_args(*args)
launcher(
*cloned_args,
**cloned_kwargs,
grid=grid,
stream=stream,
)

return do_bench(kernel_call, rep=40, fast_flush=True)

def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
def clone_args(self, *args):
from .compile_fx import clone_preserve_strides

# clone inplace buffers to avoid autotune contaminating them if
Expand All @@ -387,15 +385,7 @@ def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
else:
cloned_args.append(arg)

cloned_kwargs: Dict[str, Any] = {}
for name, arg in kwargs.items():
if name in self.mutated_arg_names:
assert isinstance(arg, torch.Tensor)
cloned_kwargs[name] = clone_preserve_strides(arg)
else:
cloned_kwargs[name] = arg

return cloned_args, cloned_kwargs
return cloned_args

@dynamo_timed
def benchmark_all_configs(self, *args, **kwargs):
Expand Down Expand Up @@ -461,14 +451,11 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs):
Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1;
while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
"""
if (
self.heuristic_type == HeuristicType.TEMPLATE
or self.heuristic_type == HeuristicType.USER_AUTOTUNE
):
if self.heuristic_type == HeuristicType.TEMPLATE:
# skip triton template
return launcher

cloned_args, _ = self.clone_args(*args)
cloned_args = self.clone_args(*args)
config2launcher = {launcher.config: launcher}

def benchmark_one_config(config):
Expand Down Expand Up @@ -1143,39 +1130,6 @@ def template(num_stages, num_warps, meta, filename=None):
)


def user_autotune(configs, meta, filename=None):
"""
Compile a user defined triton kernel
"""
defaults = inspect.signature(triton.Config).parameters
default_num_stages = defaults["num_stages"].default
default_num_warps = defaults["num_warps"].default

if len(configs) == 0:
configs = [
triton.Config(
{}, num_stages=default_num_stages, num_warps=default_num_warps
)
]
else:
configs = [
triton.Config(
c.get("kwargs", {}),
num_stages=c.get("num_stages", default_num_stages),
num_warps=c.get("num_warps", default_num_warps),
)
for c in configs
]

return cached_autotune(
None,
configs,
meta=meta,
heuristic_type=HeuristicType.USER_AUTOTUNE,
filename=filename,
)


def foreach(meta, num_warps, filename=None):
"""
Compile a triton foreach kernel
Expand Down

0 comments on commit 8d44999

Please sign in to comment.