Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
self.specialized_vars: set[sympy.Symbol] = set()
self.loop_dependency_checker = LoopDependencyChecker()
self._symint_cache: dict[object, torch.SymInt] = {}
self.device_load_count = (
0 # Track number of loads in all device code for eviction policy tuning
)

def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
from .device_function import contains_only_block_size_symbols
Expand Down
1 change: 1 addition & 0 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config)

self.rng_seed_count = 0
self.device_load_index = 0 # Track which load in device code we're generating (for eviction policy tuning)
# Name of the RNG seed buffer parameter in kernel signature
self.rng_seed_buffer_param_name = None

Expand Down
54 changes: 54 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,55 @@ def visit_For(self, node: ast.For) -> None:
self.generic_visit(node)


def _count_device_loads(device_ir: DeviceIR) -> int:
"""Count the number of load operations in all device code for eviction policy tuning."""
from ..language import memory_ops

# Build set of rolled graph IDs to exclude (these are duplicates)
rolled_graph_ids = {
info.new_graph_id
for info in device_ir.rolled_reductions
if info.new_graph_id is not None
}

load_count = 0
# Walk all graphs except rolled duplicates
for graph_info in device_ir.graphs:
if graph_info.graph_id in rolled_graph_ids:
continue

for node in graph_info.graph.nodes:
# Check if this is a load operation
if node.op == "call_function" and node.target is memory_ops.load:
# Only count loads without explicit eviction policy
# (user can still specify eviction_policy to override tuning)
# Check kwargs first, then check if 4th arg (eviction_policy) is None
eviction_policy_arg = node.kwargs.get("eviction_policy")
if eviction_policy_arg is None:
# Check if eviction_policy was passed as positional arg (index 3)
if len(node.args) >= 4:
eviction_policy_arg = node.args[3]
if eviction_policy_arg is None:
load_count += 1
return load_count


def _register_eviction_policy_tunable(load_count: int) -> None:
"""Register the eviction policy tunable for all device loads."""
if load_count == 0:
return

from ..autotuner.config_fragment import EnumFragment
from ..autotuner.config_fragment import ListOf
from ..autotuner.config_spec import VALID_EVICTION_POLICIES

env = CompileEnvironment.current()
# Register a tunable for eviction policies for all device loads
fragment = ListOf(EnumFragment(choices=VALID_EVICTION_POLICIES), length=load_count)
env.config_spec.load_eviction_policies = fragment
env.device_load_count = load_count


def lower_to_device_ir(func: HostFunction) -> DeviceIR:
device_ir = DeviceIR()
with func, device_ir, compile_lock:
Expand All @@ -1085,6 +1134,11 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
if len(device_ir.root_ids) > 1:
# xyz not supported with shared program IDs, but persistent kernels are allowed
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")

# Count all device loads and register eviction policy tunable
load_count = _count_device_loads(device_ir)
_register_eviction_policy_tunable(load_count)

return device_ir


Expand Down
1 change: 1 addition & 0 deletions helion/autotuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .config_fragment import BooleanFragment as BooleanFragment
from .config_fragment import EnumFragment as EnumFragment
from .config_fragment import IntegerFragment as IntegerFragment
from .config_fragment import ListOf as ListOf
from .config_fragment import PowerOfTwoFragment as PowerOfTwoFragment
from .config_spec import ConfigSpec as ConfigSpec
from .differential_evolution import (
Expand Down
46 changes: 46 additions & 0 deletions helion/autotuner/config_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,49 @@ def category(self) -> Category:
class NumWarpsFragment(PowerOfTwoFragment):
def category(self) -> Category:
return Category.NUM_WARPS


@dataclasses.dataclass
class ListOf(ConfigSpecFragment):
"""Wrapper that creates a list of independently tunable fragments.

Example:
ListOf(EnumFragment(choices=("a", "b", "c")), length=5)
creates a list of 5 independently tunable enum values.
"""

inner: ConfigSpecFragment
length: int

def default(self) -> list[object]:
"""Return a list of default values."""
return [self.inner.default() for _ in range(self.length)]

def random(self) -> list[object]:
"""Return a list of random values."""
return [self.inner.random() for _ in range(self.length)]

def pattern_neighbors(self, current: object) -> list[object]:
"""Return neighbors by changing one element at a time."""
if not isinstance(current, list) or len(current) != self.length:
raise ValueError(f"Expected list of length {self.length}, got {current!r}")

neighbors: list[object] = []
# For each position, try all neighbors from the inner fragment
for i in range(self.length):
for neighbor_value in self.inner.pattern_neighbors(current[i]):
neighbor = current.copy()
neighbor[i] = neighbor_value
neighbors.append(neighbor)
return neighbors

def differential_mutation(self, a: object, b: object, c: object) -> list[object]:
"""Create a new value by combining a, b, and c element-wise."""
assert isinstance(a, list) and len(a) == self.length
assert isinstance(b, list) and len(b) == self.length
assert isinstance(c, list) and len(c) == self.length

return [
self.inner.differential_mutation(a[i], b[i], c[i])
for i in range(self.length)
]
27 changes: 21 additions & 6 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .config_fragment import ConfigSpecFragment
from .config_fragment import EnumFragment
from .config_fragment import IntegerFragment
from .config_fragment import ListOf
from .config_fragment import NumWarpsFragment
from .config_fragment import PermutationFragment
from .config_fragment import PowerOfTwoFragment
Expand Down Expand Up @@ -50,9 +51,11 @@
"num_stages",
"pid_type",
"indexing",
"load_eviction_policies",
]
)
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
VALID_EVICTION_POLICIES = ("", "first", "last")


