Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions examples/split_k_barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import torch

import helion
from helion._testing import DEVICE
from helion._testing import run_example
from helion.autotuner import PowerOfTwoFragment
import helion.language as hl


@helion.kernel(static_shapes=True, dot_precision="ieee")
def split_k_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Two-stage split-K matmul using hl.barrier(). The barrier approach
gives deterministic results as opposed to the atomic_add approach.

Stage 1:
- Split K into `split_k` contiguous chunks.
- Each chunk computes a partial [tile_m, tile_n] product into its own slice of `tmp`.

Barrier:
- Grid-wide barrier to ensure all partials are written before reduction.

Stage 2:
- Reduce partials across the split dimension and write `out`.

Shapes:
a: [M, K]
b: [K, N]
tmp: [M, N, split_k]
out: [M, N]

Notes:
- Static shapes keep codegen simpler.
- `split_k` is fixed for clarity; autotuning could choose it instead.
"""
m, k = a.shape
_, n = b.shape
split_k = hl.register_tunable("split_k", PowerOfTwoFragment(16, 512, 64))
block_k = helion.next_power_of_2(helion.cdiv(k, split_k))
tmp = torch.empty((m, n, split_k), device=a.device, dtype=a.dtype)
out = torch.empty((m, n), device=a.device, dtype=a.dtype)

for tile_m, tile_n, tile_k_outer in hl.tile(
[m, n, k], block_size=[None, None, block_k]
):
acc = hl.zeros([tile_m, tile_n], device=a.device, dtype=a.dtype)
for tile_k_inner in hl.tile(tile_k_outer.begin, tile_k_outer.end):
acc = torch.addmm(acc, a[tile_m, tile_k_inner], b[tile_k_inner, tile_n])
# this could be a hl.atomic_add to avoid the barrier, but that would be non-determinstic
tmp[tile_m, tile_n, tile_k_outer.id] = acc

hl.barrier()

for tile_m, tile_n in hl.tile([m, n]):
out[tile_m, tile_n] = torch.sum(tmp[tile_m, tile_n, :], dim=-1)

return out


def check(m: int, k: int, n: int) -> None:
a = torch.randn(m, k, device=DEVICE)
b = torch.randn(n, k, device=DEVICE).T

run_example(
split_k_matmul,
torch.matmul,
args=(a, b),
atol=5e-1, # long reduction accumulate errors
)


def main() -> None:
torch.manual_seed(0)
check(16, 4096, 16)


if __name__ == "__main__":
main()
10 changes: 4 additions & 6 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from .. import exc
from ..language.constexpr import ConstExpr
from .loop_dependency_checker import LoopDependencyChecker
from .source_location import SourceLocation
from .source_location import current_location
from .variable_origin import BlockSizeOrigin
Expand Down Expand Up @@ -102,18 +101,18 @@ def __init__(
self.block_sizes: list[BlockSizeInfo] = []
self.debug_shape_renames: dict[sympy.Expr, sympy.Expr] = {}
self.config_spec = ConfigSpec()
if settings.autotune_force_persistent:
for pid_type in ("flat", "xyz"):
self.config_spec.disallow_pid_type(pid_type)
self.kernel_tensor_sizes: dict[tuple[sympy.Expr, ...], int] = (
collections.Counter()
)
self.specialized_vars: set[sympy.Symbol] = set()
self.loop_dependency_checker = LoopDependencyChecker()
self._symint_cache: dict[object, torch.SymInt] = {}
self.device_load_count = (
0 # Track number of loads in all device code for eviction policy tuning
)
if settings.autotune_force_persistent:
for pid_type in ("flat", "xyz"):
self.config_spec.disallow_pid_type(pid_type)
self.has_barrier: bool = False

def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
from .device_function import contains_only_block_size_symbols
Expand Down Expand Up @@ -405,7 +404,6 @@ def __enter__(self) -> Self:
assert getattr(tls, "env", None) is None, "CompileEnvironment already active"
self.fake_mode.__enter__()
tls.env = self
self.loop_dependency_checker = LoopDependencyChecker()
return self

def __exit__(
Expand Down
12 changes: 11 additions & 1 deletion helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,12 @@ def __init__(self, val: int) -> None:


class DeviceFunction:
def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
def __init__(
self,
name: str,
config: Config,
codegen: GenerateAST,
) -> None:
super().__init__()
self.name = name
self.config = config
Expand Down Expand Up @@ -659,6 +664,11 @@ def codegen_function_call(self) -> ast.AST:
[
f"num_warps={num_warps}",
f"num_stages={self.config.num_stages}",
*(
["launch_cooperative_grid=True"]
if CompileEnvironment.current().has_barrier
else []
),
]
+ [
f"{x.removeprefix('_triton_config_')}={self.config[x]}"
Expand Down
69 changes: 69 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .inductor_lowering import CodegenState
from .inductor_lowering import codegen_call_with_graph
from .inductor_lowering import prepare_graph_lowerings
from .loop_dependency_checker import LoopDependencyChecker
from .matmul_utils import tensor_matmul_replacement
from .matmul_utils import torch_matmul_replacement
from .node_masking import remove_unnecessary_masking
Expand Down Expand Up @@ -189,6 +190,8 @@ def codegen(self, state: CodegenState) -> list[object]:


class RootGraphInfo(GraphInfo):
phase_index: int = 0

@property
def name(self) -> str:
return f"root_graph_{self.graph_id}"
Expand Down Expand Up @@ -376,12 +379,22 @@ class RolledReductionInfo(NamedTuple):
can_be_rolled_by_caller: bool


@dataclasses.dataclass
class KernelPhase:
roots: list[int] # store root indices
root_nodes: list[ast.For]
loop_dependency_checker: LoopDependencyChecker = dataclasses.field(
default_factory=LoopDependencyChecker
)


class DeviceIR:
def __init__(self) -> None:
super().__init__()
self.graphs: list[GraphInfo] = []
self.root_ids: list[int] = []
self.rolled_reductions: list[RolledReductionInfo] = []
self.phases: list[KernelPhase] = []
self.grid_block_ids: list[list[int]] = []

def get_root(self, config: Config, graph_id: int) -> torch.fx.Graph:
Expand Down Expand Up @@ -435,6 +448,11 @@ def add_reduction_loop_graph(
def add_root_graph(self, graph: torch.fx.Graph) -> None:
self.root_ids.append(self.add_graph(graph, graph_info_cls=RootGraphInfo))

def phase_for_root(self, root_id: int) -> int:
graph_info = self.graphs[self.root_ids[root_id]]
assert isinstance(graph_info, RootGraphInfo)
return graph_info.phase_index

def build_rolled_reductions(self) -> None:
env = CompileEnvironment.current()
rdims = [bs for bs in env.block_sizes if bs.reduction]
Expand Down Expand Up @@ -1274,6 +1292,10 @@ class WalkHostAST(NodeVisitor):
def __init__(self, device_ir: DeviceIR) -> None:
super().__init__()
self.device_ir = device_ir
self.root_index = 0
self.current_phase_roots: list[int] = []
self.phases: list[KernelPhase] = []
self.root_nodes: list[ast.For] = []

def visit_For(self, node: ast.For) -> None:
assert isinstance(node, ExtendedAST)
Expand All @@ -1292,9 +1314,44 @@ def visit_For(self, node: ast.For) -> None:
# pyrefly: ignore [missing-attribute]
block_ids = [inner.block_id]
self.device_ir.grid_block_ids.append(block_ids)
# store root index (position) not graph id
self.root_nodes.append(node)
self.current_phase_roots.append(len(self.device_ir.root_ids) - 1)
self.root_index += 1
else:
self.generic_visit(node)

def visit_Expr(self, node: ast.Expr) -> None:
# Record barrier placement between top-level loops.
from .type_propagation import BarrierResultType

assert isinstance(node, ExtendedAST)
assert isinstance(node.value, ExtendedAST)
is_barrier = isinstance(node.value._type_info, BarrierResultType)

if is_barrier:
if self.root_index == 0 or not self.current_phase_roots:
raise exc.BarrierOnlyAllowedAtTopLevel
self.phases.append(
KernelPhase(
roots=self.current_phase_roots,
root_nodes=[self.root_nodes[r] for r in self.current_phase_roots],
)
)
self.current_phase_roots = []
return
self.generic_visit(node)

def flush_phases(self) -> None:
if self.current_phase_roots:
self.phases.append(
KernelPhase(
roots=self.current_phase_roots,
root_nodes=[self.root_nodes[r] for r in self.current_phase_roots],
)
)
self.current_phase_roots = []


def _count_device_loads_and_stores(device_ir: DeviceIR) -> tuple[int, int, int]:
"""Count the number of load and store operations in device code for autotuning.
Expand Down Expand Up @@ -1386,6 +1443,18 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
visitor = WalkHostAST(device_ir)
for stmt in func.body:
visitor.visit(stmt)
visitor.flush_phases()
device_ir.phases = visitor.phases
# Run dependency checks once, per phase, so codegen does not redo it per-config.
for phase in device_ir.phases:
checker = phase.loop_dependency_checker
for loop_node in phase.root_nodes:
checker.register_loop(loop_node)
for phase_idx, phase in enumerate(device_ir.phases):
for ridx in phase.roots:
graph_info = device_ir.graphs[device_ir.root_ids[ridx]]
assert isinstance(graph_info, RootGraphInfo)
graph_info.phase_index = phase_idx
# If there are no top-level device loops, we cannot generate a valid kernel.
# Raise a friendly error instead of emitting an empty Triton function body.
if len(device_ir.root_ids) == 0:
Expand Down
28 changes: 25 additions & 3 deletions helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .. import exc
from ..language._decorators import is_api_func
from ..runtime.config import Config
from .ast_extension import ExtendedAST
from .ast_extension import LoopType
from .ast_extension import NodeVisitor
Expand All @@ -24,6 +25,7 @@
from .helper_function import CodegenInterface
from .inductor_lowering import CodegenState
from .inductor_lowering import codegen_call_with_graph
from .loop_dependency_checker import LoopDependencyChecker
from .program_id import ForEachProgramID
from .tile_strategy import DeviceLoopState
from .variable_origin import ArgumentOrigin
Expand All @@ -35,6 +37,7 @@

