Skip to content

Commit 8d44999

Browse files
Revert "[Inductor] Add triton.autotune support for user defined triton kernels with constant/simple grids (#112228)"
This reverts commit dbb31a2. Reverted #112228 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing ROCm test in trunk https://hud.pytorch.org/pytorch/pytorch/commit/dbb31a2984fa616b4bb6fac7abb2a06ec0533eb1 ([comment](#112228 (comment)))
1 parent 668c3b3 commit 8d44999

File tree

6 files changed

+24
-140
lines changed

6 files changed

+24
-140
lines changed

test/dynamo/test_functions.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,30 +1402,6 @@ def add_kernel(
14021402
output = x + y
14031403
tl.store(out_ptr + offsets, output, mask=mask)
14041404

1405-
@triton.autotune(
1406-
configs=[
1407-
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
1408-
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
1409-
],
1410-
key=[],
1411-
)
1412-
@triton.jit
1413-
def add_kernel_autotuned(
1414-
in_ptr0,
1415-
in_ptr1,
1416-
out_ptr,
1417-
n_elements,
1418-
BLOCK_SIZE: "tl.constexpr",
1419-
):
1420-
pid = tl.program_id(axis=0)
1421-
block_start = pid * BLOCK_SIZE
1422-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
1423-
mask = offsets < n_elements
1424-
x = tl.load(in_ptr0 + offsets, mask=mask)
1425-
y = tl.load(in_ptr1 + offsets, mask=mask)
1426-
output = x + y
1427-
tl.store(out_ptr + offsets, output, mask=mask)
1428-
14291405
@triton.jit
14301406
def mul2_kernel(
14311407
in_ptr0,
@@ -2011,25 +1987,6 @@ def call_triton(
20111987
# reset back
20121988
CONSTANT_C = prev_c
20131989

2014-
@requires_cuda()
2015-
@requires_triton()
2016-
@common_utils.parametrize("grad", [False, True])
2017-
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
2018-
def test_triton_kernel_autotune(self, grad, backend):
2019-
def call_triton(x: torch.Tensor, y: torch.Tensor):
2020-
output = torch.zeros_like(x, requires_grad=grad)
2021-
n_elements = output.numel()
2022-
grid = (n_elements,)
2023-
add_kernel_autotuned[grid](x, y, output, n_elements)
2024-
return output
2025-
2026-
t1 = torch.rand(5, device="cuda", requires_grad=grad)
2027-
t2 = torch.rand(5, device="cuda", requires_grad=grad)
2028-
2029-
torch_add = t1 + t2
2030-
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
2031-
self.assertEqual(compiled_func(t1, t2), torch_add)
2032-
20331990
@requires_cuda()
20341991
@requires_triton()
20351992
@common_utils.parametrize("grad", [False, True])

torch/_dynamo/variables/builder.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,16 +362,12 @@ def _wrap(self, value):
362362
from torch.utils._triton import has_triton
363363

364364
if has_triton():
365-
from triton.runtime.autotuner import Autotuner
366365
from triton.runtime.jit import JITFunction
367366
else:
368367

369368
class JITFunction:
370369
pass
371370

372-
class Autotuner:
373-
pass
374-
375371
make_guards = self.make_guards
376372

377373
# Handle exact type() match
@@ -720,7 +716,7 @@ def index_source(key):
720716
sym_node_proxy,
721717
new_symint == 1,
722718
)
723-
elif isinstance(value, (JITFunction, Autotuner)):
719+
elif isinstance(value, JITFunction):
724720
return TritonKernelVariable(
725721
value,
726722
None, # No kernel idx provided

torch/_dynamo/variables/functions.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -652,12 +652,10 @@ def get_val(v):
652652

653653
class TritonKernelVariable(VariableTracker):
654654
def __init__(self, kernel, kernel_idx, grid, **kwargs):
655-
from triton.runtime.autotuner import Autotuner
655+
super().__init__(**kwargs)
656656

657657
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
658658

659-
super().__init__(**kwargs)
660-
661659
assert kernel is not None
662660

663661
self.kernel = kernel
@@ -667,19 +665,6 @@ def __init__(self, kernel, kernel_idx, grid, **kwargs):
667665

668666
self.grid = grid
669667

670-
if isinstance(kernel, Autotuner):
671-
# We only support configs and keys arguments of triton.autotune
672-
# Make sure other arguments are defaulted
673-
defaults = inspect.signature(Autotuner).parameters
674-
if (
675-
defaults["warmup"].default != kernel.warmup
676-
or defaults["rep"].default != kernel.rep
677-
or defaults["prune_configs_by"].default != kernel.early_config_prune
678-
):
679-
raise Unsupported(
680-
"Only configs and keys are supported for triton.autotune"
681-
)
682-
683668
def call_function(
684669
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
685670
) -> "VariableTracker":

torch/_inductor/codegen/wrapper.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def get_unique_kernel_name(self, name: str) -> str:
778778
self.user_defined_kernel_count += 1
779779
return new_name
780780

781-
def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs):
781+
def define_user_defined_triton_kernel(self, name, kernel, kwargs):
782782
original_name = kernel.__name__
783783
compile_wrapper = IndentedBuffer()
784784
compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''")
@@ -788,18 +788,26 @@ def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs):
788788
import triton
789789
import triton.language as tl
790790
from torch._inductor.utils import instance_descriptor
791-
from torch._inductor.triton_heuristics import user_autotune
791+
from torch._inductor.triton_heuristics import template
792792
""",
793793
strip=True,
794794
)
795795
compile_wrapper.newline()
796796

797+
# TODO(oulgen): num_stages and num_warps are default values of
798+
# triton.Config. Can we do better? Or ask the user to provide?
799+
num_stages = 2
800+
num_warps = 4
801+
797802
from ..ir import Buffer
798803
from .common import SizeArg, TensorArg
799804

800805
signature: List[Union[TensorArg, SizeArg]] = []
801806
constants = {}
802807
for key, arg in kwargs.items():
808+
# Not a real argument
809+
if key == "grid":
810+
continue
803811
if (
804812
key in kernel.__annotations__
805813
and "constexpr" in kernel.__annotations__[key]
@@ -821,20 +829,12 @@ def define_user_defined_triton_kernel(self, name, kernel, configs, kwargs):
821829
"configs": [config_of(signature)],
822830
"kernel_name": name,
823831
}
824-
configs = [
825-
{
826-
"kwargs": config.kwargs,
827-
"num_warps": config.num_warps,
828-
"num_stages": config.num_stages,
829-
}
830-
for config in configs
831-
]
832832
compile_wrapper.splice(
833833
f"""
834-
@user_autotune(
835-
configs={configs!r},
836-
meta={triton_meta!r},
837-
filename=__file__
834+
@template(
835+
num_stages={num_stages},
836+
num_warps={num_warps},
837+
meta={triton_meta!r}
838838
)
839839
@triton.jit
840840
"""

torch/_inductor/ir.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3768,15 +3768,9 @@ def apply_constraint(self):
37683768

37693769
class UserDefinedTritonKernel(ExternKernel):
37703770
def codegen(self, wrapper):
3771-
from triton.runtime.autotuner import Autotuner
3772-
37733771
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
37743772

37753773
kernel = kernel_side_table.get_kernel(self.kernel_idx)
3776-
configs = []
3777-
if isinstance(kernel, Autotuner):
3778-
configs = kernel.configs
3779-
kernel = kernel.fn
37803774
new_name = wrapper.get_unique_kernel_name(kernel.__name__)
37813775

37823776
self.codegen_comment(wrapper)
@@ -3785,9 +3779,7 @@ def codegen(self, wrapper):
37853779
self.grid,
37863780
self.codegen_kwargs(),
37873781
)
3788-
wrapper.define_user_defined_triton_kernel(
3789-
new_name, kernel, configs, self.kwargs
3790-
)
3782+
wrapper.define_user_defined_triton_kernel(new_name, kernel, self.kwargs)
37913783

37923784
def should_allocate(self):
37933785
return False

torch/_inductor/triton_heuristics.py

Lines changed: 7 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import re
1313
import threading
1414
from enum import auto, Enum
15-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
15+
from typing import Any, Callable, List, Optional, Set, Tuple
1616

1717
import torch
1818

@@ -62,7 +62,6 @@ class HeuristicType(Enum):
6262
REDUCTION = auto()
6363
PERSISTENT_REDUCTION = auto()
6464
TEMPLATE = auto()
65-
USER_AUTOTUNE = auto()
6665

6766

6867
class AutotuneHint(Enum):
@@ -345,7 +344,7 @@ def launcher({', '.join(def_args)}, grid, stream):
345344

346345
return binary, launcher
347346

348-
def bench(self, launcher, *args, grid, **kwargs):
347+
def bench(self, launcher, *args, grid):
349348
"""Measure the performance of a given launcher"""
350349
if launcher.n_spills > config.triton.spill_threshold:
351350
log.debug(
@@ -363,17 +362,16 @@ def kernel_call():
363362
{**dict(zip(self.arg_names, args)), **launcher.config.kwargs}
364363
)
365364

366-
cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
365+
cloned_args = self.clone_args(*args)
367366
launcher(
368367
*cloned_args,
369-
**cloned_kwargs,
370368
grid=grid,
371369
stream=stream,
372370
)
373371

374372
return do_bench(kernel_call, rep=40, fast_flush=True)
375373

376-
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
374+
def clone_args(self, *args):
377375
from .compile_fx import clone_preserve_strides
378376

379377
# clone inplace buffers to avoid autotune contaminating them if
@@ -387,15 +385,7 @@ def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
387385
else:
388386
cloned_args.append(arg)
389387

390-
cloned_kwargs: Dict[str, Any] = {}
391-
for name, arg in kwargs.items():
392-
if name in self.mutated_arg_names:
393-
assert isinstance(arg, torch.Tensor)
394-
cloned_kwargs[name] = clone_preserve_strides(arg)
395-
else:
396-
cloned_kwargs[name] = arg
397-
398-
return cloned_args, cloned_kwargs
388+
return cloned_args
399389

400390
@dynamo_timed
401391
def benchmark_all_configs(self, *args, **kwargs):
@@ -461,14 +451,11 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs):
461451
Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1;
462452
while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
463453
"""
464-
if (
465-
self.heuristic_type == HeuristicType.TEMPLATE
466-
or self.heuristic_type == HeuristicType.USER_AUTOTUNE
467-
):
454+
if self.heuristic_type == HeuristicType.TEMPLATE:
468455
# skip triton template
469456
return launcher
470457

471-
cloned_args, _ = self.clone_args(*args)
458+
cloned_args = self.clone_args(*args)
472459
config2launcher = {launcher.config: launcher}
473460

474461
def benchmark_one_config(config):
@@ -1143,39 +1130,6 @@ def template(num_stages, num_warps, meta, filename=None):
11431130
)
11441131

11451132

1146-
def user_autotune(configs, meta, filename=None):
1147-
"""
1148-
Compile a user defined triton kernel
1149-
"""
1150-
defaults = inspect.signature(triton.Config).parameters
1151-
default_num_stages = defaults["num_stages"].default
1152-
default_num_warps = defaults["num_warps"].default
1153-
1154-
if len(configs) == 0:
1155-
configs = [
1156-
triton.Config(
1157-
{}, num_stages=default_num_stages, num_warps=default_num_warps
1158-
)
1159-
]
1160-
else:
1161-
configs = [
1162-
triton.Config(
1163-
c.get("kwargs", {}),
1164-
num_stages=c.get("num_stages", default_num_stages),
1165-
num_warps=c.get("num_warps", default_num_warps),
1166-
)
1167-
for c in configs
1168-
]
1169-
1170-
return cached_autotune(
1171-
None,
1172-
configs,
1173-
meta=meta,
1174-
heuristic_type=HeuristicType.USER_AUTOTUNE,
1175-
filename=filename,
1176-
)
1177-
1178-
11791133
def foreach(meta, num_warps, filename=None):
11801134
"""
11811135
Compile a triton foreach kernel

0 commit comments

Comments
 (0)