3535from ..._dynamo .utils import counters
3636from .. import config , ir , scheduler
3737from ..analyze_preserves_zero_mask import prologue_preserves_zero_mask
38- from ..codecache import code_hash
38+ from ..codecache import code_hash , PyCodeCache
3939from ..dependencies import MemoryDep , StarDep , WeakDep
4040
4141
4242if TYPE_CHECKING :
4343 from ..ir import IRNode
4444
4545from ..optimize_indexing import indexing_dtype_strength_reduction
46- from ..runtime .runtime_utils import green_text , yellow_text
46+ from ..runtime .coordinate_descent_tuner import CoordescTuner
47+ from ..runtime .hints import DeviceProperties
48+ from ..runtime .runtime_utils import green_text , next_power_of_2 , yellow_text
4749from ..scheduler import BaseSchedulerNode , BaseScheduling , WhyNoFuse
4850from ..utils import (
4951 cache_property_on_self ,
@@ -1535,6 +1537,63 @@ def _split_mix_order_reduction_epilogue(self, node):
15351537 epilogues .append (node )
15361538 return reductions , epilogues
15371539
1540+ def _generate_kernel_code_for_mix_order_reduction (
1541+ self , kernel_features , split_size , for_benchmark
1542+ ):
1543+ """
1544+ for_benchmark:
1545+ True if the generated code is for benchmarking. We need make
1546+ sure benchmark harness code is generated.
1547+ """
1548+ numel , rnumel = kernel_features .numel , kernel_features .reduction_numel
1549+ node_schedule = kernel_features .node_schedule
1550+
1551+ kernel = self .create_kernel_choices (
1552+ kernel_features ,
1553+ [{"x" : numel , "r0_" : rnumel }],
1554+ {
1555+ "features" : kernel_features ,
1556+ "tiling_scores" : None ,
1557+ "mix_order_reduction" : True ,
1558+ "override_persistent_reduction" : True ,
1559+ },
1560+ )[0 ]
1561+ assert kernel .persistent_reduction
1562+ assert kernel .mix_order_reduction
1563+ kernel .rsplit_size = split_size
1564+ self .codegen_node_schedule_with_kernel (node_schedule , kernel )
1565+
1566+ # allocate workspace for this kernel
1567+ _ , ws_name , ws_off = kernel .args .workspace (
1568+ len (kernel .saved_partial_accumulate )
1569+ * kernel .numels ["r0_" ]
1570+ * ((kernel .numels ["x" ] + kernel .rsplit_size - 1 ) // kernel .rsplit_size ),
1571+ False ,
1572+ dtype = torch .float ,
1573+ )
1574+ assert ws_off == 0 , f"{ ws_off = } "
1575+ with kernel :
1576+ kernel .codegen_body ()
1577+
1578+ stack = contextlib .ExitStack ()
1579+ with V .set_kernel_handler (kernel ), stack :
1580+ if for_benchmark :
1581+ stack .enter_context (config .patch (benchmark_kernel = True ))
1582+ src_code = kernel .codegen_kernel ()
1583+
1584+ if for_benchmark :
1585+ # only do this if we are doing benchmarking.
1586+ # When we are generating final code, the kernel name
1587+ # should be decided differently with node type, fx node name
1588+ # etc.
1589+ src_code = src_code .replace (str (Placeholder .KERNEL_NAME ), "triton_" )
1590+ return kernel , ws_name , src_code
1591+
1592+ def benchmark_codegened_module (
1593+ self , mod , n_spills_threshold = 8 , node_names : Optional [OrderedSet [str ]] = None
1594+ ) -> tuple [float , str ]:
1595+ raise NotImplementedError
1596+
15381597 def _codegen_mix_order_reduction (self , node1 , node2 ):
15391598 numel , rnumel = scheduler .MixOrderReduction .get_numel_rnumel (node1 )
15401599
@@ -1544,7 +1603,21 @@ def _codegen_mix_order_reduction(self, node1, node2):
15441603 ):
15451604 return self ._codegen_mix_order_reduction (node2 , node1 )
15461605
1547- # pyrefly: ignore [bad-assignment]
1606+ def _pick_split_size ():
1607+ # the overridden has highest priority
1608+ if config .triton .mix_order_reduction_split_size is not None :
1609+ return config .triton .mix_order_reduction_split_size
1610+
1611+ # heuristics based on number of SMs
1612+ device_prop = DeviceProperties .create (node1 .get_device ())
1613+ num_sm = device_prop .multi_processor_count
1614+ estimated_num_splits = num_sm * 8
1615+ split_size = max (next_power_of_2 (numel // estimated_num_splits ), 16 )
1616+ split_size = min (split_size , 128 )
1617+ return split_size
1618+
1619+ split_size = _pick_split_size ()
1620+
15481621 metrics .codegen_mix_order_reduction += 1
15491622
15501623 assert V .graph .sizevars .statically_known_gt (
@@ -1557,9 +1630,6 @@ def _codegen_mix_order_reduction(self, node1, node2):
15571630 node2
15581631 )
15591632
1560- split_size = config .triton .mix_order_reduction_split_size
1561- nsplit = (numel + split_size - 1 ) // split_size
1562-
15631633 converted_nodes = []
15641634 for subnode in node2_reductions :
15651635 subnode .cancel_reduction_split ()
@@ -1570,25 +1640,40 @@ def _codegen_mix_order_reduction(self, node1, node2):
15701640 node1 .get_nodes () + converted_nodes , numel , rnumel
15711641 )
15721642 kernel_features = SIMDKernelFeatures (node_schedule , numel , rnumel )
1573- kernel = self .create_kernel_choices (
1574- kernel_features ,
1575- [{"x" : numel , "r0_" : rnumel }],
1576- {
1577- "features" : kernel_features ,
1578- "tiling_scores" : None ,
1579- "mix_order_reduction" : True ,
1580- "override_persistent_reduction" : True ,
1581- },
1582- )[0 ]
1583- assert kernel .persistent_reduction
1584- assert kernel .mix_order_reduction
1585- kernel .rsplit_size = split_size
1586- self .codegen_node_schedule_with_kernel (node_schedule , kernel )
15871643
1588- is_split_reduction = bool (node2_reductions [0 ].node ._split_size )
1644+ # The autotuning is skipped in deterministic mode
1645+ if (
1646+ not torch ._inductor .config .deterministic
1647+ and config .triton .mix_order_reduction_split_size is None
1648+ and config .triton .mix_order_reduction_autotune_split_size
1649+ ):
1650+
1651+ def _bench (candidate_split_size ):
1652+ _ , _ , src_code = self ._generate_kernel_code_for_mix_order_reduction (
1653+ kernel_features ,
1654+ split_size = candidate_split_size ,
1655+ for_benchmark = True ,
1656+ )
1657+ mod = PyCodeCache .load (src_code )
1658+ ms , _ = self .benchmark_codegened_module (mod )
1659+ return ms
1660+
1661+ split_size = CoordescTuner .autotune_single_field (
1662+ _bench ,
1663+ split_size ,
1664+ 8 ,
1665+ )
1666+ # print(f"Autotuning pick split size {split_size}")
1667+
1668+ kernel , ws_name , src_code = self ._generate_kernel_code_for_mix_order_reduction (
1669+ kernel_features ,
1670+ split_size = split_size ,
1671+ for_benchmark = False ,
1672+ )
15891673
15901674 # rename intermediate reduction output to final reduction
15911675 # output
1676+ is_split_reduction = bool (node2_reductions [0 ].node ._split_size )
15921677 rename = {}
15931678 if is_split_reduction :
15941679 for subnode in node2_reductions :
@@ -1611,19 +1696,6 @@ def _codegen_mix_order_reduction(self, node1, node2):
16111696 partial_accum .buffer_name , partial_accum .buffer_name
16121697 )
16131698
1614- # allocate workspace for this kernel
1615- _ , ws_name , ws_off = kernel .args .workspace (
1616- len (kernel .saved_partial_accumulate )
1617- * kernel .numels ["r0_" ]
1618- * ((kernel .numels ["x" ] + kernel .rsplit_size - 1 ) // kernel .rsplit_size ),
1619- False ,
1620- dtype = torch .float ,
1621- )
1622- assert ws_off == 0 , f"{ ws_off = } "
1623- with kernel :
1624- kernel .codegen_body ()
1625- with V .set_kernel_handler (kernel ):
1626- src_code = kernel .codegen_kernel ()
16271699 kernel_name = self .define_kernel (src_code , node_schedule , kernel )
16281700 kernel .kernel_name = kernel_name
16291701 kernel .code_hash = code_hash (src_code )
@@ -1643,6 +1715,7 @@ def _codegen_mix_order_reduction(self, node1, node2):
16431715
16441716 # a extra round of reduction
16451717 assert len (converted_nodes ) == len (kernel .saved_partial_accumulate )
1718+ nsplit = (numel + split_size - 1 ) // split_size
16461719 for idx , partial_accum in enumerate (kernel .saved_partial_accumulate ):
16471720 buffer_name = partial_accum .buffer_name
16481721
0 commit comments