from ..runtime import Config
from .host_function import HostFunction
from .loop_dependency_checker import LoopDependencyChecker
from .tile_strategy import DeviceLoopOrGridState
from .type_propagation import TensorType

Expand All @@ -55,7 +58,11 @@ def __init__(self, func: HostFunction, config: Config) -> None:
self.next_else_block: list[ast.AST] | None = None

# Now create device function and initialize CodegenInterface
self.device_function = DeviceFunction(f"_helion_{func.name}", config, self)
self.device_function = DeviceFunction(
f"_helion_{func.name}",
config,
self,
)
CodegenInterface.__init__(self, self.device_function)

def offset_var(self, block_idx: int) -> str:
Expand All @@ -69,6 +76,10 @@ def mask_var(self, block_idx: int) -> str | None:
return loops[-1].strategy.mask_var(block_idx)
return None

def _phase_checker(self, root_id: int) -> LoopDependencyChecker:
phase_idx = self.host_function.device_ir.phase_for_root(root_id)
return self.host_function.device_ir.phases[phase_idx].loop_dependency_checker

def add_statement(self, stmt: ast.AST | str | None) -> None:
if stmt is None:
return
Expand Down Expand Up @@ -226,17 +237,20 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
if node._loop_type == LoopType.GRID:
assert not node.orelse

