Skip to content

Commit ede4f1c

Browse files
shunting314etaf
authored andcommitted
[Inductor] mix order reduction heuristics and tuning (#166585)
Pull Request resolved: #166585 Approved by: https://github.com/jansel, https://github.com/PaulZhang12 ghstack dependencies: #166053, #166382, #166461
1 parent d7f02cb commit ede4f1c

File tree

3 files changed

+146
-35
lines changed

3 files changed

+146
-35
lines changed

torch/_inductor/codegen/simd.py

Lines changed: 107 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,17 @@
3535
from ..._dynamo.utils import counters
3636
from .. import config, ir, scheduler
3737
from ..analyze_preserves_zero_mask import prologue_preserves_zero_mask
38-
from ..codecache import code_hash
38+
from ..codecache import code_hash, PyCodeCache
3939
from ..dependencies import MemoryDep, StarDep, WeakDep
4040

4141

4242
if TYPE_CHECKING:
4343
from ..ir import IRNode
4444

4545
from ..optimize_indexing import indexing_dtype_strength_reduction
46-
from ..runtime.runtime_utils import green_text, yellow_text
46+
from ..runtime.coordinate_descent_tuner import CoordescTuner
47+
from ..runtime.hints import DeviceProperties
48+
from ..runtime.runtime_utils import green_text, next_power_of_2, yellow_text
4749
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
4850
from ..utils import (
4951
cache_property_on_self,
@@ -1535,6 +1537,63 @@ def _split_mix_order_reduction_epilogue(self, node):
15351537
epilogues.append(node)
15361538
return reductions, epilogues
15371539

1540+
def _generate_kernel_code_for_mix_order_reduction(
1541+
self, kernel_features, split_size, for_benchmark
1542+
):
1543+
"""
1544+
for_benchmark:
1545+
True if the generated code is for benchmarking. We need make
1546+
sure benchmark harness code is generated.
1547+
"""
1548+
numel, rnumel = kernel_features.numel, kernel_features.reduction_numel
1549+
node_schedule = kernel_features.node_schedule
1550+
1551+
kernel = self.create_kernel_choices(
1552+
kernel_features,
1553+
[{"x": numel, "r0_": rnumel}],
1554+
{
1555+
"features": kernel_features,
1556+
"tiling_scores": None,
1557+
"mix_order_reduction": True,
1558+
"override_persistent_reduction": True,
1559+
},
1560+
)[0]
1561+
assert kernel.persistent_reduction
1562+
assert kernel.mix_order_reduction
1563+
kernel.rsplit_size = split_size
1564+
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
1565+
1566+
# allocate workspace for this kernel
1567+
_, ws_name, ws_off = kernel.args.workspace(
1568+
len(kernel.saved_partial_accumulate)
1569+
* kernel.numels["r0_"]
1570+
* ((kernel.numels["x"] + kernel.rsplit_size - 1) // kernel.rsplit_size),
1571+
False,
1572+
dtype=torch.float,
1573+
)
1574+
assert ws_off == 0, f"{ws_off=}"
1575+
with kernel:
1576+
kernel.codegen_body()
1577+
1578+
stack = contextlib.ExitStack()
1579+
with V.set_kernel_handler(kernel), stack:
1580+
if for_benchmark:
1581+
stack.enter_context(config.patch(benchmark_kernel=True))
1582+
src_code = kernel.codegen_kernel()
1583+
1584+
if for_benchmark:
1585+
# only do this if we are doing benchmarking.
1586+
# When we are generating final code, the kernel name
1587+
# should be decided differently with node type, fx node name
1588+
# etc.
1589+
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
1590+
return kernel, ws_name, src_code
1591+
1592+
def benchmark_codegened_module(
1593+
self, mod, n_spills_threshold=8, node_names: Optional[OrderedSet[str]] = None
1594+
) -> tuple[float, str]:
1595+
raise NotImplementedError
1596+
15381597
def _codegen_mix_order_reduction(self, node1, node2):
15391598
numel, rnumel = scheduler.MixOrderReduction.get_numel_rnumel(node1)
15401599

@@ -1544,7 +1603,21 @@ def _codegen_mix_order_reduction(self, node1, node2):
15441603
):
15451604
return self._codegen_mix_order_reduction(node2, node1)
15461605

1547-
# pyrefly: ignore [bad-assignment]
1606+
def _pick_split_size():
1607+
# the overridden has highest priority
1608+
if config.triton.mix_order_reduction_split_size is not None:
1609+
return config.triton.mix_order_reduction_split_size
1610+
1611+
# heuristics based on number of SMs
1612+
device_prop = DeviceProperties.create(node1.get_device())
1613+
num_sm = device_prop.multi_processor_count
1614+
estimated_num_splits = num_sm * 8
1615+
split_size = max(next_power_of_2(numel // estimated_num_splits), 16)
1616+
split_size = min(split_size, 128)
1617+
return split_size
1618+
1619+
split_size = _pick_split_size()
1620+
15481621
metrics.codegen_mix_order_reduction += 1
15491622

15501623
assert V.graph.sizevars.statically_known_gt(
@@ -1557,9 +1630,6 @@ def _codegen_mix_order_reduction(self, node1, node2):
15571630
node2
15581631
)
15591632

1560-
split_size = config.triton.mix_order_reduction_split_size
1561-
nsplit = (numel + split_size - 1) // split_size
1562-
15631633
converted_nodes = []
15641634
for subnode in node2_reductions:
15651635
subnode.cancel_reduction_split()
@@ -1570,25 +1640,40 @@ def _codegen_mix_order_reduction(self, node1, node2):
15701640
node1.get_nodes() + converted_nodes, numel, rnumel
15711641
)
15721642
kernel_features = SIMDKernelFeatures(node_schedule, numel, rnumel)
1573-
kernel = self.create_kernel_choices(
1574-
kernel_features,
1575-
[{"x": numel, "r0_": rnumel}],
1576-
{
1577-
"features": kernel_features,
1578-
"tiling_scores": None,
1579-
"mix_order_reduction": True,
1580-
"override_persistent_reduction": True,
1581-
},
1582-
)[0]
1583-
assert kernel.persistent_reduction
1584-
assert kernel.mix_order_reduction
1585-
kernel.rsplit_size = split_size
1586-
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
15871643

1588-
is_split_reduction = bool(node2_reductions[0].node._split_size)
1644+
# The autotuning is skipped in deterministic mode
1645+
if (
1646+
not torch._inductor.config.deterministic
1647+
and config.triton.mix_order_reduction_split_size is None
1648+
and config.triton.mix_order_reduction_autotune_split_size
1649+
):
1650+
1651+
def _bench(candidate_split_size):
1652+
_, _, src_code = self._generate_kernel_code_for_mix_order_reduction(
1653+
kernel_features,
1654+
split_size=candidate_split_size,
1655+
for_benchmark=True,
1656+
)
1657+
mod = PyCodeCache.load(src_code)
1658+
ms, _ = self.benchmark_codegened_module(mod)
1659+
return ms
1660+
1661+
split_size = CoordescTuner.autotune_single_field(
1662+
_bench,
1663+
split_size,
1664+
8,
1665+
)
1666+
# print(f"Autotuning pick split size {split_size}")
1667+
1668+
kernel, ws_name, src_code = self._generate_kernel_code_for_mix_order_reduction(
1669+
kernel_features,
1670+
split_size=split_size,
1671+
for_benchmark=False,
1672+
)
15891673

15901674
# rename intermediate reduction output to final reduction
15911675
# output
1676+
is_split_reduction = bool(node2_reductions[0].node._split_size)
15921677
rename = {}
15931678
if is_split_reduction:
15941679
for subnode in node2_reductions:
@@ -1611,19 +1696,6 @@ def _codegen_mix_order_reduction(self, node1, node2):
16111696
partial_accum.buffer_name, partial_accum.buffer_name
16121697
)
16131698

1614-
# allocate workspace for this kernel
1615-
_, ws_name, ws_off = kernel.args.workspace(
1616-
len(kernel.saved_partial_accumulate)
1617-
* kernel.numels["r0_"]
1618-
* ((kernel.numels["x"] + kernel.rsplit_size - 1) // kernel.rsplit_size),
1619-
False,
1620-
dtype=torch.float,
1621-
)
1622-
assert ws_off == 0, f"{ws_off=}"
1623-
with kernel:
1624-
kernel.codegen_body()
1625-
with V.set_kernel_handler(kernel):
1626-
src_code = kernel.codegen_kernel()
16271699
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
16281700
kernel.kernel_name = kernel_name
16291701
kernel.code_hash = code_hash(src_code)
@@ -1643,6 +1715,7 @@ def _codegen_mix_order_reduction(self, node1, node2):
16431715

16441716
# a extra round of reduction
16451717
assert len(converted_nodes) == len(kernel.saved_partial_accumulate)
1718+
nsplit = (numel + split_size - 1) // split_size
16461719
for idx, partial_accum in enumerate(kernel.saved_partial_accumulate):
16471720
buffer_name = partial_accum.buffer_name
16481721

torch/_inductor/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1555,7 +1555,8 @@ class triton:
15551555
os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0") == "1"
15561556
)
15571557

1558-
mix_order_reduction_split_size = 64
1558+
mix_order_reduction_split_size: Optional[int] = None
1559+
mix_order_reduction_autotune_split_size = True
15591560

15601561

15611562
class aot_inductor:

torch/_inductor/runtime/coordinate_descent_tuner.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,40 @@ def autotune(
331331
)
332332

333333
return best_config
334+
335+
@staticmethod
336+
def autotune_single_field(fn, init_val, min_val=None, max_val=None):
337+
"""
338+
fn is a function that takes the field value and returns the benchmarking result
339+
init_val is the starting point of autotuning.
340+
341+
Should work well for parabola like curve. Here is a real example
342+
for split-size of mix-order-reduction: https://github.com/pytorch/pytorch/pull/166461
343+
"""
344+
cache = {}
345+
346+
def _bench(val):
347+
if val not in cache:
348+
cache[val] = fn(val)
349+
# print(f"split size {val} -> {cache[val]:.3f} ms")
350+
return cache[val]
351+
352+
if min_val is None:
353+
min_val = 1
354+
if max_val is None:
355+
max_val = 2**30 # some arbitrary large value
356+
357+
best_val = init_val
358+
improved = True
359+
while improved:
360+
improved = False
361+
candlist = [best_val // 2, best_val * 2]
362+
for cand in candlist:
363+
cand = max(cand, min_val)
364+
cand = min(cand, max_val)
365+
366+
if _bench(cand) < _bench(best_val):
367+
best_val = cand
368+
improved = True
369+
370+
return best_val

0 commit comments

Comments
 (0)