diff --git a/examples/split_k_barrier.py b/examples/split_k_barrier.py new file mode 100644 index 000000000..a554de6ce --- /dev/null +++ b/examples/split_k_barrier.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import run_example +from helion.autotuner import PowerOfTwoFragment +import helion.language as hl + + +@helion.kernel(static_shapes=True, dot_precision="ieee") +def split_k_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Two-stage split-K matmul using hl.barrier(). The barrier approach + gives deterministic results as opposed to the atomic_add approach. + + Stage 1: + - Split K into `split_k` contiguous chunks. + - Each chunk computes a partial [tile_m, tile_n] product into its own slice of `tmp`. + + Barrier: + - Grid-wide barrier to ensure all partials are written before reduction. + + Stage 2: + - Reduce partials across the split dimension and write `out`. + + Shapes: + a: [M, K] + b: [K, N] + tmp: [M, N, split_k] + out: [M, N] + + Notes: + - Static shapes keep codegen simpler. + - `split_k` is fixed for clarity; autotuning could choose it instead. + """ + m, k = a.shape + _, n = b.shape + split_k = hl.register_tunable("split_k", PowerOfTwoFragment(16, 512, 64)) + block_k = helion.next_power_of_2(helion.cdiv(k, split_k)) + tmp = torch.empty((m, n, split_k), device=a.device, dtype=a.dtype) + out = torch.empty((m, n), device=a.device, dtype=a.dtype) + + for tile_m, tile_n, tile_k_outer in hl.tile( + [m, n, k], block_size=[None, None, block_k] + ): + acc = hl.zeros([tile_m, tile_n], device=a.device, dtype=a.dtype) + for tile_k_inner in hl.tile(tile_k_outer.begin, tile_k_outer.end): + acc = torch.addmm(acc, a[tile_m, tile_k_inner], b[tile_k_inner, tile_n]) + # this could be a hl.atomic_add to avoid the barrier, but that would be non-determinstic + tmp[tile_m, tile_n, tile_k_outer.id] = acc + + hl.barrier() + + for tile_m, tile_n in hl.tile([m, n]): + out[tile_m, tile_n] = torch.sum(tmp[tile_m, tile_n, :], dim=-1) + + return out + + +def check(m: int, k: int, n: int) -> None: + a = torch.randn(m, k, device=DEVICE) + b = torch.randn(n, k, device=DEVICE).T + + run_example( + split_k_matmul, + torch.matmul, + args=(a, b), + atol=5e-1, # long reduction accumulate errors + ) + + +def main() -> None: + torch.manual_seed(0) + check(16, 4096, 16) + + +if __name__ == "__main__": + main() diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 09d63af11..7a5460b80 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -21,7 +21,6 @@ from .. import exc from ..language.constexpr import ConstExpr -from .loop_dependency_checker import LoopDependencyChecker from .source_location import SourceLocation from .source_location import current_location from .variable_origin import BlockSizeOrigin @@ -102,18 +101,18 @@ def __init__( self.block_sizes: list[BlockSizeInfo] = [] self.debug_shape_renames: dict[sympy.Expr, sympy.Expr] = {} self.config_spec = ConfigSpec() - if settings.autotune_force_persistent: - for pid_type in ("flat", "xyz"): - self.config_spec.disallow_pid_type(pid_type) self.kernel_tensor_sizes: dict[tuple[sympy.Expr, ...], int] = ( collections.Counter() ) self.specialized_vars: set[sympy.Symbol] = set() - self.loop_dependency_checker = LoopDependencyChecker() self._symint_cache: dict[object, torch.SymInt] = {} self.device_load_count = ( 0 # Track number of loads in all device code for eviction policy tuning ) + if settings.autotune_force_persistent: + for pid_type in ("flat", "xyz"): + self.config_spec.disallow_pid_type(pid_type) + self.has_barrier: bool = False def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None: from .device_function import contains_only_block_size_symbols @@ -405,7 +404,6 @@ def __enter__(self) -> Self: assert getattr(tls, "env", None) is None, "CompileEnvironment already active" self.fake_mode.__enter__() tls.env = self - self.loop_dependency_checker = LoopDependencyChecker() return self def __exit__( diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 95d48ab3e..a05048b5e 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -197,7 +197,12 @@ def __init__(self, val: int) -> None: class DeviceFunction: - def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: + def __init__( + self, + name: str, + config: Config, + codegen: GenerateAST, + ) -> None: super().__init__() self.name = name self.config = config @@ -659,6 +664,11 @@ def codegen_function_call(self) -> ast.AST: [ f"num_warps={num_warps}", f"num_stages={self.config.num_stages}", + *( + ["launch_cooperative_grid=True"] + if CompileEnvironment.current().has_barrier + else [] + ), ] + [ f"{x.removeprefix('_triton_config_')}={self.config[x]}" diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 658e62f4b..d44d12cca 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -48,6 +48,7 @@ from .inductor_lowering import CodegenState from .inductor_lowering import codegen_call_with_graph from .inductor_lowering import prepare_graph_lowerings +from .loop_dependency_checker import LoopDependencyChecker from .matmul_utils import tensor_matmul_replacement from .matmul_utils import torch_matmul_replacement from .node_masking import remove_unnecessary_masking @@ -189,6 +190,8 @@ def codegen(self, state: CodegenState) -> list[object]: class RootGraphInfo(GraphInfo): + phase_index: int = 0 + @property def name(self) -> str: return f"root_graph_{self.graph_id}" @@ -376,12 +379,22 @@ class RolledReductionInfo(NamedTuple): can_be_rolled_by_caller: bool +@dataclasses.dataclass +class KernelPhase: + roots: list[int] # store root indices + root_nodes: list[ast.For] + loop_dependency_checker: LoopDependencyChecker = dataclasses.field( + default_factory=LoopDependencyChecker + ) + + class DeviceIR: def __init__(self) -> None: super().__init__() self.graphs: list[GraphInfo] = [] self.root_ids: list[int] = [] self.rolled_reductions: list[RolledReductionInfo] = [] + self.phases: list[KernelPhase] = [] self.grid_block_ids: list[list[int]] = [] def get_root(self, config: Config, graph_id: int) -> torch.fx.Graph: @@ -435,6 +448,11 @@ def add_reduction_loop_graph( def add_root_graph(self, graph: torch.fx.Graph) -> None: self.root_ids.append(self.add_graph(graph, graph_info_cls=RootGraphInfo)) + def phase_for_root(self, root_id: int) -> int: + graph_info = self.graphs[self.root_ids[root_id]] + assert isinstance(graph_info, RootGraphInfo) + return graph_info.phase_index + def build_rolled_reductions(self) -> None: env = CompileEnvironment.current() rdims = [bs for bs in env.block_sizes if bs.reduction] @@ -1274,6 +1292,10 @@ class WalkHostAST(NodeVisitor): def __init__(self, device_ir: DeviceIR) -> None: super().__init__() self.device_ir = device_ir + self.root_index = 0 + self.current_phase_roots: list[int] = [] + self.phases: list[KernelPhase] = [] + self.root_nodes: list[ast.For] = [] def visit_For(self, node: ast.For) -> None: assert isinstance(node, ExtendedAST) @@ -1292,9 +1314,44 @@ def visit_For(self, node: ast.For) -> None: # pyrefly: ignore [missing-attribute] block_ids = [inner.block_id] self.device_ir.grid_block_ids.append(block_ids) + # store root index (position) not graph id + self.root_nodes.append(node) + self.current_phase_roots.append(len(self.device_ir.root_ids) - 1) + self.root_index += 1 else: self.generic_visit(node) + def visit_Expr(self, node: ast.Expr) -> None: + # Record barrier placement between top-level loops. + from .type_propagation import BarrierResultType + + assert isinstance(node, ExtendedAST) + assert isinstance(node.value, ExtendedAST) + is_barrier = isinstance(node.value._type_info, BarrierResultType) + + if is_barrier: + if self.root_index == 0 or not self.current_phase_roots: + raise exc.BarrierOnlyAllowedAtTopLevel + self.phases.append( + KernelPhase( + roots=self.current_phase_roots, + root_nodes=[self.root_nodes[r] for r in self.current_phase_roots], + ) + ) + self.current_phase_roots = [] + return + self.generic_visit(node) + + def flush_phases(self) -> None: + if self.current_phase_roots: + self.phases.append( + KernelPhase( + roots=self.current_phase_roots, + root_nodes=[self.root_nodes[r] for r in self.current_phase_roots], + ) + ) + self.current_phase_roots = [] + def _count_device_loads_and_stores(device_ir: DeviceIR) -> tuple[int, int, int]: """Count the number of load and store operations in device code for autotuning. @@ -1386,6 +1443,18 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: visitor = WalkHostAST(device_ir) for stmt in func.body: visitor.visit(stmt) + visitor.flush_phases() + device_ir.phases = visitor.phases + # Run dependency checks once, per phase, so codegen does not redo it per-config. + for phase in device_ir.phases: + checker = phase.loop_dependency_checker + for loop_node in phase.root_nodes: + checker.register_loop(loop_node) + for phase_idx, phase in enumerate(device_ir.phases): + for ridx in phase.roots: + graph_info = device_ir.graphs[device_ir.root_ids[ridx]] + assert isinstance(graph_info, RootGraphInfo) + graph_info.phase_index = phase_idx # If there are no top-level device loops, we cannot generate a valid kernel. # Raise a friendly error instead of emitting an empty Triton function body. if len(device_ir.root_ids) == 0: diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index f07f41ded..d8fab907c 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -10,6 +10,7 @@ from .. import exc from ..language._decorators import is_api_func +from ..runtime.config import Config from .ast_extension import ExtendedAST from .ast_extension import LoopType from .ast_extension import NodeVisitor @@ -24,6 +25,7 @@ from .helper_function import CodegenInterface from .inductor_lowering import CodegenState from .inductor_lowering import codegen_call_with_graph +from .loop_dependency_checker import LoopDependencyChecker from .program_id import ForEachProgramID from .tile_strategy import DeviceLoopState from .variable_origin import ArgumentOrigin @@ -35,6 +37,7 @@ from ..runtime import Config from .host_function import HostFunction + from .loop_dependency_checker import LoopDependencyChecker from .tile_strategy import DeviceLoopOrGridState from .type_propagation import TensorType @@ -55,7 +58,11 @@ def __init__(self, func: HostFunction, config: Config) -> None: self.next_else_block: list[ast.AST] | None = None # Now create device function and initialize CodegenInterface - self.device_function = DeviceFunction(f"_helion_{func.name}", config, self) + self.device_function = DeviceFunction( + f"_helion_{func.name}", + config, + self, + ) CodegenInterface.__init__(self, self.device_function) def offset_var(self, block_idx: int) -> str: @@ -69,6 +76,10 @@ def mask_var(self, block_idx: int) -> str | None: return loops[-1].strategy.mask_var(block_idx) return None + def _phase_checker(self, root_id: int) -> LoopDependencyChecker: + phase_idx = self.host_function.device_ir.phase_for_root(root_id) + return self.host_function.device_ir.phases[phase_idx].loop_dependency_checker + def add_statement(self, stmt: ast.AST | str | None) -> None: if stmt is None: return @@ -226,17 +237,20 @@ def visit_For(self, node: ast.For) -> ast.AST | None: if node._loop_type == LoopType.GRID: assert not node.orelse + assert node._root_id is not None + # Loop dependency checks were already run during lowering; phase checker kept for symmetry/debug. + self._phase_checker(node._root_id) + if len(self.host_function.device_ir.root_ids) == 1: body = self.device_function.body else: assert len(self.host_function.device_ir.root_ids) > 1 - assert node._root_id is not None # Multiple top level for loops if node._root_id == 0: self.device_function.set_pid( ForEachProgramID( - self.device_function.new_var("pid_shared", dce=False) + self.device_function.new_var("pid_shared", dce=False), ) ) self.device_function.body.extend( @@ -309,6 +323,11 @@ def visit_For(self, node: ast.For) -> ast.AST | None: # This ensures block size and rdim vars are defined in the correct order self.device_function.flush_deferred_rdim_defs(self) + if isinstance(self.device_function.pid, ForEachProgramID): + self.device_function.pid.case_phases.append( + self.host_function.device_ir.phase_for_root(node._root_id) + ) + # If we are in a multi top level loop, for all loops except for the last one # emit ifthenelse blocks if node._root_id < len(self.host_function.device_ir.root_ids) - 1: @@ -476,6 +495,9 @@ def generate_ast( func: HostFunction, config: Config, emit_repro_caller: bool ) -> ast.AST: with func: + if len(func.device_ir.phases) > 1: + if not str(config.pid_type).startswith("persistent"): + raise exc.BarrierRequiresPersistent(config.pid_type) codegen = GenerateAST(func, config) with codegen.device_function: for stmt in func.body: diff --git a/helion/_compiler/loop_dependency_checker.py b/helion/_compiler/loop_dependency_checker.py index db8e9ff91..e19267b09 100644 --- a/helion/_compiler/loop_dependency_checker.py +++ b/helion/_compiler/loop_dependency_checker.py @@ -22,14 +22,29 @@ class LoopDependencyChecker: def __init__(self) -> None: self.reads: set[str] = set() self.writes: set[str] = set() - - def register_loop(self, loop_node: ast.For) -> None: + self._barrier_after_root: set[int] = set() + self._root_counter: int = 0 + self.disabled: bool = False + + def insert_barrier_after_root(self, root_id: int) -> None: + """Record that a barrier separates root_id and root_id+1.""" + self._barrier_after_root.add(root_id) + + def register_loop(self, loop_node: ast.For, root_id: int | None = None) -> None: + if self.disabled: + return + current_root = root_id if root_id is not None else self._root_counter + if (current_root - 1) in self._barrier_after_root: + self.reads.clear() + self.writes.clear() + self._barrier_after_root.discard(current_root - 1) rw = ReadWrites.from_list(loop_node.body) self._check_dependencies(rw) self.reads |= set(rw.reads) self.writes |= set(rw.writes) + self._root_counter = current_root + 1 def _check_dependencies(self, rw: ReadWrites) -> None: """ diff --git a/helion/_compiler/program_id.py b/helion/_compiler/program_id.py index 0cb32fd6f..38aedf635 100644 --- a/helion/_compiler/program_id.py +++ b/helion/_compiler/program_id.py @@ -6,10 +6,14 @@ from typing import TYPE_CHECKING from typing import NamedTuple +import torch + +from .ast_extension import create from .ast_extension import expr_from_string from .ast_extension import statement_from_string from .compile_environment import CompileEnvironment from .device_function import DeviceFunction +from .device_function import TensorArg from .host_function import HostFunction if TYPE_CHECKING: @@ -142,7 +146,9 @@ class ForEachProgramID(ProgramIDs): # pyrefly: ignore [bad-override] shared_pid_var: str cases: list[ProgramIDs] = dataclasses.field(default_factory=list) + case_phases: list[int] = dataclasses.field(default_factory=list) pid_info: list[PIDInfo] = dataclasses.field(default_factory=list, init=False) + barrier_after_root: set[int] = dataclasses.field(default_factory=set) def codegen_pid_init(self) -> list[ast.stmt]: # Check if persistent kernels are enabled in config - if so, skip regular initialization @@ -174,11 +180,37 @@ def codegen_test(self, state: CodegenState) -> ast.AST: def setup_persistent_kernel( self, device_function: DeviceFunction, total_pids_expr: str | None = None ) -> list[ast.stmt] | None: - # Persistent type will be the same for every case, so we can use the first one - return self.cases[0].setup_persistent_kernel( - device_function, self.total_pids_expr(is_device=True) + total_expr = self.total_pids_expr(is_device=True) + # If there is only one phase, fall back to existing behavior. + has_phases = len(set(self.case_phases)) > 1 + + def _base_strategy(pid: ProgramIDs) -> ProgramIDs: + from .tile_strategy import L2GroupingProgramIDs + + if isinstance(pid, L2GroupingProgramIDs): + assert pid.parent_strategy is not None, ( + "L2 grouping strategy is missing its parent" + ) + return pid.parent_strategy + return pid + + base_strategy = _base_strategy(self.cases[0]) + + if not has_phases: + return base_strategy.setup_persistent_kernel(device_function, total_expr) + + # We expect a persistent-blocked strategy when barriers are present. + if not base_strategy._is_persistent(): + return base_strategy.setup_persistent_kernel(device_function, total_expr) + + assert isinstance(base_strategy, PersistentProgramIDs) + assert base_strategy.is_blocked, ( + "hl.barrier() currently requires persistent_blocked" ) + # Delegate to helper for phase-split persistent loops + return self._emit_phase_loops(base_strategy, device_function, total_expr) + def total_pids_expr(self, *, is_device: bool) -> str: """Get total PIDs expression for ForEachProgramID (sum of all pids).""" cdivs = [pid.total_pids_expr(is_device=is_device) for pid in self.cases] @@ -217,6 +249,109 @@ def _prepare_persistent_body( *body, ] + def _phase_boundaries(self) -> list[str]: + """Compute cumulative PID boundaries at phase transitions.""" + cdivs = [pid.total_pids_expr(is_device=True) for pid in self.cases] + boundaries: list[str] = [] + running = "0" + prev_phase = self.case_phases[0] + for idx, cdiv in enumerate(cdivs): + running = f"({running}) + ({cdiv})" + next_phase = ( + self.case_phases[idx + 1] + if idx + 1 < len(self.case_phases) + else prev_phase + ) + if next_phase != prev_phase or idx == len(cdivs) - 1: + boundaries.append(running) + prev_phase = next_phase + return boundaries + + def _emit_phase_loops( + self, + strategy: PersistentProgramIDs, + device_function: DeviceFunction, + total_expr: str, + ) -> list[ast.stmt]: + """Emit persistent loops split by KernelPhase boundaries.""" + from .tile_strategy import TileStrategy + + # persistent setup preamble (mirrors PersistentProgramIDs.setup_persistent_kernel) + setup_statements = [ + statement_from_string(f"{strategy.total_pids_var} = {total_expr}"), + ] + if strategy.block_size_var and strategy.start_pid_var and strategy.end_pid_var: + assignments = [ + ( + strategy.block_size_var, + f"tl.cdiv({strategy.total_pids_var}, {NUM_SM_VAR})", + ), + ( + strategy.start_pid_var, + f"tl.program_id(0) * {strategy.block_size_var}", + ), + ( + strategy.end_pid_var, + f"tl.minimum({strategy.start_pid_var} + {strategy.block_size_var}, {strategy.total_pids_var})", + ), + ] + setup_statements.extend( + [statement_from_string(f"{var} = {expr}") for var, expr in assignments] + ) + device_function.preamble.extend(setup_statements) + + boundaries = self._phase_boundaries() + block_ids = [pid.block_id for pid in strategy.pid_info] + + def range_expr(begin: str, end: str) -> str: + return TileStrategy.get_range_call_str( + device_function.config, block_ids, begin=begin, end=end + ) + + base_body = self._prepare_persistent_body( + device_function.body, device_function, strategy.virtual_pid_var + ) + + sem_arg = device_function.new_var("x_grid_sem", dce=False) + device_function.arguments.append( + TensorArg( + sem_arg, + torch.empty(1, device="meta", dtype=torch.uint32), + f"torch.zeros((1,), device={strategy.get_device_str()}, dtype=torch.uint32)", + ) + ) + + loops: list[ast.stmt] = [] + start_expr = "0" + for boundary in boundaries: + cond = expr_from_string( + f"({strategy.virtual_pid_var} >= ({start_expr})) and ({strategy.virtual_pid_var} < ({boundary}))" + ) + loop_body = [create(ast.If, test=cond, body=list(base_body), orelse=[])] + loops.append( + create( + ast.For, + target=create( + ast.Name, id=strategy.virtual_pid_var, ctx=ast.Store() + ), + iter=expr_from_string( + range_expr( + f"tl.maximum({strategy.start_pid_var}, {start_expr})", + f"tl.minimum({strategy.end_pid_var}, {boundary})", + ) + ), + body=loop_body, + orelse=[], + type_comment=None, + ) + ) + if boundary != boundaries[-1]: + loops.append( + statement_from_string(f"triton_helpers.x_grid_barrier({sem_arg})") + ) + start_expr = boundary + return loops + class XYZProgramIDs(ProgramIDs): """Use the cuda x/y/z launch grid for PIDs""" diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index a87a3a28a..f074fd5ca 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -21,6 +21,7 @@ from .compile_environment import CompileEnvironment from .compile_environment import _has_unbacked from .compile_environment import _to_sympy +from .device_function import DeviceFunction from .host_function import HostFunction from .program_id import FlatProgramIDs from .program_id import ForEachProgramID @@ -353,6 +354,18 @@ def _fold_tile_end_op( return loop_info.end_expr return end + def select_pid_strategy(self) -> ProgramIDs: + pid_type = self.fn.config.pid_type + if pid_type == "xyz": + assert 1 < len(self.block_ids) <= 3 + return XYZProgramIDs() + if pid_type == "persistent_blocked": + return PersistentBlockedProgramIDs() + if pid_type == "persistent_interleaved": + return PersistentInterleavedProgramIDs() + assert pid_type == "flat" + return FlatProgramIDs() + class FlattenedTileStrategy(BlockSizeTileStrategy): """Collapse all dimensions into single flat iteration space.""" @@ -439,24 +452,27 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState: ) env = CompileEnvironment.current() dtype = env.triton_index_type() + + pid_var = state.device_function.new_var("pid_flat", dce=True) + pids = self.select_pid_strategy() + if isinstance(state.device_function.pid, ForEachProgramID): + pids.shared_pid_var = state.device_function.pid.shared_pid_var + + pids.append(PIDInfo(pid_var, block_size_var, total_numel, self.block_ids[0])) + state.add_statement( - f"{offsets_var} = tl.program_id(0) * ({block_size_var}) + tl.arange(0, {block_size_var}).to({dtype})" + f"{offsets_var} = {pid_var} * ({block_size_var}) + tl.arange(0, {block_size_var}).to({dtype})" ) state.codegen.statements_stack[-1].extend(statements) - class TmpPid(ProgramIDs): - def codegen_grid(self) -> ast.AST: - return expr_from_string( - f"(triton.cdiv({HostFunction.current().sympy_expr(total_numel)}, {block_size_var}), 1, 1)" - ) - - def codegen(self, state: CodegenState) -> None: - pass # No-op implementation for TmpPid - - def total_pids_expr(self, *, is_device: bool) -> str: - return "1" # Simple implementation for TmpPid + pids.codegen(state) - state.device_function.set_pid(TmpPid()) + if isinstance(state.device_function.pid, ForEachProgramID): + shared_pid = state.device_function.pid + shared_pid.cases.append(pids) + shared_pid.codegen(state) + else: + state.device_function.set_pid(pids) block_id_to_info = self._create_block_id_info_dict(state) return DeviceGridState(self, block_id_to_info=block_id_to_info) @@ -518,6 +534,11 @@ def update_allow_flattened(cls, shape: Sequence[sympy.Expr]) -> None: break def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]: + # Keep axis structure intact for multi-phase kernels (e.g., barrier) to + # avoid mismatched ranks in downstream reductions. + if len(HostFunction.current().device_ir.root_ids) > 1: + return shapes + env = CompileEnvironment.current() # Filter out unit-sized blocks that don't need compacting compact_block_ids = [ @@ -654,18 +675,6 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState: block_id_to_info = self._create_block_id_info_dict(state) return DeviceGridState(self, block_id_to_info=block_id_to_info) - def select_pid_strategy(self) -> ProgramIDs: - pid_type = self.fn.config.pid_type - if pid_type == "xyz": - assert 1 < len(self.block_ids) <= 3 - return XYZProgramIDs() - if pid_type == "persistent_blocked": - return PersistentBlockedProgramIDs() - if pid_type == "persistent_interleaved": - return PersistentInterleavedProgramIDs() - assert pid_type == "flat" - return FlatProgramIDs() - def _to_ast(self, x: object, to_dtype: str | None = None) -> ast.AST: if isinstance(x, ast.AST): if to_dtype: diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 09a74bdff..4a48fab44 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -2497,14 +2497,29 @@ def propagate_types(func: HostFunction) -> None: assert not func.fn.__closure__ prop = TypePropagation(func, local_scope) + def _is_barrier_stmt(statement: ast.stmt) -> bool: + if isinstance(statement, ast.Expr): + value = statement.value + type_info = getattr(value, "_type_info", None) + return isinstance(type_info, BarrierResultType) + return False + seen_for_loop = False seen_non_for_loop_statement_after_for_loop = False + phase_index: int = 0 for stmt in func.body: + prop.visit(stmt) + if _is_barrier_stmt(stmt): + phase_index += 1 + barrier_stmt = _is_barrier_stmt(stmt) if isinstance(stmt, ast.For): if seen_for_loop and seen_non_for_loop_statement_after_for_loop: # TODO(oulgen): This check is too coarse, refine it. raise exc.TopLevelStatementBetweenLoops seen_for_loop = True - elif seen_for_loop: + elif seen_for_loop and not barrier_stmt: seen_non_for_loop_statement_after_for_loop = True - prop.visit(stmt) + + +class BarrierResultType(LiteralType): + """Marker type returned by hl.barrier() to signal a phase boundary.""" diff --git a/helion/exc.py b/helion/exc.py index f29e67d15..b39a7ca2f 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -84,13 +84,24 @@ class DeviceLoopElseBlock(BaseError): class LoopDependencyError(BaseError): - message = "Loop dependency detected: '{0}' was written in a previous loop." + message = ( + "Loop dependency detected: '{0}' was written in a previous loop. " + "If this dependency is intentional, insert hl.barrier() between the loops." + ) class TopLevelStatementBetweenLoops(BaseError): message = "Statements cannot appear between top level loops." +class BarrierOnlyAllowedAtTopLevel(BaseError): + message = "hl.barrier() is only supported between top level hl.tile/hl.grid loops." + + +class BarrierRequiresPersistent(BaseError): + message = "hl.barrier() requires pid_type to be persistent (got '{0}')." + + class NestedGridLoop(BaseError): message = "Grid loops must be at the top level of a function." diff --git a/helion/language/__init__.py b/helion/language/__init__.py index c8324ebe5..fc734a2c0 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -8,6 +8,7 @@ from .atomic_ops import atomic_or as atomic_or from .atomic_ops import atomic_xchg as atomic_xchg from .atomic_ops import atomic_xor as atomic_xor +from .barrier import barrier as barrier from .builtin_ops import _builtin_max as _builtin_max from .builtin_ops import _builtin_min as _builtin_min from .constexpr import ConstExpr as constexpr # noqa: F401 diff --git a/helion/language/barrier.py b/helion/language/barrier.py new file mode 100644 index 000000000..fbfab1656 --- /dev/null +++ b/helion/language/barrier.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING + +from .. import exc +from .._compiler.ast_extension import expr_from_string +from .._compiler.compile_environment import CompileEnvironment +from .._compiler.type_propagation import BarrierResultType +from .._compiler.type_propagation import LiteralType +from . import _decorators + +if TYPE_CHECKING: + from .._compiler.inductor_lowering import CodegenState + from .._compiler.variable_origin import Origin + +__all__ = ["barrier"] + + +@_decorators.api( + is_device_loop=False, + is_device_only=False, + cache_type=True, + signature=inspect.signature(lambda: None), +) +def barrier() -> None: + """Grid-wide barrier separating top-level `hl.tile` / `hl.grid` loops.""" + raise exc.NotInsideKernel + + +@_decorators.type_propagation(barrier) +def _(origin: Origin, **kwargs: object) -> LiteralType: + # Only allowed on the host between top-level device loops. + if origin.is_device(): + raise exc.BarrierOnlyAllowedAtTopLevel + + # A barrier introduces a sequential phase boundary between top-level loops, + # so force persistent kernels (other PID choices are incompatible). + env = CompileEnvironment.current() + env.has_barrier = True + for disallowed in ("flat", "xyz", "persistent_interleaved"): + env.config_spec.disallow_pid_type(disallowed) + + # Return None literal with a dedicated marker type. + return BarrierResultType(origin=origin, value=None) + + +@_decorators.codegen(barrier, "triton") +def _(state: CodegenState) -> object: + # No device code emitted; barrier only affects host-side scheduling. + return expr_from_string("None") + + +@_decorators.ref(barrier) +def _() -> None: + # No-op in ref/interpret mode + return None diff --git a/helion/language/loops.py b/helion/language/loops.py index e8bda9815..9316d20d0 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -571,8 +571,6 @@ def _codegen_loop_helper( indices = cast("list[TileIndexType | GridIndexType]", indices_raw) if loop_type == LoopType.GRID: - env = CompileEnvironment.current() - env.loop_dependency_checker.register_loop(for_loop) block_ids = [t.block_id for t in indices] state.tile_strategy.codegen_grid(state, block_ids) return expr_from_string("None") diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 875074d9b..2b33881d9 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -80,6 +80,7 @@ def default_launcher( *args: object, num_warps: int, num_stages: int, + launch_cooperative_grid: bool = False, **kwargs: dict, ) -> object: """Default launcher function that executes the kernel immediately.""" @@ -89,5 +90,6 @@ def default_launcher( warmup=False, num_warps=num_warps, num_stages=num_stages, + launch_cooperative_grid=launch_cooperative_grid, **kwargs, ) diff --git a/test/test_barrier.expected b/test/test_barrier.expected new file mode 100644 index 000000000..e73d41f0a --- /dev/null +++ b/test/test_barrier.expected @@ -0,0 +1,454 @@ +This file is automatically generated by assertExpectedJournal calls in test_barrier.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestBarrier.test_dep_across_barrier) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_barrier_dep_single(x, tmp, out, x_grid_sem, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = tmp[t] + 1 + total_pids = tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + block_size = tl.cdiv(total_pids, _NUM_SM) + start_pid = tl.program_id(0) * block_size + end_pid = tl.minimum(start_pid + block_size, total_pids) + for virtual_pid in tl.range(tl.maximum(start_pid, 0), tl.minimum(end_pid, 0 + tl.cdiv(8, _BLOCK_SIZE_0))): + if virtual_pid >= 0 and virtual_pid < 0 + tl.cdiv(8, _BLOCK_SIZE_0): + pid_shared = virtual_pid + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: tmp[t] = x[t] * 2 + if pid_shared < tl.cdiv(8, _BLOCK_SIZE_0): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_0 = pid_shared + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_barrier.py:N]: tmp[t] = x[t] * 2 + load = tl.load(x + indices_0 * 1, None) + v_0 = 2.0 + v_1 = load * v_0 + tl.store(tmp + indices_0 * 1, v_1, None) + else: + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + pid_1 = pid_shared + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[test_barrier.py:N]: out[t] = tmp[t] + 1 + load_1 = tl.load(tmp + indices_1 * 1, None) + v_2 = 1.0 + v_3 = load_1 + v_2 + tl.store(out + indices_1 * 1, v_3, None) + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = tmp[t] + 1 + triton_helpers.x_grid_barrier(x_grid_sem) + for virtual_pid in tl.range(tl.maximum(start_pid, 0 + tl.cdiv(8, _BLOCK_SIZE_0)), tl.minimum(end_pid, 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1))): + if virtual_pid >= 0 + tl.cdiv(8, _BLOCK_SIZE_0) and virtual_pid < 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1): + pid_shared = virtual_pid + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: tmp[t] = x[t] * 2 + if pid_shared < tl.cdiv(8, _BLOCK_SIZE_0): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_0 = pid_shared + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_barrier.py:N]: tmp[t] = x[t] * 2 + load = tl.load(x + indices_0 * 1, None) + v_0 = 2.0 + v_1 = load * v_0 + tl.store(tmp + indices_0 * 1, v_1, None) + else: + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + pid_1 = pid_shared + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[test_barrier.py:N]: out[t] = tmp[t] + 1 + load_1 = tl.load(tmp + indices_1 * 1, None) + v_2 = 1.0 + v_3 = load_1 + v_2 + tl.store(out + indices_1 * 1, v_3, None) + +def barrier_dep_single(x: torch.Tensor, *, _launcher=_default_launcher): + # src[test_barrier.py:N]: tmp = torch.empty_like(x) + tmp = torch.empty_like(x) + # src[test_barrier.py:N]: out = torch.empty_like(x) + out = torch.empty_like(x) + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + _NUM_SM = helion.runtime.get_num_sm(x.device) + _BLOCK_SIZE_0 = 8 + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + _BLOCK_SIZE_1 = 8 + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = tmp[t] + 1 + _launcher(_helion_barrier_dep_single, (_NUM_SM,), x, tmp, out, torch.zeros((1,), device=x.device, dtype=torch.uint32), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1, launch_cooperative_grid=True) + # src[test_barrier.py:N]: return out + return out + +--- assertExpectedJournal(TestBarrier.test_multiple_barriers) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_barrier_multiple(x, buf1, buf2, out, x_grid_sem, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = buf2[t] - 5 + total_pids = tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1) + tl.cdiv(6, _BLOCK_SIZE_2) + block_size = tl.cdiv(total_pids, _NUM_SM) + start_pid = tl.program_id(0) * block_size + end_pid = tl.minimum(start_pid + block_size, total_pids) + for virtual_pid in tl.range(tl.maximum(start_pid, 0), tl.minimum(end_pid, 0 + tl.cdiv(6, _BLOCK_SIZE_0))): + if virtual_pid >= 0 and virtual_pid < 0 + tl.cdiv(6, _BLOCK_SIZE_0): + pid_shared = virtual_pid + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: buf1[t] = x[t] + 3 + if pid_shared < tl.cdiv(6, _BLOCK_SIZE_0): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_0 = pid_shared + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 6 + # src[test_barrier.py:N]: buf1[t] = x[t] + 3 + load = tl.load(x + indices_0 * 1, mask_0, other=0) + v_0 = 3.0 + v_1 = load + v_0 + tl.store(buf1 + indices_0 * 1, v_1, mask_0) + elif pid_shared < tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(6, _BLOCK_SIZE_0) + pid_1 = pid_shared + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < 6 + # src[test_barrier.py:N]: buf2[t] = buf1[t] * 2 + load_1 = tl.load(buf1 + indices_1 * 1, mask_1, other=0) + v_2 = 2.0 + v_3 = load_1 * v_2 + tl.store(buf2 + indices_1 * 1, v_3, mask_1) + else: + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1) + pid_2 = pid_shared + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < 6 + # src[test_barrier.py:N]: out[t] = buf2[t] - 5 + load_2 = tl.load(buf2 + indices_2 * 1, mask_2, other=0) + v_4 = 5.0 + v_5 = load_2 - v_4 + tl.store(out + indices_2 * 1, v_5, mask_2) + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = buf2[t] - 5 + triton_helpers.x_grid_barrier(x_grid_sem) + for virtual_pid in tl.range(tl.maximum(start_pid, 0 + tl.cdiv(6, _BLOCK_SIZE_0)), tl.minimum(end_pid, 0 + tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1))): + if virtual_pid >= 0 + tl.cdiv(6, _BLOCK_SIZE_0) and virtual_pid < 0 + tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1): + pid_shared = virtual_pid + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: buf1[t] = x[t] + 3 + if pid_shared < tl.cdiv(6, _BLOCK_SIZE_0): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_0 = pid_shared + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 6 + # src[test_barrier.py:N]: buf1[t] = x[t] + 3 + load = tl.load(x + indices_0 * 1, mask_0, other=0) + v_0 = 3.0 + v_1 = load + v_0 + tl.store(buf1 + indices_0 * 1, v_1, mask_0) + elif pid_shared < tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(6, _BLOCK_SIZE_0) + pid_1 = pid_shared + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < 6 + # src[test_barrier.py:N]: buf2[t] = buf1[t] * 2 + load_1 = tl.load(buf1 + indices_1 * 1, mask_1, other=0) + v_2 = 2.0 + v_3 = load_1 * v_2 + tl.store(buf2 + indices_1 * 1, v_3, mask_1) + else: + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1) + pid_2 = pid_shared + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < 6 + # src[test_barrier.py:N]: out[t] = buf2[t] - 5 + load_2 = tl.load(buf2 + indices_2 * 1, mask_2, other=0) + v_4 = 5.0 + v_5 = load_2 - v_4 + tl.store(out + indices_2 * 1, v_5, mask_2) + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = buf2[t] - 5 + triton_helpers.x_grid_barrier(x_grid_sem) + for virtual_pid in tl.range(tl.maximum(start_pid, 0 + tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1)), tl.minimum(end_pid, 0 + tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1) + tl.cdiv(6, _BLOCK_SIZE_2))): + if virtual_pid >= 0 + tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1) and virtual_pid < 0 + tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1) + tl.cdiv(6, _BLOCK_SIZE_2): + pid_shared = virtual_pid + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: buf1[t] = x[t] + 3 + if pid_shared < tl.cdiv(6, _BLOCK_SIZE_0): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_0 = pid_shared + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 6 + # src[test_barrier.py:N]: buf1[t] = x[t] + 3 + load = tl.load(x + indices_0 * 1, mask_0, other=0) + v_0 = 3.0 + v_1 = load + v_0 + tl.store(buf1 + indices_0 * 1, v_1, mask_0) + elif pid_shared < tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(6, _BLOCK_SIZE_0) + pid_1 = pid_shared + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < 6 + # src[test_barrier.py:N]: buf2[t] = buf1[t] * 2 + load_1 = tl.load(buf1 + indices_1 * 1, mask_1, other=0) + v_2 = 2.0 + v_3 = load_1 * v_2 + tl.store(buf2 + indices_1 * 1, v_3, mask_1) + else: + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(6, _BLOCK_SIZE_0) + tl.cdiv(6, _BLOCK_SIZE_1) + pid_2 = pid_shared + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < 6 + # src[test_barrier.py:N]: out[t] = buf2[t] - 5 + load_2 = tl.load(buf2 + indices_2 * 1, mask_2, other=0) + v_4 = 5.0 + v_5 = load_2 - v_4 + tl.store(out + indices_2 * 1, v_5, mask_2) + +def barrier_multiple(x: torch.Tensor, *, _launcher=_default_launcher): + # src[test_barrier.py:N]: buf1 = torch.empty_like(x) + buf1 = torch.empty_like(x) + # src[test_barrier.py:N]: buf2 = torch.empty_like(x) + buf2 = torch.empty_like(x) + # src[test_barrier.py:N]: out = torch.empty_like(x) + out = torch.empty_like(x) + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + _NUM_SM = helion.runtime.get_num_sm(x.device) + _BLOCK_SIZE_0 = 8 + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + _BLOCK_SIZE_1 = 8 + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + _BLOCK_SIZE_2 = 8 + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = buf2[t] - 5 + _launcher(_helion_barrier_multiple, (_NUM_SM,), x, buf1, buf2, out, torch.zeros((1,), device=x.device, dtype=torch.uint32), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1, launch_cooperative_grid=True) + # src[test_barrier.py:N]: return out + return out + +--- assertExpectedJournal(TestBarrier.test_multiple_loops_between_barriers) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_barrier_groups(x, buf, y, buf2, out, x_grid_sem, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = out[t] + 7 + total_pids = tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2) + tl.cdiv(8, _BLOCK_SIZE_3) + block_size = tl.cdiv(total_pids, _NUM_SM) + start_pid = tl.program_id(0) * block_size + end_pid = tl.minimum(start_pid + block_size, total_pids) + for virtual_pid in tl.range(tl.maximum(start_pid, 0), tl.minimum(end_pid, 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1))): + if virtual_pid >= 0 and virtual_pid < 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1): + pid_shared = virtual_pid + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: buf[t] = x[t] + 1 + if pid_shared < tl.cdiv(8, _BLOCK_SIZE_0): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_0 = pid_shared + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_barrier.py:N]: buf[t] = x[t] + 1 + load = tl.load(x + indices_0 * 1, None) + v_0 = 1.0 + v_1 = load + v_0 + tl.store(buf + indices_0 * 1, v_1, None) + elif pid_shared < tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + pid_1 = pid_shared + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[test_barrier.py:N]: buf2[t] = y[t] + 5 + load_1 = tl.load(y + indices_1 * 1, None) + v_2 = 5.0 + v_3 = load_1 + v_2 + tl.store(buf2 + indices_1 * 1, v_3, None) + elif pid_shared < tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + pid_2 = pid_shared + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + # src[test_barrier.py:N]: out[t] = (buf[t] + buf2[t]) * 2 + load_2 = tl.load(buf + indices_2 * 1, None) + load_3 = tl.load(buf2 + indices_2 * 1, None) + v_4 = load_2 + load_3 + v_5 = 2.0 + v_6 = v_4 * v_5 + tl.store(out + indices_2 * 1, v_6, None) + else: + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2) + pid_3 = pid_shared + offset_3 = pid_3 * _BLOCK_SIZE_3 + indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32) + # src[test_barrier.py:N]: out[t] = out[t] + 7 + load_4 = tl.load(out + indices_3 * 1, None) + v_7 = 7.0 + v_8 = load_4 + v_7 + tl.store(out + indices_3 * 1, v_8, None) + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = out[t] + 7 + triton_helpers.x_grid_barrier(x_grid_sem) + for virtual_pid in tl.range(tl.maximum(start_pid, 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1)), tl.minimum(end_pid, 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2))): + if virtual_pid >= 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) and virtual_pid < 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2): + pid_shared = virtual_pid + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: buf[t] = x[t] + 1 + if pid_shared < tl.cdiv(8, _BLOCK_SIZE_0): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_0 = pid_shared + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_barrier.py:N]: buf[t] = x[t] + 1 + load = tl.load(x + indices_0 * 1, None) + v_0 = 1.0 + v_1 = load + v_0 + tl.store(buf + indices_0 * 1, v_1, None) + elif pid_shared < tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + pid_1 = pid_shared + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[test_barrier.py:N]: buf2[t] = y[t] + 5 + load_1 = tl.load(y + indices_1 * 1, None) + v_2 = 5.0 + v_3 = load_1 + v_2 + tl.store(buf2 + indices_1 * 1, v_3, None) + elif pid_shared < tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + pid_2 = pid_shared + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + # src[test_barrier.py:N]: out[t] = (buf[t] + buf2[t]) * 2 + load_2 = tl.load(buf + indices_2 * 1, None) + load_3 = tl.load(buf2 + indices_2 * 1, None) + v_4 = load_2 + load_3 + v_5 = 2.0 + v_6 = v_4 * v_5 + tl.store(out + indices_2 * 1, v_6, None) + else: + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2) + pid_3 = pid_shared + offset_3 = pid_3 * _BLOCK_SIZE_3 + indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32) + # src[test_barrier.py:N]: out[t] = out[t] + 7 + load_4 = tl.load(out + indices_3 * 1, None) + v_7 = 7.0 + v_8 = load_4 + v_7 + tl.store(out + indices_3 * 1, v_8, None) + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = out[t] + 7 + triton_helpers.x_grid_barrier(x_grid_sem) + for virtual_pid in tl.range(tl.maximum(start_pid, 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2)), tl.minimum(end_pid, 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2) + tl.cdiv(8, _BLOCK_SIZE_3))): + if virtual_pid >= 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2) and virtual_pid < 0 + tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2) + tl.cdiv(8, _BLOCK_SIZE_3): + pid_shared = virtual_pid + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: buf[t] = x[t] + 1 + if pid_shared < tl.cdiv(8, _BLOCK_SIZE_0): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_0 = pid_shared + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_barrier.py:N]: buf[t] = x[t] + 1 + load = tl.load(x + indices_0 * 1, None) + v_0 = 1.0 + v_1 = load + v_0 + tl.store(buf + indices_0 * 1, v_1, None) + elif pid_shared < tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + pid_1 = pid_shared + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[test_barrier.py:N]: buf2[t] = y[t] + 5 + load_1 = tl.load(y + indices_1 * 1, None) + v_2 = 5.0 + v_3 = load_1 + v_2 + tl.store(buf2 + indices_1 * 1, v_3, None) + elif pid_shared < tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2): + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + pid_2 = pid_shared + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + # src[test_barrier.py:N]: out[t] = (buf[t] + buf2[t]) * 2 + load_2 = tl.load(buf + indices_2 * 1, None) + load_3 = tl.load(buf2 + indices_2 * 1, None) + v_4 = load_2 + load_3 + v_5 = 2.0 + v_6 = v_4 * v_5 + tl.store(out + indices_2 * 1, v_6, None) + else: + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + pid_shared -= tl.cdiv(8, _BLOCK_SIZE_0) + tl.cdiv(8, _BLOCK_SIZE_1) + tl.cdiv(8, _BLOCK_SIZE_2) + pid_3 = pid_shared + offset_3 = pid_3 * _BLOCK_SIZE_3 + indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32) + # src[test_barrier.py:N]: out[t] = out[t] + 7 + load_4 = tl.load(out + indices_3 * 1, None) + v_7 = 7.0 + v_8 = load_4 + v_7 + tl.store(out + indices_3 * 1, v_8, None) + +def barrier_groups(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_barrier.py:N]: buf = torch.empty_like(x) + buf = torch.empty_like(x) + # src[test_barrier.py:N]: buf2 = torch.empty_like(x) + buf2 = torch.empty_like(x) + # src[test_barrier.py:N]: out = torch.empty_like(x) + out = torch.empty_like(x) + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + _NUM_SM = helion.runtime.get_num_sm(x.device) + _BLOCK_SIZE_0 = 8 + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + _BLOCK_SIZE_1 = 8 + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + _BLOCK_SIZE_2 = 8 + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + _BLOCK_SIZE_3 = 8 + # src[test_barrier.py:N]: for t in hl.tile(x.size(0)): + # src[test_barrier.py:N]: out[t] = out[t] + 7 + _launcher(_helion_barrier_groups, (_NUM_SM,), x, buf, y, buf2, out, torch.zeros((1,), device=x.device, dtype=torch.uint32), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=1, launch_cooperative_grid=True) + # src[test_barrier.py:N]: return out + return out diff --git a/test/test_barrier.py b/test/test_barrier.py new file mode 100644 index 000000000..42d10938f --- /dev/null +++ b/test/test_barrier.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import RefEagerTestBase +from helion._testing import TestCase +from helion._testing import code_and_output +from helion._testing import skipIfRefEager +import helion.exc as exc +import helion.language as hl + + +@helion.kernel() +def barrier_dep_single(x: torch.Tensor) -> torch.Tensor: + tmp = torch.empty_like(x) + out = torch.empty_like(x) + + for t in hl.tile(x.size(0)): + tmp[t] = x[t] * 2 + + hl.barrier() + + for t in hl.tile(x.size(0)): + out[t] = tmp[t] + 1 + + return out + + +@helion.kernel() +def barrier_multiple(x: torch.Tensor) -> torch.Tensor: + buf1 = torch.empty_like(x) + buf2 = torch.empty_like(x) + out = torch.empty_like(x) + + for t in hl.tile(x.size(0)): + buf1[t] = x[t] + 3 + + hl.barrier() + + for t in hl.tile(x.size(0)): + buf2[t] = buf1[t] * 2 + + hl.barrier() + + for t in hl.tile(x.size(0)): + out[t] = buf2[t] - 5 + + return out + + +@helion.kernel() +def barrier_groups(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + buf = torch.empty_like(x) + buf2 = torch.empty_like(x) + out = torch.empty_like(x) + + # group 1: independent loops + for t in hl.tile(x.size(0)): + buf[t] = x[t] + 1 + for t in hl.tile(x.size(0)): + buf2[t] = y[t] + 5 + + hl.barrier() + + # group 2: consumes both buffers + for t in hl.tile(x.size(0)): + out[t] = (buf[t] + buf2[t]) * 2 + + hl.barrier() + + for t in hl.tile(x.size(0)): + out[t] = out[t] + 7 + + return out + + +class TestBarrier(RefEagerTestBase, TestCase): + def test_dep_across_barrier(self) -> None: + x = torch.arange(8, device=DEVICE, dtype=torch.float32) + code, out = code_and_output( + barrier_dep_single, + (x,), + block_sizes=[8, 8], + pid_type="persistent_blocked", + ) + expected = x * 2 + 1 + torch.testing.assert_close(out, expected) + self.assertExpectedJournal(code) + + def test_multiple_barriers(self) -> None: + x = torch.arange(6, device=DEVICE, dtype=torch.float32) + code, out = code_and_output( + barrier_multiple, + (x,), + block_sizes=[8, 8, 8], + pid_type="persistent_blocked", + ) + expected = (x + 3) * 2 - 5 + torch.testing.assert_close(out, expected) + self.assertExpectedJournal(code) + + def test_multiple_loops_between_barriers(self) -> None: + x = torch.arange(8, device=DEVICE, dtype=torch.float32) + y = torch.arange(8, device=DEVICE, dtype=torch.float32) * 3 + code, out = code_and_output( + barrier_groups, + (x, y), + block_sizes=[8, 8, 8, 8], + pid_type="persistent_blocked", + ) + expected = ((x + 1) + (y + 5)) * 2 + 7 + torch.testing.assert_close(out, expected) + self.assertExpectedJournal(code) + + @skipIfRefEager("pid_type validation is only enforced in compiled mode") + def test_non_persistent_pid_type_errors(self) -> None: + x = torch.arange(4, device=DEVICE, dtype=torch.float32) + with self.assertRaisesRegex(exc.BarrierRequiresPersistent, "requires pid_type"): + code_and_output( + barrier_dep_single, + (x,), + block_sizes=[4, 4], + pid_type="flat", + ) + + def test_default_config_is_persistent(self) -> None: + x = torch.arange(4, device=DEVICE, dtype=torch.float32) + code, out = code_and_output( + barrier_dep_single, + (x,), + block_sizes=[4, 4], + pid_type="persistent_blocked", + ) + expected = x * 2 + 1 + torch.testing.assert_close(out, expected) + # Can't see pid_type in ref-mode code; rely on normalization to succeed. diff --git a/test/test_control_flow.expected b/test/test_control_flow.expected index 9df928683..ef195a568 100644 --- a/test/test_control_flow.expected +++ b/test/test_control_flow.expected @@ -48,7 +48,8 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_fn(x, out, _BLOCK_SIZE_0_1: tl.constexpr): # src[test_control_flow.py:N]: for tile in hl.tile(x.size()): - offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) + pid_flat = tl.program_id(0) + offsets_0_1 = pid_flat * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) indices_1 = offsets_0_1 % 512 indices_0 = offsets_0_1 // 512 # src[test_control_flow.py:N]: out[tile] = torch.sigmoid(x[tile]) @@ -65,7 +66,7 @@ def fn(x, *, _launcher=_default_launcher): # src[test_control_flow.py:N]: if 3 < v < 7: # src[test_control_flow.py:N]: out[tile] = torch.sigmoid(x[tile]) # src[test_control_flow.py:N-N]: ... - _launcher(_helion_fn, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) + _launcher(_helion_fn, (triton.cdiv(262144, _BLOCK_SIZE_0_1),), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[test_control_flow.py:N]: return out return out diff --git a/test/test_examples.expected b/test/test_examples.expected index dbb880280..5b300d157 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -12,7 +12,8 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_add(x, y, out, _BLOCK_SIZE_0_1: tl.constexpr): # src[add.py:N]: for tile in hl.tile(out.size()): - offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) + pid_flat = tl.program_id(0) + offsets_0_1 = pid_flat * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) indices_1 = offsets_0_1 % 512 indices_0 = offsets_0_1 // 512 # src[add.py:N]: out[tile] = x[tile] + y[tile] @@ -44,7 +45,7 @@ def add(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 128 # src[add.py:N]: for tile in hl.tile(out.size()): # src[add.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) + _launcher(_helion_add, (triton.cdiv(262144, _BLOCK_SIZE_0_1),), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[add.py:N]: return out return out @@ -5954,6 +5955,201 @@ def softmax_two_pass(x: torch.Tensor, *, _launcher=_default_launcher): # src[softmax.py:N]: return out return out +--- assertExpectedJournal(TestExamples.test_split_k_barrier) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +import helion._testing.split_k_barrier as _source_module + +@triton.jit +def _helion_split_k_matmul(a, b, tmp, out, x_grid_sem, tmp_stride_0, tmp_stride_1, split_k, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _RDIM_SIZE_6: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr, _BLOCK_SIZE_5: tl.constexpr): + # src[split_k_barrier.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[split_k_barrier.py:N]: out[tile_m, tile_n] = torch.sum(tmp[tile_m, tile_n, :], dim=-1) + total_pids = tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2) + tl.cdiv(64, _BLOCK_SIZE_4) * tl.cdiv(64, _BLOCK_SIZE_5) + block_size = tl.cdiv(total_pids, _NUM_SM) + start_pid = tl.program_id(0) * block_size + end_pid = tl.minimum(start_pid + block_size, total_pids) + for virtual_pid in tl.range(tl.maximum(start_pid, 0), tl.minimum(end_pid, 0 + tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2))): + if virtual_pid >= 0 and virtual_pid < 0 + tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2): + pid_shared = virtual_pid + # src[split_k_barrier.py:N]: for tile_m, tile_n, tile_k_outer in hl.tile( + # src[split_k_barrier.py:N]: [m, n, k], block_size=[None, None, block_k] + # src[split_k_barrier.py:N]: ): + # src[split_k_barrier.py:N-N]: ... + if pid_shared < tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2): + # src[split_k_barrier.py:N]: for tile_m, tile_n, tile_k_outer in hl.tile( + # src[split_k_barrier.py:N]: [m, n, k], block_size=[None, None, block_k] + # src[split_k_barrier.py:N]: ): + num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0) + num_blocks_1 = tl.cdiv(64, _BLOCK_SIZE_1) + pid_0 = pid_shared % num_blocks_0 + pid_1 = pid_shared // num_blocks_0 % num_blocks_1 + pid_2 = pid_shared // (num_blocks_0 * num_blocks_1) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_7 = tl.arange(0, _RDIM_SIZE_6).to(tl.int32) + mask_6 = indices_7 < split_k + # src[split_k_barrier.py:N]: acc = hl.zeros([tile_m, tile_n], device=a.device, dtype=a.dtype) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[split_k_barrier.py:N]: for tile_k_inner in hl.tile(tile_k_outer.begin, tile_k_outer.end): + tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, 512) + # src[split_k_barrier.py:N]: for tile_k_inner in hl.tile(tile_k_outer.begin, tile_k_outer.end): + # src[split_k_barrier.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k_inner], b[tile_k_inner, tile_n]) + for offset_6 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3): + indices_6 = offset_6 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) + mask_3 = indices_6 < tile_end + acc_copy = acc + acc_copy_0 = acc_copy + # src[split_k_barrier.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k_inner], b[tile_k_inner, tile_n]) + load = tl.load(a + (indices_0[:, None] * 512 + indices_6[None, :] * 1), mask_3[None, :], other=0) + load_1 = tl.load(b + (indices_6[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0) + acc = tl.dot(tl.cast(load, tl.float32), tl.cast(load_1, tl.float32), acc=acc_copy_0, input_precision='ieee', out_dtype=tl.float32) + # src[split_k_barrier.py:N]: tmp[tile_m, tile_n, tile_k_outer.id] = acc + tile_id = offset_2 // _BLOCK_SIZE_2 + tl.store(tmp + (indices_0[:, None] * tmp_stride_0 + indices_1[None, :] * tmp_stride_1 + tile_id * 1), acc, None) + else: + # src[split_k_barrier.py:N]: for tile_m, tile_n in hl.tile([m, n]): + pid_shared -= tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2) + num_blocks_2 = tl.cdiv(64, _BLOCK_SIZE_4) + pid_3 = pid_shared % num_blocks_2 + pid_4 = pid_shared // num_blocks_2 + offset_4 = pid_3 * _BLOCK_SIZE_4 + indices_4 = (offset_4 + tl.arange(0, _BLOCK_SIZE_4)).to(tl.int32) + offset_5 = pid_4 * _BLOCK_SIZE_5 + indices_5 = (offset_5 + tl.arange(0, _BLOCK_SIZE_5)).to(tl.int32) + indices_7 = tl.arange(0, _RDIM_SIZE_6).to(tl.int32) + mask_6 = indices_7 < split_k + # src[split_k_barrier.py:N]: out[tile_m, tile_n] = torch.sum(tmp[tile_m, tile_n, :], dim=-1) + load_2 = tl.load(tmp + (indices_4[:, None, None] * tmp_stride_0 + indices_5[None, :, None] * tmp_stride_1 + indices_7[None, None, :] * 1), mask_6[None, None, :], other=0) + sum_1 = tl.cast(tl.sum(load_2, 2), tl.float32) + tl.store(out + (indices_4[:, None] * 64 + indices_5[None, :] * 1), sum_1, None) + # src[split_k_barrier.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[split_k_barrier.py:N]: out[tile_m, tile_n] = torch.sum(tmp[tile_m, tile_n, :], dim=-1) + triton_helpers.x_grid_barrier(x_grid_sem) + for virtual_pid in tl.range(tl.maximum(start_pid, 0 + tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2)), tl.minimum(end_pid, 0 + tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2) + tl.cdiv(64, _BLOCK_SIZE_4) * tl.cdiv(64, _BLOCK_SIZE_5))): + if virtual_pid >= 0 + tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2) and virtual_pid < 0 + tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2) + tl.cdiv(64, _BLOCK_SIZE_4) * tl.cdiv(64, _BLOCK_SIZE_5): + pid_shared = virtual_pid + # src[split_k_barrier.py:N]: for tile_m, tile_n, tile_k_outer in hl.tile( + # src[split_k_barrier.py:N]: [m, n, k], block_size=[None, None, block_k] + # src[split_k_barrier.py:N]: ): + # src[split_k_barrier.py:N-N]: ... + if pid_shared < tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2): + # src[split_k_barrier.py:N]: for tile_m, tile_n, tile_k_outer in hl.tile( + # src[split_k_barrier.py:N]: [m, n, k], block_size=[None, None, block_k] + # src[split_k_barrier.py:N]: ): + num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0) + num_blocks_1 = tl.cdiv(64, _BLOCK_SIZE_1) + pid_0 = pid_shared % num_blocks_0 + pid_1 = pid_shared // num_blocks_0 % num_blocks_1 + pid_2 = pid_shared // (num_blocks_0 * num_blocks_1) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_7 = tl.arange(0, _RDIM_SIZE_6).to(tl.int32) + mask_6 = indices_7 < split_k + # src[split_k_barrier.py:N]: acc = hl.zeros([tile_m, tile_n], device=a.device, dtype=a.dtype) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[split_k_barrier.py:N]: for tile_k_inner in hl.tile(tile_k_outer.begin, tile_k_outer.end): + tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, 512) + # src[split_k_barrier.py:N]: for tile_k_inner in hl.tile(tile_k_outer.begin, tile_k_outer.end): + # src[split_k_barrier.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k_inner], b[tile_k_inner, tile_n]) + for offset_6 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3): + indices_6 = offset_6 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) + mask_3 = indices_6 < tile_end + acc_copy = acc + acc_copy_0 = acc_copy + # src[split_k_barrier.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k_inner], b[tile_k_inner, tile_n]) + load = tl.load(a + (indices_0[:, None] * 512 + indices_6[None, :] * 1), mask_3[None, :], other=0) + load_1 = tl.load(b + (indices_6[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0) + acc = tl.dot(tl.cast(load, tl.float32), tl.cast(load_1, tl.float32), acc=acc_copy_0, input_precision='ieee', out_dtype=tl.float32) + # src[split_k_barrier.py:N]: tmp[tile_m, tile_n, tile_k_outer.id] = acc + tile_id = offset_2 // _BLOCK_SIZE_2 + tl.store(tmp + (indices_0[:, None] * tmp_stride_0 + indices_1[None, :] * tmp_stride_1 + tile_id * 1), acc, None) + else: + # src[split_k_barrier.py:N]: for tile_m, tile_n in hl.tile([m, n]): + pid_shared -= tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(64, _BLOCK_SIZE_1) * tl.cdiv(512, _BLOCK_SIZE_2) + num_blocks_2 = tl.cdiv(64, _BLOCK_SIZE_4) + pid_3 = pid_shared % num_blocks_2 + pid_4 = pid_shared // num_blocks_2 + offset_4 = pid_3 * _BLOCK_SIZE_4 + indices_4 = (offset_4 + tl.arange(0, _BLOCK_SIZE_4)).to(tl.int32) + offset_5 = pid_4 * _BLOCK_SIZE_5 + indices_5 = (offset_5 + tl.arange(0, _BLOCK_SIZE_5)).to(tl.int32) + indices_7 = tl.arange(0, _RDIM_SIZE_6).to(tl.int32) + mask_6 = indices_7 < split_k + # src[split_k_barrier.py:N]: out[tile_m, tile_n] = torch.sum(tmp[tile_m, tile_n, :], dim=-1) + load_2 = tl.load(tmp + (indices_4[:, None, None] * tmp_stride_0 + indices_5[None, :, None] * tmp_stride_1 + indices_7[None, None, :] * 1), mask_6[None, None, :], other=0) + sum_1 = tl.cast(tl.sum(load_2, 2), tl.float32) + tl.store(out + (indices_4[:, None] * 64 + indices_5[None, :] * 1), sum_1, None) + +def split_k_matmul(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_launcher): + """ + Two-stage split-K matmul using hl.barrier(). The barrier approach + gives deterministic results as opposed to the atomic_add approach. + + Stage 1: + - Split K into `split_k` contiguous chunks. + - Each chunk computes a partial [tile_m, tile_n] product into its own slice of `tmp`. + + Barrier: + - Grid-wide barrier to ensure all partials are written before reduction. + + Stage 2: + - Reduce partials across the split dimension and write `out`. + + Shapes: + a: [M, K] + b: [K, N] + tmp: [M, N, split_k] + out: [M, N] + + Notes: + - Static shapes keep codegen simpler. + - `split_k` is fixed for clarity; autotuning could choose it instead. + """ + # src[split_k_barrier.py:N]: m, k = a.shape + m, k = a.shape + # src[split_k_barrier.py:N]: _, n = b.shape + _, n = b.shape + # src[split_k_barrier.py:N]: split_k = hl.register_tunable("split_k", PowerOfTwoFragment(16, 512, 64)) + split_k = 64 + # src[split_k_barrier.py:N]: block_k = helion.next_power_of_2(helion.cdiv(k, split_k)) + block_k = helion.next_power_of_2(helion.cdiv(k, split_k)) + # src[split_k_barrier.py:N]: tmp = torch.empty((m, n, split_k), device=a.device, dtype=a.dtype) + tmp = torch.empty((m, n, split_k), device=a.device, dtype=a.dtype) + # src[split_k_barrier.py:N]: out = torch.empty((m, n), device=a.device, dtype=a.dtype) + out = torch.empty((m, n), device=a.device, dtype=a.dtype) + # src[split_k_barrier.py:N]: for tile_m, tile_n, tile_k_outer in hl.tile( + # src[split_k_barrier.py:N]: [m, n, k], block_size=[None, None, block_k] + # src[split_k_barrier.py:N]: ): + _NUM_SM = helion.runtime.get_num_sm(a.device) + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 8 + _BLOCK_SIZE_2 = block_k + _RDIM_SIZE_6 = triton.next_power_of_2(split_k) + # src[split_k_barrier.py:N]: for tile_k_inner in hl.tile(tile_k_outer.begin, tile_k_outer.end): + # src[split_k_barrier.py:N]: acc = torch.addmm(acc, a[tile_m, tile_k_inner], b[tile_k_inner, tile_n]) + _BLOCK_SIZE_3 = 16 + # src[split_k_barrier.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_4 = 16 + _BLOCK_SIZE_5 = 16 + # src[split_k_barrier.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[split_k_barrier.py:N]: out[tile_m, tile_n] = torch.sum(tmp[tile_m, tile_n, :], dim=-1) + _launcher(_helion_split_k_matmul, (_NUM_SM,), a, b, tmp, out, torch.zeros((1,), device=a.device, dtype=torch.uint32), tmp.stride(0), tmp.stride(1), split_k, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _RDIM_SIZE_6, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=1, launch_cooperative_grid=True) + # src[split_k_barrier.py:N]: return out + return out + --- assertExpectedJournal(TestExamples.test_squeeze_and_excitation_net_bwd_da) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 31a796a5f..a5c36938c 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -53,6 +53,24 @@ def test_matmul(self): ) ) + def test_split_k_barrier(self): + m, k, n = 64, 512, 64 + a = torch.randn([m, k], device=DEVICE, dtype=torch.float32) + b = torch.randn([k, n], device=DEVICE, dtype=torch.float32) + expected = a @ b + + self.assertExpectedJournal( + check_example( + "split_k_barrier", + (a, b), + expected, + fn_name="split_k_matmul", + block_sizes=[16, 8, 16, 16, 16], + pid_type="persistent_blocked", + split_k=64, + ) + ) + def test_matmul_bwd(self): """Test backward pass for matmul computation.""" # Create tensors with requires_grad=True like rms_norm_bwd test diff --git a/test/test_generate_ast.expected b/test/test_generate_ast.expected index 43714832e..30169e61f 100644 --- a/test/test_generate_ast.expected +++ b/test/test_generate_ast.expected @@ -45,7 +45,8 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_add(x, y, out, _BLOCK_SIZE_0_1: tl.constexpr): # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): - offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) + pid_flat = tl.program_id(0) + offsets_0_1 = pid_flat * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) indices_1 = offsets_0_1 % 500 indices_0 = offsets_0_1 // 500 mask_0_1 = offsets_0_1 < 50000 @@ -64,7 +65,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 1024 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(50000, _BLOCK_SIZE_0_1), 1, 1), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) + _launcher(_helion_add, (triton.cdiv(50000, _BLOCK_SIZE_0_1),), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -79,7 +80,8 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_add(x, y, out, _BLOCK_SIZE_0_1: tl.constexpr): # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): - offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) + pid_flat = tl.program_id(0) + offsets_0_1 = pid_flat * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) indices_0 = offsets_0_1 % 100 indices_1 = offsets_0_1 // 100 mask_0_1 = offsets_0_1 < 50000 @@ -98,7 +100,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 1024 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(50000, _BLOCK_SIZE_0_1), 1, 1), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) + _launcher(_helion_add, (triton.cdiv(50000, _BLOCK_SIZE_0_1),), x, y, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -113,7 +115,8 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_add(x, y, out, _BLOCK_SIZE_0_1_2: tl.constexpr): # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): - offsets_0_1_2 = tl.program_id(0) * _BLOCK_SIZE_0_1_2 + tl.arange(0, _BLOCK_SIZE_0_1_2).to(tl.int32) + pid_flat = tl.program_id(0) + offsets_0_1_2 = pid_flat * _BLOCK_SIZE_0_1_2 + tl.arange(0, _BLOCK_SIZE_0_1_2).to(tl.int32) indices_2 = offsets_0_1_2 % 10 indices_1 = offsets_0_1_2 // 10 % 500 indices_0 = offsets_0_1_2 // 5000 @@ -133,7 +136,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1_2 = 1024 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(500000, _BLOCK_SIZE_0_1_2), 1, 1), x, y, out, _BLOCK_SIZE_0_1_2, num_warps=4, num_stages=1) + _launcher(_helion_add, (triton.cdiv(500000, _BLOCK_SIZE_0_1_2),), x, y, out, _BLOCK_SIZE_0_1_2, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -148,7 +151,8 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_add(x, y, out, _BLOCK_SIZE_0_1_2: tl.constexpr): # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): - offsets_0_1_2 = tl.program_id(0) * _BLOCK_SIZE_0_1_2 + tl.arange(0, _BLOCK_SIZE_0_1_2).to(tl.int32) + pid_flat = tl.program_id(0) + offsets_0_1_2 = pid_flat * _BLOCK_SIZE_0_1_2 + tl.arange(0, _BLOCK_SIZE_0_1_2).to(tl.int32) indices_1 = offsets_0_1_2 % 500 indices_0 = offsets_0_1_2 // 500 % 100 indices_2 = offsets_0_1_2 // 50000 @@ -168,7 +172,7 @@ def add(x, y, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1_2 = 1024 # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): # src[basic_kernels.py:N]: out[tile] = x[tile] + y[tile] - _launcher(_helion_add, (triton.cdiv(500000, _BLOCK_SIZE_0_1_2), 1, 1), x, y, out, _BLOCK_SIZE_0_1_2, num_warps=4, num_stages=1) + _launcher(_helion_add, (triton.cdiv(500000, _BLOCK_SIZE_0_1_2),), x, y, out, _BLOCK_SIZE_0_1_2, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -434,7 +438,8 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_hl_zeros_usage(x, out, _BLOCK_SIZE_0_1: tl.constexpr): # src[basic_kernels.py:N]: for tile in hl.tile(out.size()): - offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) + pid_flat = tl.program_id(0) + offsets_0_1 = pid_flat * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) indices_1 = offsets_0_1 % 512 indices_0 = offsets_0_1 // 512 # src[basic_kernels.py:N]: tmp = hl.zeros(tile, dtype=x.dtype) @@ -457,7 +462,7 @@ def hl_zeros_usage(x: torch.Tensor, *, _launcher=_default_launcher): # src[basic_kernels.py:N]: tmp = hl.zeros(tile, dtype=x.dtype) # src[basic_kernels.py:N]: tmp += x[tile] # src[basic_kernels.py:N-N]: ... - _launcher(_helion_hl_zeros_usage, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) + _launcher(_helion_hl_zeros_usage, (triton.cdiv(262144, _BLOCK_SIZE_0_1),), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return out return out @@ -515,7 +520,8 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_inplace_mul(x, c, _BLOCK_SIZE_0_1: tl.constexpr): # src[basic_kernels.py:N]: for tile in hl.tile(x.size()): - offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) + pid_flat = tl.program_id(0) + offsets_0_1 = pid_flat * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) indices_1 = offsets_0_1 % 512 indices_0 = offsets_0_1 // 512 # src[basic_kernels.py:N]: x[tile] *= c @@ -531,7 +537,7 @@ def inplace_mul(x, c, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 128 # src[basic_kernels.py:N]: for tile in hl.tile(x.size()): # src[basic_kernels.py:N]: x[tile] *= c - _launcher(_helion_inplace_mul, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, c, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) + _launcher(_helion_inplace_mul, (triton.cdiv(262144, _BLOCK_SIZE_0_1),), x, c, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[basic_kernels.py:N]: return x return x diff --git a/test/test_loops.expected b/test/test_loops.expected index 6b98854f7..1b74a2ee2 100644 --- a/test/test_loops.expected +++ b/test/test_loops.expected @@ -560,7 +560,8 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_silu_kernel(x, out, _BLOCK_SIZE_0_1: tl.constexpr): # src[test_loops.py:N]: for tile in hl.tile(out.size()): - offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) + pid_flat = tl.program_id(0) + offsets_0_1 = pid_flat * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) indices_1 = offsets_0_1 mask_0_1 = offsets_0_1 < 100 # src[test_loops.py:N]: out[tile] = x[tile] * torch.sigmoid(x[tile]) @@ -577,7 +578,7 @@ def silu_kernel(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 16 # src[test_loops.py:N]: for tile in hl.tile(out.size()): # src[test_loops.py:N]: out[tile] = x[tile] * torch.sigmoid(x[tile]) - _launcher(_helion_silu_kernel, (triton.cdiv(100, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=1, num_stages=1) + _launcher(_helion_silu_kernel, (triton.cdiv(100, _BLOCK_SIZE_0_1),), x, out, _BLOCK_SIZE_0_1, num_warps=1, num_stages=1) # src[test_loops.py:N]: return out return out diff --git a/test/test_specialize.expected b/test/test_specialize.expected index c8b19be6f..cf88ec02c 100644 --- a/test/test_specialize.expected +++ b/test/test_specialize.expected @@ -347,7 +347,8 @@ from helion.runtime import default_launcher as _default_launcher @triton.jit def _helion_fn(x, out, _BLOCK_SIZE_0_1: tl.constexpr): # src[test_specialize.py:N]: for tile in hl.tile(x.size()): - offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) + pid_flat = tl.program_id(0) + offsets_0_1 = pid_flat * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32) indices_1 = offsets_0_1 % 512 indices_0 = offsets_0_1 // 512 # src[test_specialize.py:N]: out[tile] = x[tile] * scale @@ -365,7 +366,7 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0_1 = 32 # src[test_specialize.py:N]: for tile in hl.tile(x.size()): # src[test_specialize.py:N]: out[tile] = x[tile] * scale - _launcher(_helion_fn, (triton.cdiv(262144, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) + _launcher(_helion_fn, (triton.cdiv(262144, _BLOCK_SIZE_0_1),), x, out, _BLOCK_SIZE_0_1, num_warps=4, num_stages=1) # src[test_specialize.py:N]: return out return out