diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index f67c3618b..c2f9ae0da 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -245,20 +245,24 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: name, config.get(name, ()), block_ids=self.grid_block_ids ) - # Only one range_warp_specializes is allowed, take the last one range_warp_specializes = cast( "list[bool | None]", config.get("range_warp_specializes", []) ) if range_warp_specializes and any(range_warp_specializes): - for i in [j for j, val in enumerate(range_warp_specializes) if val][:-1]: + # Only one range_warp_specializes is allowed, take the first one + # Prefer warp specialize on outermost loop + first_idx = range_warp_specializes.index(True) + for i in range(first_idx + 1, len(range_warp_specializes)): range_warp_specializes[i] = None range_unroll_factors = cast( "list[int]", config.get("range_unroll_factors", []) ) - if range_unroll_factors and range_unroll_factors[-1]: - range_unroll_factors[-1] = 0 + if range_unroll_factors and range_unroll_factors[first_idx] > 1: + if range_unroll_factors[first_idx]: + range_unroll_factors[first_idx] = 0 + config["range_unroll_factors"] = range_unroll_factors config["range_warp_specializes"] = range_warp_specializes diff --git a/test/test_autotuner.expected b/test/test_autotuner.expected index 762c75345..449562b85 100644 --- a/test/test_autotuner.expected +++ b/test/test_autotuner.expected @@ -4,8 +4,8 @@ Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environmen --- assertExpectedJournal(TestAutotuner.test_config_fragment0) helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None]) helion.Config(block_sizes=[32, 128, 64], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 0], range_warp_specializes=[None, True]) -helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[16], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='persistent_interleaved', range_flattens=[True, None], range_multi_buffers=[None, None], range_num_stages=[2, 0], range_unroll_factors=[2, 0], range_warp_specializes=[True, False]) -helion.Config(block_sizes=[16, 128, 64], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[2, 0], range_warp_specializes=[True, None]) +helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[16], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='persistent_interleaved', range_flattens=[True, None], range_multi_buffers=[None, None], range_num_stages=[2, 0], range_unroll_factors=[0, 3], range_warp_specializes=[True, None]) +helion.Config(block_sizes=[16, 128, 64], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[0, 3], range_warp_specializes=[True, None]) helion.Config(block_sizes=[64, 32, 16], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['first', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, True], range_num_stages=[0, 4], range_unroll_factors=[0, 1], range_warp_specializes=[None, None]) helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[32], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=2, num_warps=1, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[True, None], range_num_stages=[3, 2], range_unroll_factors=[2, 2], range_warp_specializes=[False, False]) helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=5, num_warps=4, pid_type='persistent_interleaved', range_flattens=[None, True], range_multi_buffers=[False, False], range_num_stages=[3, 4], range_unroll_factors=[3, 2], range_warp_specializes=[None, None]) @@ -15,7 +15,7 @@ helion.Config(block_sizes=[16, 128, 16], indexing='pointer', l2_groupings=[8], l --- assertExpectedJournal(TestAutotuner.test_config_fragment1) helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[1, 32, 32], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[1, 32, 32], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True]) helion.Config(block_sizes=[2, 512, 4], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['last', ''], loop_orders=[[2, 1, 0]], num_stages=4, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[None], range_unroll_factors=[3], range_warp_specializes=[False]) helion.Config(block_sizes=[1, 2, 8], flatten_loops=[True], indexing='pointer', l2_groupings=[32], load_eviction_policies=['last', 'last'], loop_orders=[[1, 2, 0]], num_stages=7, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True]) helion.Config(block_sizes=[1, 128, 4], flatten_loops=[True], indexing='pointer', l2_groupings=[2], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=6, num_warps=1, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True])