From 0e623935f8690b183973df849094d172768853ba Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Tue, 7 Oct 2025 13:03:19 -0700 Subject: [PATCH] Autotune eviction policy stack-info: PR: https://github.com/pytorch/helion/pull/823, branch: oulgen/stack/118 --- helion/_compiler/compile_environment.py | 3 + helion/_compiler/device_function.py | 1 + helion/_compiler/device_ir.py | 54 ++++++++++++ helion/autotuner/__init__.py | 1 + helion/autotuner/config_fragment.py | 46 ++++++++++ helion/autotuner/config_spec.py | 27 ++++-- helion/language/memory_ops.py | 17 ++++ helion/runtime/config.py | 7 ++ test/test_autotuner.expected | 40 ++++----- test/test_eviction_policy.expected | 77 ++++++++++++++++ test/test_eviction_policy.py | 112 ++++++++++++++++++++++++ test/test_register_tunable.expected | 2 +- 12 files changed, 360 insertions(+), 27 deletions(-) diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 6f7088c2f..d99f7c87f 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -96,6 +96,9 @@ def __init__(self, device: torch.device, settings: Settings) -> None: self.specialized_vars: set[sympy.Symbol] = set() self.loop_dependency_checker = LoopDependencyChecker() self._symint_cache: dict[object, torch.SymInt] = {} + self.device_load_count = ( + 0 # Track number of loads in all device code for eviction policy tuning + ) def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None: from .device_function import contains_only_block_size_symbols diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index cb04619e7..9bec72f47 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -242,6 +242,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config) self.rng_seed_count = 0 + self.device_load_index = 0 # Track which load in device code we're generating (for eviction policy tuning) # Name of the RNG seed buffer parameter in kernel signature self.rng_seed_buffer_param_name = None diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index b791b1abc..fd89f3ba5 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -1065,6 +1065,55 @@ def visit_For(self, node: ast.For) -> None: self.generic_visit(node) +def _count_device_loads(device_ir: DeviceIR) -> int: + """Count the number of load operations in all device code for eviction policy tuning.""" + from ..language import memory_ops + + # Build set of rolled graph IDs to exclude (these are duplicates) + rolled_graph_ids = { + info.new_graph_id + for info in device_ir.rolled_reductions + if info.new_graph_id is not None + } + + load_count = 0 + # Walk all graphs except rolled duplicates + for graph_info in device_ir.graphs: + if graph_info.graph_id in rolled_graph_ids: + continue + + for node in graph_info.graph.nodes: + # Check if this is a load operation + if node.op == "call_function" and node.target is memory_ops.load: + # Only count loads without explicit eviction policy + # (user can still specify eviction_policy to override tuning) + # Check kwargs first, then check if 4th arg (eviction_policy) is None + eviction_policy_arg = node.kwargs.get("eviction_policy") + if eviction_policy_arg is None: + # Check if eviction_policy was passed as positional arg (index 3) + if len(node.args) >= 4: + eviction_policy_arg = node.args[3] + if eviction_policy_arg is None: + load_count += 1 + return load_count + + +def _register_eviction_policy_tunable(load_count: int) -> None: + """Register the eviction policy tunable for all device loads.""" + if load_count == 0: + return + + from ..autotuner.config_fragment import EnumFragment + from ..autotuner.config_fragment import ListOf + from ..autotuner.config_spec import VALID_EVICTION_POLICIES + + env = CompileEnvironment.current() + # Register a tunable for eviction policies for all device loads + fragment = ListOf(EnumFragment(choices=VALID_EVICTION_POLICIES), length=load_count) + env.config_spec.load_eviction_policies = fragment + env.device_load_count = load_count + + def lower_to_device_ir(func: HostFunction) -> DeviceIR: device_ir = DeviceIR() with func, device_ir, compile_lock: @@ -1085,6 +1134,11 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: if len(device_ir.root_ids) > 1: # xyz not supported with shared program IDs, but persistent kernels are allowed CompileEnvironment.current().config_spec.disallow_pid_type("xyz") + + # Count all device loads and register eviction policy tunable + load_count = _count_device_loads(device_ir) + _register_eviction_policy_tunable(load_count) + return device_ir diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index 401205ae8..cc8d90bc7 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -3,6 +3,7 @@ from .config_fragment import BooleanFragment as BooleanFragment from .config_fragment import EnumFragment as EnumFragment from .config_fragment import IntegerFragment as IntegerFragment +from .config_fragment import ListOf as ListOf from .config_fragment import PowerOfTwoFragment as PowerOfTwoFragment from .config_spec import ConfigSpec as ConfigSpec from .differential_evolution import ( diff --git a/helion/autotuner/config_fragment.py b/helion/autotuner/config_fragment.py index 2b4969c0b..c8bddf2b3 100644 --- a/helion/autotuner/config_fragment.py +++ b/helion/autotuner/config_fragment.py @@ -221,3 +221,49 @@ def category(self) -> Category: class NumWarpsFragment(PowerOfTwoFragment): def category(self) -> Category: return Category.NUM_WARPS + + +@dataclasses.dataclass +class ListOf(ConfigSpecFragment): + """Wrapper that creates a list of independently tunable fragments. + + Example: + ListOf(EnumFragment(choices=("a", "b", "c")), length=5) + creates a list of 5 independently tunable enum values. + """ + + inner: ConfigSpecFragment + length: int + + def default(self) -> list[object]: + """Return a list of default values.""" + return [self.inner.default() for _ in range(self.length)] + + def random(self) -> list[object]: + """Return a list of random values.""" + return [self.inner.random() for _ in range(self.length)] + + def pattern_neighbors(self, current: object) -> list[object]: + """Return neighbors by changing one element at a time.""" + if not isinstance(current, list) or len(current) != self.length: + raise ValueError(f"Expected list of length {self.length}, got {current!r}") + + neighbors: list[object] = [] + # For each position, try all neighbors from the inner fragment + for i in range(self.length): + for neighbor_value in self.inner.pattern_neighbors(current[i]): + neighbor = current.copy() + neighbor[i] = neighbor_value + neighbors.append(neighbor) + return neighbors + + def differential_mutation(self, a: object, b: object, c: object) -> list[object]: + """Create a new value by combining a, b, and c element-wise.""" + assert isinstance(a, list) and len(a) == self.length + assert isinstance(b, list) and len(b) == self.length + assert isinstance(c, list) and len(c) == self.length + + return [ + self.inner.differential_mutation(a[i], b[i], c[i]) + for i in range(self.length) + ] diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index 234086d48..59e416e6b 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -18,6 +18,7 @@ from .config_fragment import ConfigSpecFragment from .config_fragment import EnumFragment from .config_fragment import IntegerFragment +from .config_fragment import ListOf from .config_fragment import NumWarpsFragment from .config_fragment import PermutationFragment from .config_fragment import PowerOfTwoFragment @@ -50,9 +51,11 @@ "num_stages", "pid_type", "indexing", + "load_eviction_policies", ] ) VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved") +VALID_EVICTION_POLICIES = ("", "first", "last") @dataclasses.dataclass @@ -97,6 +100,11 @@ class ConfigSpec: default_factory=functools.partial(tuple, VALID_PID_TYPES) ) grid_block_ids: list[int] = dataclasses.field(default_factory=list) + load_eviction_policies: ListOf = dataclasses.field( + default_factory=lambda: ListOf( + EnumFragment(choices=VALID_EVICTION_POLICIES), length=0 + ) + ) @staticmethod def _valid_indexing_types() -> tuple[IndexingLiteral, ...]: @@ -206,12 +214,16 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: "range_multi_buffers", "range_flattens", "static_ranges", + "load_eviction_policies", ): - if not config[name]: - config.pop(name) + if not config.get(name): + config.pop(name, None) config.setdefault("num_warps", DEFAULT_NUM_WARPS) config.setdefault("num_stages", DEFAULT_NUM_STAGES) + config.setdefault( + "load_eviction_policies", self.load_eviction_policies.default() + ) # TODO(jansel): include num_ctas and max_nreg for name, values in ( @@ -266,10 +278,12 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)), "indexing": fn(EnumFragment(self._valid_indexing_types())), "pid_type": fn(EnumFragment(self.allowed_pid_types)), + "load_eviction_policies": fn(self.load_eviction_policies), } # Add tunable parameters - for key, fragment in self.user_defined_tunables.items(): - config[key] = fn(fragment) + config.update( + {key: fn(fragment) for key, fragment in self.user_defined_tunables.items()} + ) for name in ( "loop_orders", @@ -282,9 +296,10 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "range_multi_buffers", "range_flattens", "static_ranges", + "load_eviction_policies", ): - if not config[name]: - config.pop(name) + if not config.get(name): + config.pop(name, None) self.normalize(config) return helion.Config(**config) diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index 23fbab363..1fe790cbc 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -16,6 +16,13 @@ __all__ = ["load", "store"] +# Map short config names to full Triton API names for eviction policies +_EVICTION_POLICY_MAP = { + "": None, + "first": "evict_first", + "last": "evict_last", +} + @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) @@ -242,6 +249,16 @@ def _(state: CodegenState) -> ast.AST: extra_mask = state.ast_args[2] assert isinstance(extra_mask, (type(None), ast.AST)) eviction_policy = state.ast_args[3] if len(state.ast_args) > 3 else None + + # If no explicit eviction_policy and we're in device code, use tunable + if eviction_policy is None and state.codegen.on_device: + policies = state.config.load_eviction_policies + idx = state.device_function.device_load_index + if idx < len(policies): + policy_value = policies[idx] + eviction_policy = _EVICTION_POLICY_MAP.get(policy_value, policy_value) + state.device_function.device_load_index += 1 + if eviction_policy is not None: assert isinstance(eviction_policy, str) eviction_policy = ast.Constant(value=eviction_policy) diff --git a/helion/runtime/config.py b/helion/runtime/config.py index 107f03424..185d372f0 100644 --- a/helion/runtime/config.py +++ b/helion/runtime/config.py @@ -34,6 +34,7 @@ def __init__( range_multi_buffers: list[bool | None] | None = None, range_flattens: list[bool | None] | None = None, static_ranges: list[bool] | None = None, + load_eviction_policies: list[str] | None = None, num_warps: int | None = None, num_stages: int | None = None, pid_type: PidTypeLiteral | None = None, @@ -55,6 +56,7 @@ def __init__( range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls. range_flattens: Controls flatten parameter for tl.range calls. static_ranges: Whether to use tl.static_range instead tl.range. + load_eviction_policies: Eviction policies for load operations ("", "first", "last"). num_warps: Number of warps per block. num_stages: Number of stages for software pipelining. pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved"). @@ -74,6 +76,7 @@ def __init__( "range_multi_buffers": range_multi_buffers, "range_flattens": range_flattens, "static_ranges": static_ranges, + "load_eviction_policies": load_eviction_policies, "num_warps": num_warps, "num_stages": num_stages, "indexing": indexing, @@ -189,6 +192,10 @@ def range_flattens(self) -> list[bool | None]: def static_ranges(self) -> list[bool]: return cast("list[bool]", self.config.get("static_ranges", [])) + @property + def load_eviction_policies(self) -> list[str]: + return cast("list[str]", self.config.get("load_eviction_policies", [])) + @property def indexing(self) -> IndexingLiteral: return self.config.get("indexing", "pointer") # type: ignore[return-value] diff --git a/test/test_autotuner.expected b/test/test_autotuner.expected index 71e4fabd6..0a3e87e72 100644 --- a/test/test_autotuner.expected +++ b/test/test_autotuner.expected @@ -2,28 +2,28 @@ This file is automatically generated by assertExpectedJournal calls in test_auto Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. --- assertExpectedJournal(TestAutotuner.test_config_fragment0) -helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[1], loop_orders=[[0, 1]], num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None]) -helion.Config(block_sizes=[16, 16, 16], indexing='block_ptr', l2_groupings=[8], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 2], range_warp_specializes=[None, True]) -helion.Config(block_sizes=[16, 64, 32], indexing='tensor_descriptor', l2_groupings=[16], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 3], range_warp_specializes=[None, False]) -helion.Config(block_sizes=[16, 16, 32], indexing='pointer', l2_groupings=[1], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[2, 3], range_warp_specializes=[True, None]) -helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[8], loop_orders=[[0, 1]], num_stages=2, num_warps=16, pid_type='persistent_blocked', range_flattens=[True, None], range_multi_buffers=[True, True], range_num_stages=[0, 2], range_unroll_factors=[1, 4], range_warp_specializes=[True, None]) -helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[1], loop_orders=[[0, 1]], num_stages=4, num_warps=4, pid_type='persistent_blocked', range_flattens=[False, False], range_multi_buffers=[False, False], range_num_stages=[0, 2], range_unroll_factors=[0, 4], range_warp_specializes=[None, True]) -helion.Config(block_sizes=[16, 128, 16], indexing='pointer', l2_groupings=[64], loop_orders=[[1, 0]], num_stages=3, num_warps=4, pid_type='persistent_blocked', range_flattens=[False, False], range_multi_buffers=[None, False], range_num_stages=[4, 4], range_unroll_factors=[0, 1], range_warp_specializes=[False, False]) -helion.Config(block_sizes=[64, 128, 16], indexing='block_ptr', l2_groupings=[64], loop_orders=[[1, 0]], num_stages=5, num_warps=32, pid_type='persistent_blocked', range_flattens=[False, None], range_multi_buffers=[None, True], range_num_stages=[2, 2], range_unroll_factors=[0, 4], range_warp_specializes=[False, False]) -helion.Config(block_sizes=[16, 32, 16], indexing='tensor_descriptor', l2_groupings=[8], loop_orders=[[1, 0]], num_stages=4, num_warps=2, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, False], range_num_stages=[0, 4], range_unroll_factors=[0, 4], range_warp_specializes=[None, False]) -helion.Config(block_sizes=[16, 32, 32], indexing='tensor_descriptor', l2_groupings=[2], loop_orders=[[0, 1]], num_stages=6, num_warps=2, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 2], range_warp_specializes=[None, None]) +helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None]) +helion.Config(block_sizes=[32, 128, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 2], range_warp_specializes=[None, True]) +helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 3], range_warp_specializes=[None, False]) +helion.Config(block_sizes=[16, 32, 256], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[2, 3], range_warp_specializes=[True, None]) +helion.Config(block_sizes=[64, 32, 16], indexing='block_ptr', l2_groupings=[2], load_eviction_policies=['first', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, True], range_num_stages=[0, 4], range_unroll_factors=[0, 1], range_warp_specializes=[None, None]) +helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[32], load_eviction_policies=['last', 'first'], loop_orders=[[0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None, False], range_multi_buffers=[None, None], range_num_stages=[0, 2], range_unroll_factors=[0, 2], range_warp_specializes=[None, False]) +helion.Config(block_sizes=[16, 32, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[1, 0]], num_stages=5, num_warps=16, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, False], range_num_stages=[0, 3], range_unroll_factors=[0, 3], range_warp_specializes=[None, None]) +helion.Config(block_sizes=[16, 32, 16], indexing='pointer', l2_groupings=[2], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=8, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[False, None], range_num_stages=[3, 3], range_unroll_factors=[2, 3], range_warp_specializes=[False, True]) +helion.Config(block_sizes=[256, 16, 16], indexing='pointer', l2_groupings=[2], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=5, num_warps=32, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, False], range_num_stages=[0, 1], range_unroll_factors=[0, 2], range_warp_specializes=[None, True]) +helion.Config(block_sizes=[16, 64, 16], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=3, num_warps=32, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[False, None], range_num_stages=[3, 0], range_unroll_factors=[3, 4], range_warp_specializes=[False, True]) --- assertExpectedJournal(TestAutotuner.test_config_fragment1) -helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], loop_orders=[[0, 1, 2]], num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[2, 128, 128], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True]) -helion.Config(block_sizes=[2, 16, 4], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[8], loop_orders=[[0, 2, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[8, 2, 512], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[32], loop_orders=[[0, 2, 1]], num_stages=7, num_warps=32, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[8, 4, 256], flatten_loops=[True], indexing='pointer', l2_groupings=[2], loop_orders=[[2, 0, 1]], num_stages=7, num_warps=8, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[True]) -helion.Config(block_sizes=[2, 2, 32], flatten_loops=[True], indexing='block_ptr', l2_groupings=[1], loop_orders=[[2, 0, 1]], num_stages=3, num_warps=16, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[1, 32, 32], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[64], loop_orders=[[1, 2, 0]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[8, 64, 2], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[8], loop_orders=[[0, 1, 2]], num_stages=7, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[2], range_warp_specializes=[False]) -helion.Config(block_sizes=[1, 256, 16], flatten_loops=[True], indexing='pointer', l2_groupings=[64], loop_orders=[[1, 0, 2]], num_stages=5, num_warps=32, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[2], range_warp_specializes=[True]) -helion.Config(block_sizes=[4, 64, 256], flatten_loops=[False], indexing='block_ptr', l2_groupings=[8], loop_orders=[[1, 2, 0]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[True], range_unroll_factors=[1], range_warp_specializes=[True]) +helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[1, 64, 64], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True]) +helion.Config(block_sizes=[2, 8, 512], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[1, 512, 1], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[1], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[2], range_warp_specializes=[True]) +helion.Config(block_sizes=[1, 4, 256], flatten_loops=[True], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0, 2]], num_stages=2, num_warps=32, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[1], range_warp_specializes=[True]) +helion.Config(block_sizes=[1, 128, 16], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[4], range_warp_specializes=[None]) +helion.Config(block_sizes=[8, 32, 256], flatten_loops=[False], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', 'last'], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=8, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[None]) +helion.Config(block_sizes=[2, 64, 32], flatten_loops=[False], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[1, 2, 0]], num_stages=5, num_warps=16, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[4, 32, 1], flatten_loops=[True], indexing='pointer', l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[3], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 2, 128], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['', 'first'], loop_orders=[[1, 2, 0]], num_stages=2, num_warps=4, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[1], range_warp_specializes=[False]) --- assertExpectedJournal(TestAutotuner.test_save_load_config) { diff --git a/test/test_eviction_policy.expected b/test/test_eviction_policy.expected index 6321638d5..9709330f8 100644 --- a/test/test_eviction_policy.expected +++ b/test/test_eviction_policy.expected @@ -1,6 +1,56 @@ This file is automatically generated by assertExpectedJournal calls in test_eviction_policy.py. Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. +--- assertExpectedJournal(TestEvictionPolicy.test_eviction_policy_in_generated_code) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_kernel_with_eviction(x, y, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0, eviction_policy='evict_last') + v_0 = val_x + val_y + tl.store(out + indices_0 * out_stride_0, v_0, mask_0) + +def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + out = torch.empty_like(x) + _BLOCK_SIZE_0 = 16 + _launcher(_helion_kernel_with_eviction, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, out, x.size(0), out.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestEvictionPolicy.test_explicit_eviction_policy_overrides_tunable) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_kernel_with_override(x, y, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0, eviction_policy='evict_last') + val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0, eviction_policy='evict_first') + v_0 = val_x + val_y + tl.store(out + indices_0 * out_stride_0, v_0, mask_0) + +def kernel_with_override(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + out = torch.empty_like(x) + _BLOCK_SIZE_0 = 16 + _launcher(_helion_kernel_with_override, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, out, x.size(0), out.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + --- assertExpectedJournal(TestEvictionPolicy.test_hl_load_eviction_policy_emitted) from __future__ import annotations @@ -67,3 +117,30 @@ def copy_with_eviction(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 16 _launcher(_helion_copy_with_eviction, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) return out + +--- assertExpectedJournal(TestEvictionPolicy.test_multiple_loads_different_policies) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_kernel_multiple_loads(x, y, z, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, z_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0, eviction_policy='evict_first') + val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0, eviction_policy='evict_last') + val_z = tl.load(z + indices_0 * z_stride_0, mask_0, other=0) + v_0 = val_x + val_y + v_1 = v_0 + val_z + tl.store(out + indices_0 * out_stride_0, v_1, mask_0) + +def kernel_multiple_loads(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, _launcher=_default_launcher): + out = torch.empty_like(x) + _BLOCK_SIZE_0 = 16 + _launcher(_helion_kernel_multiple_loads, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, z, out, x.size(0), out.stride(0), x.stride(0), y.stride(0), z.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out diff --git a/test/test_eviction_policy.py b/test/test_eviction_policy.py index b6dc347a6..deaca9fe5 100644 --- a/test/test_eviction_policy.py +++ b/test/test_eviction_policy.py @@ -12,6 +12,7 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfRefEager import helion.language as hl @@ -38,6 +39,117 @@ def copy_with_eviction(x: torch.Tensor) -> torch.Tensor: self.assertIn("eviction_policy", code) self.assertIn("evict_last", code) + @skipIfRefEager("Config spec inspection not applicable in ref eager mode") + def test_autotune_eviction_policy_registered(self): + """Test that eviction policy tunable is automatically registered for loads in device loops.""" + + @helion.kernel + def kernel_with_loads(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.size(0)): + val_x = hl.load(x, [tile]) + val_y = hl.load(y, [tile]) + out[tile] = val_x + val_y + return out + + x = torch.randn([128], device=DEVICE, dtype=torch.float32) + y = torch.randn([128], device=DEVICE, dtype=torch.float32) + + bound_kernel = kernel_with_loads.bind((x, y)) + config_spec = bound_kernel.config_spec + + from helion.autotuner import EnumFragment + from helion.autotuner import ListOf + + fragment = config_spec.load_eviction_policies + self.assertIsInstance(fragment, ListOf) + self.assertEqual(fragment.length, 2) + self.assertIsInstance(fragment.inner, EnumFragment) + self.assertIn("", fragment.inner.choices) + self.assertIn("first", fragment.inner.choices) + self.assertIn("last", fragment.inner.choices) + + def test_eviction_policy_in_generated_code(self): + """Test that eviction policies appear in generated code when configured.""" + + @helion.kernel( + config={ + "block_size": 16, + "load_eviction_policies": ["", "last"], + } + ) + def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.size(0)): + val_x = hl.load(x, [tile]) # No eviction policy + val_y = hl.load(y, [tile]) # Should get evict_last + out[tile] = val_x + val_y + return out + + x = torch.randn([128], device=DEVICE, dtype=torch.float32) + y = torch.randn([128], device=DEVICE, dtype=torch.float32) + + code, result = code_and_output(kernel_with_eviction, (x, y)) + torch.testing.assert_close(result, x + y) + + # Check that evict_last appears in the generated code + self.assertIn("evict_last", code) + self.assertExpectedJournal(code) + + def test_explicit_eviction_policy_overrides_tunable(self): + @helion.kernel( + config={ + "block_size": 16, + "load_eviction_policies": ["first", "first"], + } + ) + def kernel_with_override(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.size(0)): + # Explicit eviction_policy should override tunable + val_x = hl.load(x, [tile], eviction_policy="evict_last") + val_y = hl.load(y, [tile]) + out[tile] = val_x + val_y + return out + + x = torch.randn([128], device=DEVICE, dtype=torch.float32) + y = torch.randn([128], device=DEVICE, dtype=torch.float32) + + code, result = code_and_output(kernel_with_override, (x, y)) + torch.testing.assert_close(result, x + y) + + self.assertIn("evict_last", code) + self.assertExpectedJournal(code) + + def test_multiple_loads_different_policies(self): + @helion.kernel( + config={ + "block_size": 16, + "load_eviction_policies": ["first", "last", ""], + } + ) + def kernel_multiple_loads( + x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.size(0)): + val_x = hl.load(x, [tile]) # evict_first + val_y = hl.load(y, [tile]) # evict_last + val_z = hl.load(z, [tile]) # None + out[tile] = val_x + val_y + val_z + return out + + x = torch.randn([128], device=DEVICE, dtype=torch.float32) + y = torch.randn([128], device=DEVICE, dtype=torch.float32) + z = torch.randn([128], device=DEVICE, dtype=torch.float32) + + code, result = code_and_output(kernel_multiple_loads, (x, y, z)) + torch.testing.assert_close(result, x + y + z) + + self.assertIn("evict_first", code) + self.assertIn("evict_last", code) + self.assertExpectedJournal(code) + instantiate_parametrized_tests(TestEvictionPolicy) diff --git a/test/test_register_tunable.expected b/test/test_register_tunable.expected index da25192e5..ddcd01167 100644 --- a/test/test_register_tunable.expected +++ b/test/test_register_tunable.expected @@ -2,7 +2,7 @@ This file is automatically generated by assertExpectedJournal calls in test_regi Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. --- assertExpectedJournal(TestRegisterTunable.test_integer_fragment) -helion.Config(block_sizes=[128], indexing='pointer', multiplier=3, num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]) +helion.Config(block_sizes=[128], indexing='pointer', load_eviction_policies=[''], multiplier=3, num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]) --- assertExpectedJournal(TestRegisterTunable.test_integer_fragment) from __future__ import annotations