From 88bbbea02edd32da9ea1a43057e402669f79a618 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Nov 2025 17:34:04 -0800 Subject: [PATCH] Refactor inductor_lowering.py into two files stack-info: PR: https://github.com/pytorch/helion/pull/1103, branch: jansel/stack/223 --- helion/_compiler/aten_lowering.py | 524 ++++++++++++++++++++++++++ helion/_compiler/inductor_lowering.py | 522 +------------------------ helion/_compiler/roll_reduction.py | 2 +- 3 files changed, 543 insertions(+), 505 deletions(-) create mode 100644 helion/_compiler/aten_lowering.py diff --git a/helion/_compiler/aten_lowering.py b/helion/_compiler/aten_lowering.py new file mode 100644 index 000000000..e1c15ab93 --- /dev/null +++ b/helion/_compiler/aten_lowering.py @@ -0,0 +1,524 @@ +from __future__ import annotations + +import ast +from collections.abc import Callable +import dataclasses +from operator import getitem +from typing import TYPE_CHECKING +from typing import cast + +import torch +from torch._inductor.codegen.simd import constant_repr +from torch._inductor.utils import triton_type +from torch.fx.node import Argument +from torch.fx.node import Node +from torch.fx.node import map_arg +from triton import next_power_of_2 + +from ..language.matmul_ops import enforce_dot_requirements +from .ast_extension import create +from .ast_extension import expr_from_string +from .ast_extension import statement_from_string +from .compile_environment import CompileEnvironment +from .matmul_utils import emit_tl_dot_with_padding +from .node_masking import apply_masking +from .node_masking import cached_masked_value +from .node_masking import getitem_masked_value + +if TYPE_CHECKING: + from .helper_function import CodegenInterface + + +class LoweringContext: + cg: CodegenInterface + env: dict[Node, Argument] + + def to_ast(self, value: object) -> ast.AST: + raise NotImplementedError + + +class Lowering: + def codegen(self, ctx: LoweringContext, node: Node) -> object: + raise NotImplementedError + + def get_masked_value(self, node: Node) -> float | bool | None: + """Get the masked value for this node.""" + return None + + +MaskedValueFn = Callable[[Node], float | bool | None] +CodegenHandler = Callable[[LoweringContext, Node], object] + + +def _env_arg(ctx: LoweringContext, node: Node) -> Argument: + return cast("Argument", ctx.env[node]) + + +@dataclasses.dataclass +class LambdaLowering(Lowering): + fn: Callable[..., object] + masked_value_fn: MaskedValueFn | None = None + + def codegen(self, ctx: LoweringContext, node: Node) -> object: + return self.fn(ctx, node) + + def get_masked_value(self, node: Node) -> float | bool | None: + if self.masked_value_fn is not None: + return self.masked_value_fn(node) + return None + + +def passthrough_masked_value(node: Node) -> float | bool | None: + for input_node in node.all_input_nodes: + if isinstance(input_node.meta["val"], torch.Tensor): + return cached_masked_value(input_node) + return None + + +aten_lowering_dispatch: dict[object, Callable[[Node], Lowering]] = {} + + +def default_make_lowering( + handler: CodegenHandler, + node: Node, + masked_value_fn: MaskedValueFn | None = None, +) -> Lowering: + return LambdaLowering(handler, masked_value_fn=masked_value_fn) + + +def register_lowering( + fn: object, + make_lowering: Callable[[CodegenHandler, Node], Lowering] = default_make_lowering, + masked_value_fn: MaskedValueFn | None = None, +) -> Callable[[CodegenHandler], CodegenHandler]: + def decorator(handler: CodegenHandler) -> CodegenHandler: + assert fn not in aten_lowering_dispatch, f"Lowering for {fn} already registered" + + aten_lowering_dispatch[fn] = lambda node: make_lowering( + handler, + node, + masked_value_fn=masked_value_fn, # pyright: ignore[reportCallIssue] + ) + return handler + + return decorator + + +@register_lowering(torch.ops.aten.sym_size.int) # pyright: ignore[reportAttributeAccessIssue] +def codegen_sym_size(ctx: LoweringContext, node: Node) -> object: + val = node.meta["val"] + assert isinstance( + val, (int, float, bool, torch.SymInt, torch.SymBool, torch.SymFloat) + ) + return val + + +@register_lowering(getitem, masked_value_fn=getitem_masked_value) +def codegen_getitem(ctx: LoweringContext, node: Node) -> object: + assert not node.kwargs, "getitem kwargs not supported" + lhs, rhs = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) + assert isinstance(lhs, (list, tuple)) + assert isinstance(rhs, int) + return lhs[rhs] + + +@register_lowering( + torch.ops.aten.full.default, # pyright: ignore[reportAttributeAccessIssue] + masked_value_fn=lambda n: ( + n.args[1] if isinstance(n.args[1], (int, float, bool)) else None + ), +) +def codegen_full(ctx: LoweringContext, node: Node) -> object: + env = CompileEnvironment.current() + size = map_arg(node.args[0], lambda n: n.meta["val"]) + dtype = node.kwargs.get("dtype", torch.get_default_dtype()) + assert isinstance(dtype, torch.dtype) + device = node.kwargs.get("device", env.device) + assert device == env.device, f"expected {env.device}, got {device}" + assert not node.kwargs.get("pin_memory"), "pin_memory not supported" + value_ast = map_arg(node.args[1], lambda arg: _env_arg(ctx, arg)) + if isinstance(value_ast, (int, float, bool)): + value_ast = expr_from_string(constant_repr(value_ast)) + assert isinstance(value_ast, ast.AST), value_ast + shape_str = ctx.cg.device_function.tile_strategy.shape_str([*size]) # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable] + return expr_from_string( + f"tl.full({shape_str}, {{value}}, {triton_type(dtype)})", + value=value_ast, + ) + + +@register_lowering( + torch.ops.aten.unsqueeze.default, # pyright: ignore[reportAttributeAccessIssue] + masked_value_fn=passthrough_masked_value, +) +def codegen_unsqueeze(ctx: LoweringContext, node: Node) -> object: + assert not node.kwargs, "getitem kwargs not supported" + tensor, dim = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) + assert isinstance(tensor, ast.AST) + assert isinstance(dim, int) + ndim = node.args[0].meta["val"].ndim # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + if dim < 0: + dim += ndim + assert 0 <= dim <= ndim, f"Invalid dim {dim} for tensor with {ndim} dims" + args = [":"] * ndim + args.insert(dim, "None") + return expr_from_string( + f"{{tensor}}[{', '.join(args)}]", + tensor=tensor, + ) + + +@register_lowering(torch.ops.aten.squeeze.dim, masked_value_fn=passthrough_masked_value) # pyright: ignore[reportAttributeAccessIssue] +@register_lowering( + torch.ops.aten.view.default, # pyright: ignore[reportAttributeAccessIssue] + masked_value_fn=passthrough_masked_value, +) +@register_lowering( + torch.ops.aten.reshape.default, # pyright: ignore[reportAttributeAccessIssue] + masked_value_fn=passthrough_masked_value, +) +def codegen_view(ctx: LoweringContext, node: Node) -> object: + assert not node.kwargs, "view kwargs not supported" + tensor = map_arg(node.args[0], lambda arg: _env_arg(ctx, arg)) + assert isinstance(tensor, ast.AST) + shape_str = ctx.cg.device_function.tile_strategy.shape_str( + [*node.meta["val"].size()] + ) + return expr_from_string(f"tl.reshape({{tensor}}, {shape_str})", tensor=tensor) + + +@register_lowering( + torch.ops.aten.permute.default, # pyright: ignore[reportAttributeAccessIssue] + masked_value_fn=passthrough_masked_value, +) +def codegen_permute(ctx: LoweringContext, node: Node) -> object: + assert not node.kwargs, "getitem kwargs not supported" + tensor, dims = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) + assert isinstance(tensor, ast.AST) + dims = [*dims] # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable] + assert {*dims} == {*range(len(dims))}, dims + return expr_from_string( + f"tl.permute({{tensor}}, {dims!r})", + tensor=tensor, + ) + + +@register_lowering( + torch.ops.aten.stack.default, # pyright: ignore[reportAttributeAccessIssue] + masked_value_fn=passthrough_masked_value, +) +def codegen_stack(ctx: LoweringContext, node: Node) -> object: + tensors = node.args[0] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + + assert isinstance(tensors, (list, tuple)) + tensor_asts = [ctx.env[t] for t in tensors] # pyright: ignore[reportArgumentType] + n = len(tensor_asts) + + if n == 0: + raise ValueError("Cannot stack empty tensor list") + + # Round up to power of 2 for efficient masking + padded_size = 1 << (n - 1).bit_length() + + # Create index array [0, 1, 2, 3, ...] for tensor selection + idx = ctx.cg.device_function.new_var("stack_idx") + ctx.cg.add_statement(statement_from_string(f"{idx} = tl.arange(0, {padded_size})")) + + # Broadcast index to target dimension shape + # e.g., dim=0: [:, None, None], dim=1: [None, :, None], dim=2: [None, None, :] + bidx = ctx.cg.device_function.new_var("broadcast_idx") + assert isinstance(dim, int) + pattern = "[" + ", ".join(["None"] * dim + [":"] + ["None"] * max(0, 2 - dim)) + "]" + ctx.cg.add_statement(statement_from_string(f"{bidx} = {idx}{pattern}")) + + # Expand each input tensor along the stack dimension + expanded = [ctx.cg.device_function.new_var(f"expanded_{i}") for i in range(n)] + for var, tensor in zip(expanded, tensor_asts, strict=False): + tensor_ast = cast("ast.AST", tensor) + ctx.cg.add_statement( + statement_from_string(f"{var} = tl.expand_dims({{t}}, {dim})", t=tensor_ast) + ) + + # Initialize result with zeros + result = ctx.cg.device_function.new_var("stacked_result") + ctx.cg.add_statement( + statement_from_string(f"{result} = tl.zeros_like({expanded[0]})") + ) + + # Select each tensor using masks + for i in range(n): + mask = ctx.cg.device_function.new_var(f"mask_{i}") + ctx.cg.add_statement(statement_from_string(f"{mask} = {bidx} == {i}")) + ctx.cg.add_statement( + statement_from_string( + f"{result} = tl.where({mask}, {expanded[i]}, {result})" + ) + ) + + return expr_from_string(result) + + +@register_lowering( + torch.ops.aten.expand.default, # pyright: ignore[reportAttributeAccessIssue] + masked_value_fn=passthrough_masked_value, +) +def codegen_expand(ctx: LoweringContext, node: Node) -> object: + assert not node.kwargs, "getitem kwargs not supported" + tensor, _ = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) + assert isinstance(tensor, ast.AST) + val = node.meta["val"] + assert isinstance(val, torch.Tensor) + shape = [*val.size()] + if node.args[0].meta["val"].ndim != len(shape): # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + broadcasting = [":"] * len(shape) + for i in range(len(shape) - node.args[0].meta["val"].ndim): # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + broadcasting[i] = "None" + tensor = expr_from_string( + f"{{tensor}}[{', '.join(broadcasting)}]", tensor=tensor + ) + shape_str = ctx.cg.device_function.tile_strategy.shape_str(shape) + return expr_from_string( + f"tl.broadcast_to({{tensor}}, {shape_str})", + tensor=tensor, + ) + + +def apply_dot_requirements( + handler: CodegenHandler, + node: Node, + masked_value_fn: MaskedValueFn | None = None, +) -> Lowering: + """Apply min_dot_size requirements to the config_spec""" + assert not node.kwargs, "dot kwargs not supported" + assert len(node.args) in (2, 3) + lproxy, rproxy = map_arg(node.args[-2:], lambda arg: arg.meta["val"]) + assert isinstance(lproxy, torch.Tensor) + assert isinstance(rproxy, torch.Tensor) + # Update config spec min sizes for M, N, K + enforce_dot_requirements(lproxy, rproxy) + # inputs to the dot operation must be zero-masked + *maybe_acc, lnode, rnode = node.args + assert isinstance(lnode, Node) + assert isinstance(rnode, Node) + lnode = apply_masking(lnode, base_node=node, other=0) + rnode = apply_masking(rnode, base_node=node, other=0) + node.args = (*maybe_acc, lnode, rnode) + return LambdaLowering(handler, masked_value_fn=masked_value_fn) + + +def reduce_3d_dot(ctx: LoweringContext, node: Node, with_acc: bool) -> ast.AST: + acc = None + acc_node: Node | None = None + if with_acc: + acc, lhs, rhs = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) + assert isinstance(acc, ast.AST) + assert isinstance(node.args[0], Node) + acc_node = node.args[0] + lhs_node = node.args[1] + rhs_node = node.args[2] + else: + lhs, rhs = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) + lhs_node = node.args[0] + rhs_node = node.args[1] + assert isinstance(lhs, ast.AST) + assert isinstance(rhs, ast.AST) + assert isinstance(lhs_node, Node) + assert isinstance(rhs_node, Node) + + # Check if inputs are FP8 - if so, redirect user to hl.dot() + lhs_dtype = lhs_node.meta["val"].dtype + rhs_dtype = rhs_node.meta["val"].dtype + acc_dtype_meta: torch.dtype | None = None + if with_acc: + assert acc_node is not None + assert isinstance(acc_node, Node) + acc_dtype_meta = acc_node.meta["val"].dtype + if lhs_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and rhs_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ]: + raise NotImplementedError( + "FP8 GEMM via torch API is not supported yet. Please use hl.dot() instead." + ) + + lhs_shape = list(lhs_node.meta["val"].size()) # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + rhs_shape = list(rhs_node.meta["val"].size()) # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + acc_shape = ( + list(acc_node.meta["val"].size()) + if (with_acc and acc_node is not None) + else None + ) # pyright: ignore[reportOptionalMemberAccess] + + # Extract expected output dtype from FX node to match PyTorch eager mode behavior + out_dtype: torch.dtype | None = None + if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): + out_dtype = node.meta["val"].dtype + + return emit_tl_dot_with_padding( + lhs, + rhs, + acc if with_acc else None, + lhs_dtype, + rhs_dtype, + acc_dtype=acc_dtype_meta if with_acc else None, + out_dtype=out_dtype, + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + acc_shape=acc_shape, + ) + + +@register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] +@register_lowering(torch.ops.aten.mm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] +def codegen_mm(ctx: LoweringContext, node: Node) -> ast.AST: + assert not node.kwargs, "matmul kwargs not supported" + + return reduce_3d_dot(ctx, node, False) + + +@register_lowering(torch.ops.aten.addmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] +def codegen_addmm(ctx: LoweringContext, node: Node) -> ast.AST: + assert not node.kwargs, "addmm kwargs not supported" + return reduce_3d_dot(ctx, node, True) + + +@register_lowering(torch.ops.aten.baddbmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] +def codegen_baddbmm(ctx: LoweringContext, node: Node) -> ast.AST: + assert not node.kwargs, "baddbmm kwargs not supported" + return reduce_3d_dot(ctx, node, True) + + +@register_lowering(torch.ops.prims.iota.default) # pyright: ignore[reportAttributeAccessIssue] +def codegen_iota(ctx: LoweringContext, node: Node) -> object: + """Generate tl.arange for torch.ops.prims.iota.default operations with automatic power-of-2 padding.""" + start = node.kwargs.get("start", 0) + step = node.kwargs.get("step", 1) + dtype = ( + node.kwargs.get("dtype") or CompileEnvironment.current().settings.index_dtype + ) + assert isinstance(dtype, torch.dtype) + (length_arg,) = node.args # expecting a single argument for length + + # Pad static non-power-of-2 lengths to next power of 2 + length_expr = "{length}" + if isinstance(length_arg, int) and length_arg != next_power_of_2(length_arg): + length_expr = str(next_power_of_2(length_arg)) + + expr = f"tl.arange(0, {length_expr})" + if step != 1: + expr = f"{{step}} * {expr}" + if start != 0: + expr = f"{{start}} + {expr}" + if dtype != torch.int32: + expr = f"({expr}).to({triton_type(dtype)})" + return expr_from_string( + expr, + start=ctx.to_ast(start), + step=ctx.to_ast(step), + length=ctx.to_ast(length_arg), + ) + + +def _codegen_rng_op( + ctx: LoweringContext, + node: Node, + rng_function: str, +) -> object: + """Common codegen implementation for all RNG operations. + + Args: + ctx: The graph interpreter context + node: The FX node for this operation + rng_function: Either "rand" or "randn" + """ + from .generate_ast import GenerateAST + + assert rng_function in ["rand", "randn"] + assert isinstance(ctx.cg, GenerateAST) + + # Get unique seed index for this RNG operation + device_fn = ctx.cg.device_function + seed_index = device_fn.allocate_rng_seed() + + # Get dimensionality and dtype + assert hasattr(node, "meta") and "val" in node.meta + fake_value = node.meta["val"] + ndim = fake_value.ndim + dtype = node.kwargs.get("dtype", None) + + # Get dimension names for offset calculation + env = CompileEnvironment.current() + dim_names = [] + for size in fake_value.size(): + block_id = env.get_block_id(size) + assert block_id is not None + block_size = env.block_sizes[block_id].size + dim_names.append(device_fn.literal_expr(block_size)) + + offset_parts = [] + + for i in range(ndim): + # Create the index variable with proper broadcasting + index_expr = f"indices_{i}" + + # Add broadcasting slices for this dimension + # For 1D tensors, this will just be indices_0 with no slicing + slice_parts = [] + for j in range(ndim): + if j < i: + slice_parts.append("None") + elif j == i: + slice_parts.append(":") + else: + slice_parts.append("None") + + # Create the broadcasted index expression + if ndim == 1: + # For 1D, no broadcasting needed + broadcasted_index = index_expr + else: + broadcasted_index = f"{index_expr}[{', '.join(slice_parts)}]" + + # Calculate stride (product of dimensions after this one) + if i < ndim - 1: + # Use the actual dimension variable names + stride_parts = dim_names[i + 1 :] + stride_expr = " * ".join(stride_parts) + offset_parts.append(f"{broadcasted_index} * {stride_expr}") + else: + # Last dimension has no stride multiplication + offset_parts.append(broadcasted_index) + + offset_expr = expr_from_string(" + ".join(offset_parts)) + + # Load seed from buffer using the kernel parameter name + assert device_fn.rng_seed_buffer_param_name is not None + seed_expr = expr_from_string( + "tl.load({buffer} + {index})", + buffer=expr_from_string(device_fn.rng_seed_buffer_param_name), + index=create(ast.Constant, value=seed_index), + ) + + # Generate the RNG call + # Note: tl.rand() and tl.randn() always return float32 + rng_expr = expr_from_string( + f"tl.{rng_function}({{seed}}, {{offset}})", seed=seed_expr, offset=offset_expr + ) + + # Cast to target dtype only if explicitly specified + if dtype is not None: + assert isinstance(dtype, torch.dtype) + rng_expr = expr_from_string(f"{{val}}.to({triton_type(dtype)})", val=rng_expr) + + return rng_expr + + +@register_lowering(torch.ops.aten.rand.default) # pyright: ignore[reportAttributeAccessIssue] +def codegen_rand(ctx: LoweringContext, node: Node) -> object: + return _codegen_rng_op(ctx, node, "rand") + + +@register_lowering(torch.ops.aten.randn.default) # pyright: ignore[reportAttributeAccessIssue] +def codegen_randn(ctx: LoweringContext, node: Node) -> object: + return _codegen_rng_op(ctx, node, "randn") diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 0931dd44b..c0df920a4 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -6,9 +6,9 @@ import functools from operator import getitem from typing import TYPE_CHECKING -from typing import Callable from typing import ContextManager from typing import NamedTuple +from typing import cast import sympy import torch @@ -18,7 +18,6 @@ from torch._inductor.codegen.simd import ( SIMDKernelFeatures, # pyright: ignore[reportPrivateImportUsage] ) -from torch._inductor.codegen.simd import constant_repr from torch._inductor.codegen.triton import TritonKernel from torch._inductor.codegen.triton import TritonOverrides from torch._inductor.graph import GraphLowering @@ -37,33 +36,30 @@ from torch.fx.experimental import proxy_tensor from torch.fx.experimental.sym_node import SymNode from torch.fx.interpreter import Interpreter +from torch.fx.node import Argument from torch.fx.node import Node from torch.fx.node import map_arg -from triton import next_power_of_2 from .. import exc from ..exc import InductorLoweringError from ..language._decorators import APIFunc from ..language._decorators import is_api_func -from ..language.matmul_ops import enforce_dot_requirements from .ast_extension import ExtendedAST from .ast_extension import create from .ast_extension import expr_from_string from .ast_extension import statement_from_string +from .aten_lowering import Lowering +from .aten_lowering import LoweringContext +from .aten_lowering import aten_lowering_dispatch from .compile_environment import CompileEnvironment from .compile_environment import FixedBlockSizeSource from .device_function import VarInfo from .device_function import contains_only_block_size_symbols from .dtype_utils import cast_ast -from .matmul_utils import emit_tl_dot_with_padding -from .node_masking import apply_masking -from .node_masking import cached_masked_value -from .node_masking import getitem_masked_value from .node_masking import inductor_masked_value from .node_masking import mask_node_inputs if TYPE_CHECKING: - from collections.abc import Callable from collections.abc import Iterator from torch.utils._ordered_set import OrderedSet @@ -74,8 +70,6 @@ from .helper_function import CodegenInterface from .tile_dispatch import TileStrategyDispatch - CodegenHandler = Callable[["GraphInterpreter", torch.fx.Node], object] - INDUCTOR_PATCH: dict[str, object] = { # Allow implicit upcasts to FP32 for elementwise math correctness "triton.codegen_upcast_to_fp32": True, @@ -355,23 +349,14 @@ def _unpack_symint(x: torch.SymInt | int) -> sympy.Expr: raise TypeError(f"Expected SymInt or int, got {type(x)}") -class Lowering: - def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: - raise NotImplementedError - - def get_masked_value(self, node: torch.fx.Node) -> float | bool | None: - """Get the masked value for this node.""" - return None - - @dataclasses.dataclass class InductorLowering(Lowering): buffer: ComputedBuffer input_names: list[str] - def input_asts(self, ctx: GraphInterpreter, node: torch.fx.Node) -> list[ast.AST]: + def input_asts(self, ctx: LoweringContext, node: torch.fx.Node) -> list[ast.AST]: def visit(n: torch.fx.Node) -> None: - ast_val = ctx.env[n] + ast_val = cast("ast.AST", ctx.env[n]) if isinstance(fake_val := n.meta["val"], torch.Tensor): if fake_val.ndim < ndim: # Broadcast to force ranks to match @@ -413,13 +398,13 @@ def visit(n: torch.fx.Node) -> torch.fx.Node: map_arg((node.args, node.kwargs), visit) return result - def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: + def codegen(self, ctx: LoweringContext, node: torch.fx.Node) -> object: raise NotImplementedError( f"codegen not implemented for {type(self).__name__}: {self.buffer}" ) def install_kernel_handlers( - self, ctx: GraphInterpreter, node: torch.fx.Node + self, ctx: LoweringContext, node: torch.fx.Node ) -> ContextManager[None]: return install_inductor_kernel_handlers( ctx.cg, @@ -461,7 +446,7 @@ def __init__(self) -> None: class PointwiseLowering(InductorLowering): - def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: + def codegen(self, ctx: LoweringContext, node: torch.fx.Node) -> object: # Validate broadcasting of tile block dimensions to catch shape mismatches self._check_block_broadcast_compatibility(node) with self.install_kernel_handlers(ctx, node): @@ -612,7 +597,7 @@ def add_input_mask(self, node: torch.fx.Node) -> None: assert isinstance(default, (float, int, bool)) mask_node_inputs(node, default) - def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: + def codegen(self, ctx: LoweringContext, node: torch.fx.Node) -> object: reduction = self.buffer.data assert isinstance(reduction, Reduction) indices = [sympy.Symbol(f"i{n}") for n in range(len(reduction.ranges))] @@ -738,9 +723,9 @@ def __init__(self, api_func: object) -> None: assert is_api_func(api_func) self.api_func: APIFunc = api_func - def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: + def codegen(self, ctx: LoweringContext, node: torch.fx.Node) -> object: assert not node.kwargs - ast_args = [*map_arg(node.args, lambda arg: ctx.env[arg])] + ast_args = [*map_arg(node.args, lambda arg: cast("Argument", ctx.env[arg]))] proxy_args = [*map_arg(node.args, lambda arg: arg.meta["val"])] env = CompileEnvironment.current() @@ -784,7 +769,7 @@ def get_masked_value(self, node: torch.fx.Node) -> float | bool | None: class SympyExprLowering(Lowering): expr: sympy.Expr - def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: + def codegen(self, ctx: LoweringContext, node: torch.fx.Node) -> object: return expr_from_string(ctx.cg.device_function.user_sympy_expr(self.expr)) def get_masked_value(self, node: torch.fx.Node) -> float | bool | None: @@ -795,344 +780,6 @@ def get_masked_value(self, node: torch.fx.Node) -> float | bool | None: return None -@dataclasses.dataclass -class LambdaLowering(Lowering): - fn: Callable[..., object] - masked_value_fn: Callable[[torch.fx.Node], float | bool | None] | None = None - - def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: - return self.fn(ctx, node) - - def get_masked_value(self, node: torch.fx.Node) -> float | bool | None: - if self.masked_value_fn is not None: - return self.masked_value_fn(node) - return None - - -def passthrough_masked_value(node: torch.fx.Node) -> float | bool | None: - for input_node in node.all_input_nodes: - if isinstance(input_node.meta["val"], torch.Tensor): - return cached_masked_value(input_node) - return None - - -aten_lowering_dispatch: dict[object, Callable[[torch.fx.Node], Lowering]] = {} - - -def default_make_lowering( - handler: CodegenHandler, - node: torch.fx.Node, - masked_value_fn: Callable[[torch.fx.Node], float | bool | None] | None = None, -) -> Lowering: - return LambdaLowering(handler, masked_value_fn=masked_value_fn) - - -def register_lowering( - fn: object, - make_lowering: Callable[ - [CodegenHandler, torch.fx.Node], Lowering - ] = default_make_lowering, - masked_value_fn: Callable[[torch.fx.Node], float | bool | None] | None = None, -) -> Callable[[CodegenHandler], CodegenHandler]: - def decorator(handler: CodegenHandler) -> CodegenHandler: - assert fn not in aten_lowering_dispatch, f"Lowering for {fn} already registered" - - aten_lowering_dispatch[fn] = lambda node: make_lowering( - handler, - node, - masked_value_fn=masked_value_fn, # pyright: ignore[reportCallIssue] - ) - return handler - - return decorator - - -@register_lowering(torch.ops.aten.sym_size.int) # pyright: ignore[reportAttributeAccessIssue] -def codegen_sym_size(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - val = node.meta["val"] - assert isinstance( - val, (int, float, bool, torch.SymInt, torch.SymBool, torch.SymFloat) - ) - return val - - -@register_lowering(getitem, masked_value_fn=getitem_masked_value) -def codegen_getitem(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - assert not node.kwargs, "getitem kwargs not supported" - lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg]) - assert isinstance(lhs, (list, tuple)) - assert isinstance(rhs, int) - return lhs[rhs] - - -@register_lowering( - torch.ops.aten.full.default, # pyright: ignore[reportAttributeAccessIssue] - masked_value_fn=lambda n: ( - n.args[1] if isinstance(n.args[1], (int, float, bool)) else None - ), -) -def codegen_full(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - env = CompileEnvironment.current() - size = map_arg(node.args[0], lambda n: n.meta["val"]) - dtype = node.kwargs.get("dtype", torch.get_default_dtype()) - assert isinstance(dtype, torch.dtype) - device = node.kwargs.get("device", env.device) - assert device == env.device, f"expected {env.device}, got {device}" - assert not node.kwargs.get("pin_memory"), "pin_memory not supported" - value_ast = map_arg(node.args[1], lambda arg: ctx.env[arg]) - if isinstance(value_ast, (int, float, bool)): - value_ast = expr_from_string(constant_repr(value_ast)) - assert isinstance(value_ast, ast.AST), value_ast - shape_str = ctx.cg.device_function.tile_strategy.shape_str([*size]) # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable] - return expr_from_string( - f"tl.full({shape_str}, {{value}}, {triton_type(dtype)})", - value=value_ast, - ) - - -@register_lowering( - torch.ops.aten.unsqueeze.default, # pyright: ignore[reportAttributeAccessIssue] - masked_value_fn=passthrough_masked_value, -) -def codegen_unsqueeze(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - assert not node.kwargs, "getitem kwargs not supported" - tensor, dim = map_arg(node.args, lambda arg: ctx.env[arg]) - assert isinstance(tensor, ast.AST) - assert isinstance(dim, int) - ndim = node.args[0].meta["val"].ndim # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] - if dim < 0: - dim += ndim - assert 0 <= dim <= ndim, f"Invalid dim {dim} for tensor with {ndim} dims" - args = [":"] * ndim - args.insert(dim, "None") - return expr_from_string( - f"{{tensor}}[{', '.join(args)}]", - tensor=tensor, - ) - - -@register_lowering(torch.ops.aten.squeeze.dim, masked_value_fn=passthrough_masked_value) # pyright: ignore[reportAttributeAccessIssue] -@register_lowering( - torch.ops.aten.view.default, # pyright: ignore[reportAttributeAccessIssue] - masked_value_fn=passthrough_masked_value, -) -@register_lowering( - torch.ops.aten.reshape.default, # pyright: ignore[reportAttributeAccessIssue] - masked_value_fn=passthrough_masked_value, -) -def codegen_view(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - assert not node.kwargs, "view kwargs not supported" - tensor = map_arg(node.args[0], lambda arg: ctx.env[arg]) - assert isinstance(tensor, ast.AST) - shape_str = ctx.cg.device_function.tile_strategy.shape_str( - [*node.meta["val"].size()] - ) - return expr_from_string(f"tl.reshape({{tensor}}, {shape_str})", tensor=tensor) - - -@register_lowering( - torch.ops.aten.permute.default, # pyright: ignore[reportAttributeAccessIssue] - masked_value_fn=passthrough_masked_value, -) -def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - assert not node.kwargs, "getitem kwargs not supported" - tensor, dims = map_arg(node.args, lambda arg: ctx.env[arg]) - assert isinstance(tensor, ast.AST) - dims = [*dims] # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable] - assert {*dims} == {*range(len(dims))}, dims - return expr_from_string( - f"tl.permute({{tensor}}, {dims!r})", - tensor=tensor, - ) - - -@register_lowering( - torch.ops.aten.stack.default, # pyright: ignore[reportAttributeAccessIssue] - masked_value_fn=passthrough_masked_value, -) -def codegen_stack(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - tensors = node.args[0] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - - assert isinstance(tensors, (list, tuple)) - tensor_asts = [ctx.env[t] for t in tensors] # pyright: ignore[reportArgumentType] - n = len(tensor_asts) - - if n == 0: - raise ValueError("Cannot stack empty tensor list") - - # Round up to power of 2 for efficient masking - padded_size = 1 << (n - 1).bit_length() - - # Create index array [0, 1, 2, 3, ...] for tensor selection - idx = ctx.cg.device_function.new_var("stack_idx") - ctx.cg.add_statement(statement_from_string(f"{idx} = tl.arange(0, {padded_size})")) - - # Broadcast index to target dimension shape - # e.g., dim=0: [:, None, None], dim=1: [None, :, None], dim=2: [None, None, :] - bidx = ctx.cg.device_function.new_var("broadcast_idx") - assert isinstance(dim, int) - pattern = "[" + ", ".join(["None"] * dim + [":"] + ["None"] * max(0, 2 - dim)) + "]" - ctx.cg.add_statement(statement_from_string(f"{bidx} = {idx}{pattern}")) - - # Expand each input tensor along the stack dimension - expanded = [ctx.cg.device_function.new_var(f"expanded_{i}") for i in range(n)] - for var, tensor in zip(expanded, tensor_asts, strict=False): - ctx.cg.add_statement( - statement_from_string(f"{var} = tl.expand_dims({{t}}, {dim})", t=tensor) - ) - - # Initialize result with zeros - result = ctx.cg.device_function.new_var("stacked_result") - ctx.cg.add_statement( - statement_from_string(f"{result} = tl.zeros_like({expanded[0]})") - ) - - # Select each tensor using masks - for i in range(n): - mask = ctx.cg.device_function.new_var(f"mask_{i}") - ctx.cg.add_statement(statement_from_string(f"{mask} = {bidx} == {i}")) - ctx.cg.add_statement( - statement_from_string( - f"{result} = tl.where({mask}, {expanded[i]}, {result})" - ) - ) - - return expr_from_string(result) - - -@register_lowering( - torch.ops.aten.expand.default, # pyright: ignore[reportAttributeAccessIssue] - masked_value_fn=passthrough_masked_value, -) -def codegen_expand(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - assert not node.kwargs, "getitem kwargs not supported" - tensor, _ = map_arg(node.args, lambda arg: ctx.env[arg]) - assert isinstance(tensor, ast.AST) - val = node.meta["val"] - assert isinstance(val, torch.Tensor) - shape = [*val.size()] - if node.args[0].meta["val"].ndim != len(shape): # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] - broadcasting = [":"] * len(shape) - for i in range(len(shape) - node.args[0].meta["val"].ndim): # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] - broadcasting[i] = "None" - tensor = expr_from_string( - f"{{tensor}}[{', '.join(broadcasting)}]", tensor=tensor - ) - shape_str = ctx.cg.device_function.tile_strategy.shape_str(shape) - return expr_from_string( - f"tl.broadcast_to({{tensor}}, {shape_str})", - tensor=tensor, - ) - - -def apply_dot_requirements( - handler: CodegenHandler, - node: torch.fx.Node, - masked_value_fn: Callable[[torch.fx.Node], float | bool | None] | None = None, -) -> Lowering: - """Apply min_dot_size requirements to the config_spec""" - assert not node.kwargs, "dot kwargs not supported" - assert len(node.args) in (2, 3) - lproxy, rproxy = map_arg(node.args[-2:], lambda arg: arg.meta["val"]) - assert isinstance(lproxy, torch.Tensor) - assert isinstance(rproxy, torch.Tensor) - # Update config spec min sizes for M, N, K - enforce_dot_requirements(lproxy, rproxy) - # inputs to the dot operation must be zero-masked - *maybe_acc, lnode, rnode = node.args - assert isinstance(lnode, torch.fx.Node) - assert isinstance(rnode, torch.fx.Node) - lnode = apply_masking(lnode, base_node=node, other=0) - rnode = apply_masking(rnode, base_node=node, other=0) - node.args = (*maybe_acc, lnode, rnode) - return LambdaLowering(handler, masked_value_fn=masked_value_fn) - - -def reduce_3d_dot( - ctx: GraphInterpreter, node: torch.fx.Node, with_acc: bool -) -> ast.AST: - acc = None - acc_node: torch.fx.Node | None = None - if with_acc: - acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg]) - assert isinstance(acc, ast.AST) - assert isinstance(node.args[0], torch.fx.Node) - acc_node = node.args[0] - lhs_node = node.args[1] - rhs_node = node.args[2] - else: - lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg]) - lhs_node = node.args[0] - rhs_node = node.args[1] - assert isinstance(lhs, ast.AST) - assert isinstance(rhs, ast.AST) - assert isinstance(lhs_node, torch.fx.Node) - assert isinstance(rhs_node, torch.fx.Node) - - # Check if inputs are FP8 - if so, redirect user to hl.dot() - lhs_dtype = lhs_node.meta["val"].dtype - rhs_dtype = rhs_node.meta["val"].dtype - acc_dtype_meta: torch.dtype | None = None - if with_acc: - assert acc_node is not None - assert isinstance(acc_node, torch.fx.Node) - acc_dtype_meta = acc_node.meta["val"].dtype - if lhs_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and rhs_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, - ]: - raise NotImplementedError( - "FP8 GEMM via torch API is not supported yet. Please use hl.dot() instead." - ) - - lhs_shape = list(lhs_node.meta["val"].size()) # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] - rhs_shape = list(rhs_node.meta["val"].size()) # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] - acc_shape = ( - list(acc_node.meta["val"].size()) - if (with_acc and acc_node is not None) - else None - ) # pyright: ignore[reportOptionalMemberAccess] - - # Extract expected output dtype from FX node to match PyTorch eager mode behavior - out_dtype: torch.dtype | None = None - if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): - out_dtype = node.meta["val"].dtype - - return emit_tl_dot_with_padding( - lhs, - rhs, - acc if with_acc else None, - lhs_dtype, - rhs_dtype, - acc_dtype=acc_dtype_meta if with_acc else None, - out_dtype=out_dtype, - lhs_shape=lhs_shape, - rhs_shape=rhs_shape, - acc_shape=acc_shape, - ) - - -@register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] -@register_lowering(torch.ops.aten.mm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] -def codegen_mm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST: - assert not node.kwargs, "matmul kwargs not supported" - - return reduce_3d_dot(ctx, node, False) - - -@register_lowering(torch.ops.aten.addmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] -def codegen_addmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST: - assert not node.kwargs, "addmm kwargs not supported" - return reduce_3d_dot(ctx, node, True) - - -@register_lowering(torch.ops.aten.baddbmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] -def codegen_baddbmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST: - assert not node.kwargs, "baddbmm kwargs not supported" - return reduce_3d_dot(ctx, node, True) - - class GenerateASTFromInductor(DefaultHandler): def __init__( self, cg: CodegenInterface, input_name_lookup: dict[str, ast.AST] @@ -1276,10 +923,11 @@ def _unpack_opsvalue(value: object) -> str: return value -class GraphInterpreter(Interpreter): +class GraphInterpreter(LoweringContext, Interpreter): def __init__(self, graph: torch.fx.Graph, cg: CodegenInterface) -> None: super().__init__(_LazyGraphModule({}, graph), garbage_collect_values=False) self.cg = cg + self.env = cast("dict[Node, Argument]", self.env) def to_ast(self, value: object) -> ast.AST: """ @@ -1365,7 +1013,7 @@ def _collect_multi_outputs( # Check if this operation has multiple outputs using the new metadata assert "output_nodes" in node.meta output_nodes = node.meta["output_nodes"] - outputs = [None] * len(output_nodes) + outputs: list[object | None] = [None] * len(output_nodes) all_nodes = { n.name: n for n in self.module.graph.nodes # pyright: ignore[reportAttributeAccessIssue,reportGeneralTypeIssues] @@ -1390,6 +1038,7 @@ def _collect_multi_outputs( assert result is not None if not isinstance(result, ast.Name): var_name = self.cg.device_function.new_var(f"{node.name}_output{i}") + assert isinstance(result, ast.AST) self.cg.add_statement( statement_from_string(f"{var_name} = {{result}}", result=result) ) @@ -1517,138 +1166,3 @@ def add_statement(self, statement: ast.AST | str) -> None: def sympy_expr(self, expr: sympy.Expr) -> str: return self.codegen.device_function.sympy_expr(expr) - - -@register_lowering(torch.ops.prims.iota.default) # pyright: ignore[reportAttributeAccessIssue] -def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - """Generate tl.arange for torch.ops.prims.iota.default operations with automatic power-of-2 padding.""" - start = node.kwargs.get("start", 0) - step = node.kwargs.get("step", 1) - dtype = ( - node.kwargs.get("dtype") or CompileEnvironment.current().settings.index_dtype - ) - assert isinstance(dtype, torch.dtype) - (length_arg,) = node.args # expecting a single argument for length - - # Pad static non-power-of-2 lengths to next power of 2 - length_expr = "{length}" - if isinstance(length_arg, int) and length_arg != next_power_of_2(length_arg): - length_expr = str(next_power_of_2(length_arg)) - - expr = f"tl.arange(0, {length_expr})" - if step != 1: - expr = f"{{step}} * {expr}" - if start != 0: - expr = f"{{start}} + {expr}" - if dtype != torch.int32: - expr = f"({expr}).to({triton_type(dtype)})" - return expr_from_string( - expr, - start=ctx.to_ast(start), - step=ctx.to_ast(step), - length=ctx.to_ast(length_arg), - ) - - -def _codegen_rng_op( - ctx: GraphInterpreter, - node: torch.fx.Node, - rng_function: str, -) -> object: - """Common codegen implementation for all RNG operations. - - Args: - ctx: The graph interpreter context - node: The FX node for this operation - rng_function: Either "rand" or "randn" - """ - from .generate_ast import GenerateAST - - assert rng_function in ["rand", "randn"] - assert isinstance(ctx.cg, GenerateAST) - - # Get unique seed index for this RNG operation - device_fn = ctx.cg.device_function - seed_index = device_fn.allocate_rng_seed() - - # Get dimensionality and dtype - assert hasattr(node, "meta") and "val" in node.meta - fake_value = node.meta["val"] - ndim = fake_value.ndim - dtype = node.kwargs.get("dtype", None) - - # Get dimension names for offset calculation - env = CompileEnvironment.current() - dim_names = [] - for size in fake_value.size(): - block_id = env.get_block_id(size) - assert block_id is not None - block_size = env.block_sizes[block_id].size - dim_names.append(device_fn.literal_expr(block_size)) - - offset_parts = [] - - for i in range(ndim): - # Create the index variable with proper broadcasting - index_expr = f"indices_{i}" - - # Add broadcasting slices for this dimension - # For 1D tensors, this will just be indices_0 with no slicing - slice_parts = [] - for j in range(ndim): - if j < i: - slice_parts.append("None") - elif j == i: - slice_parts.append(":") - else: - slice_parts.append("None") - - # Create the broadcasted index expression - if ndim == 1: - # For 1D, no broadcasting needed - broadcasted_index = index_expr - else: - broadcasted_index = f"{index_expr}[{', '.join(slice_parts)}]" - - # Calculate stride (product of dimensions after this one) - if i < ndim - 1: - # Use the actual dimension variable names - stride_parts = dim_names[i + 1 :] - stride_expr = " * ".join(stride_parts) - offset_parts.append(f"{broadcasted_index} * {stride_expr}") - else: - # Last dimension has no stride multiplication - offset_parts.append(broadcasted_index) - - offset_expr = expr_from_string(" + ".join(offset_parts)) - - # Load seed from buffer using the kernel parameter name - assert device_fn.rng_seed_buffer_param_name is not None - seed_expr = expr_from_string( - "tl.load({buffer} + {index})", - buffer=expr_from_string(device_fn.rng_seed_buffer_param_name), - index=create(ast.Constant, value=seed_index), - ) - - # Generate the RNG call - # Note: tl.rand() and tl.randn() always return float32 - rng_expr = expr_from_string( - f"tl.{rng_function}({{seed}}, {{offset}})", seed=seed_expr, offset=offset_expr - ) - - # Cast to target dtype only if explicitly specified - if dtype is not None: - assert isinstance(dtype, torch.dtype) - rng_expr = expr_from_string(f"{{val}}.to({triton_type(dtype)})", val=rng_expr) - - return rng_expr - - -@register_lowering(torch.ops.aten.rand.default) # pyright: ignore[reportAttributeAccessIssue] -def codegen_rand(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - return _codegen_rng_op(ctx, node, "rand") - - -@register_lowering(torch.ops.aten.randn.default) # pyright: ignore[reportAttributeAccessIssue] -def codegen_randn(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - return _codegen_rng_op(ctx, node, "randn") diff --git a/helion/_compiler/roll_reduction.py b/helion/_compiler/roll_reduction.py index 4aed09f93..262abcfa3 100644 --- a/helion/_compiler/roll_reduction.py +++ b/helion/_compiler/roll_reduction.py @@ -24,10 +24,10 @@ from ..language.reduce_ops import _reduce from ..language.view_ops import join as hl_join from ..language.view_ops import split as hl_split +from .aten_lowering import aten_lowering_dispatch from .compile_environment import CompileEnvironment from .inductor_lowering import APIFuncLowering from .inductor_lowering import ReductionLowering -from .inductor_lowering import aten_lowering_dispatch if TYPE_CHECKING: from .compile_environment import BlockSizeInfo