assert node._root_id is not None
# Loop dependency checks were already run during lowering; phase checker kept for symmetry/debug.
self._phase_checker(node._root_id)

if len(self.host_function.device_ir.root_ids) == 1:
body = self.device_function.body
else:
assert len(self.host_function.device_ir.root_ids) > 1
assert node._root_id is not None
# Multiple top level for loops

if node._root_id == 0:
self.device_function.set_pid(
ForEachProgramID(
self.device_function.new_var("pid_shared", dce=False)
self.device_function.new_var("pid_shared", dce=False),
)
)
self.device_function.body.extend(
Expand Down Expand Up @@ -309,6 +323,11 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
# This ensures block size and rdim vars are defined in the correct order
self.device_function.flush_deferred_rdim_defs(self)

if isinstance(self.device_function.pid, ForEachProgramID):
self.device_function.pid.case_phases.append(
self.host_function.device_ir.phase_for_root(node._root_id)
)

# If we are in a multi top level loop, for all loops except for the last one
# emit ifthenelse blocks
if node._root_id < len(self.host_function.device_ir.root_ids) - 1:
Expand Down Expand Up @@ -476,6 +495,9 @@ def generate_ast(
func: HostFunction, config: Config, emit_repro_caller: bool
) -> ast.AST:
with func:
if len(func.device_ir.phases) > 1:
if not str(config.pid_type).startswith("persistent"):
raise exc.BarrierRequiresPersistent(config.pid_type)
codegen = GenerateAST(func, config)
with codegen.device_function:
for stmt in func.body:
Expand Down
19 changes: 17 additions & 2 deletions helion/_compiler/loop_dependency_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,29 @@ class LoopDependencyChecker:
def __init__(self) -> None:
self.reads: set[str] = set()
self.writes: set[str] = set()

def register_loop(self, loop_node: ast.For) -> None:
self._barrier_after_root: set[int] = set()
self._root_counter: int = 0
self.disabled: bool = False

def insert_barrier_after_root(self, root_id: int) -> None:
"""Record that a barrier separates root_id and root_id+1."""
self._barrier_after_root.add(root_id)

def register_loop(self, loop_node: ast.For, root_id: int | None = None) -> None:
if self.disabled:
return
current_root = root_id if root_id is not None else self._root_counter
if (current_root - 1) in self._barrier_after_root:
self.reads.clear()
self.writes.clear()
self._barrier_after_root.discard(current_root - 1)
rw = ReadWrites.from_list(loop_node.body)

self._check_dependencies(rw)

self.reads |= set(rw.reads)
self.writes |= set(rw.writes)
self._root_counter = current_root + 1

def _check_dependencies(self, rw: ReadWrites) -> None:
"""
Expand Down
Loading
Loading