@dataclasses.dataclass
Expand Down Expand Up @@ -97,6 +100,11 @@ class ConfigSpec:
default_factory=functools.partial(tuple, VALID_PID_TYPES)
)
grid_block_ids: list[int] = dataclasses.field(default_factory=list)
load_eviction_policies: ListOf = dataclasses.field(
default_factory=lambda: ListOf(
EnumFragment(choices=VALID_EVICTION_POLICIES), length=0
)
)

@staticmethod
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
Expand Down Expand Up @@ -206,12 +214,16 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"range_multi_buffers",
"range_flattens",
"static_ranges",
"load_eviction_policies",
):
if not config[name]:
config.pop(name)
if not config.get(name):
config.pop(name, None)

config.setdefault("num_warps", DEFAULT_NUM_WARPS)
config.setdefault("num_stages", DEFAULT_NUM_STAGES)
config.setdefault(
"load_eviction_policies", self.load_eviction_policies.default()
)
# TODO(jansel): include num_ctas and max_nreg

for name, values in (
Expand Down Expand Up @@ -266,10 +278,12 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
"indexing": fn(EnumFragment(self._valid_indexing_types())),
"pid_type": fn(EnumFragment(self.allowed_pid_types)),
"load_eviction_policies": fn(self.load_eviction_policies),
}
# Add tunable parameters
for key, fragment in self.user_defined_tunables.items():
config[key] = fn(fragment)
config.update(
{key: fn(fragment) for key, fragment in self.user_defined_tunables.items()}
)

for name in (
"loop_orders",
Expand All @@ -282,9 +296,10 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"range_multi_buffers",
"range_flattens",
"static_ranges",
"load_eviction_policies",
):
if not config[name]:
config.pop(name)
if not config.get(name):
config.pop(name, None)
self.normalize(config)
return helion.Config(**config)

Expand Down
17 changes: 17 additions & 0 deletions helion/language/memory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@

__all__ = ["load", "store"]

# Map short config names to full Triton API names for eviction policies
_EVICTION_POLICY_MAP = {
"": None,
"first": "evict_first",
"last": "evict_last",
}


@has_side_effect
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
Expand Down Expand Up @@ -242,6 +249,16 @@ def _(state: CodegenState) -> ast.AST:
extra_mask = state.ast_args[2]
assert isinstance(extra_mask, (type(None), ast.AST))
eviction_policy = state.ast_args[3] if len(state.ast_args) > 3 else None

# If no explicit eviction_policy and we're in device code, use tunable
if eviction_policy is None and state.codegen.on_device:
policies = state.config.load_eviction_policies
idx = state.device_function.device_load_index
if idx < len(policies):
policy_value = policies[idx]
eviction_policy = _EVICTION_POLICY_MAP.get(policy_value, policy_value)
state.device_function.device_load_index += 1

if eviction_policy is not None:
assert isinstance(eviction_policy, str)
eviction_policy = ast.Constant(value=eviction_policy)
Expand Down
7 changes: 7 additions & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
range_multi_buffers: list[bool | None] | None = None,
range_flattens: list[bool | None] | None = None,
static_ranges: list[bool] | None = None,
load_eviction_policies: list[str] | None = None,
num_warps: int | None = None,
num_stages: int | None = None,
pid_type: PidTypeLiteral | None = None,
Expand All @@ -55,6 +56,7 @@ def __init__(
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
range_flattens: Controls flatten parameter for tl.range calls.
static_ranges: Whether to use tl.static_range instead tl.range.
load_eviction_policies: Eviction policies for load operations ("", "first", "last").
num_warps: Number of warps per block.
num_stages: Number of stages for software pipelining.
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
Expand All @@ -74,6 +76,7 @@ def __init__(
"range_multi_buffers": range_multi_buffers,
"range_flattens": range_flattens,
"static_ranges": static_ranges,
"load_eviction_policies": load_eviction_policies,
"num_warps": num_warps,
"num_stages": num_stages,
"indexing": indexing,
Expand Down Expand Up @@ -189,6 +192,10 @@ def range_flattens(self) -> list[bool | None]:
def static_ranges(self) -> list[bool]:
return cast("list[bool]", self.config.get("static_ranges", []))

@property
def load_eviction_policies(self) -> list[str]:
return cast("list[str]", self.config.get("load_eviction_policies", []))

@property
def indexing(self) -> IndexingLiteral:
return self.config.get("indexing", "pointer") # type: ignore[return-value]
Expand Down
Loading
Loading