diff --git a/helion/_compiler/aten_lowering.py b/helion/_compiler/aten_lowering.py index e1c15ab93..517a89e23 100644 --- a/helion/_compiler/aten_lowering.py +++ b/helion/_compiler/aten_lowering.py @@ -15,6 +15,7 @@ from torch.fx.node import map_arg from triton import next_power_of_2 +from .. import exc from ..language.matmul_ops import enforce_dot_requirements from .ast_extension import create from .ast_extension import expr_from_string @@ -55,12 +56,34 @@ def _env_arg(ctx: LoweringContext, node: Node) -> Argument: @dataclasses.dataclass -class LambdaLowering(Lowering): - fn: Callable[..., object] +class AtenLowering(Lowering): + target: object | None = None masked_value_fn: MaskedValueFn | None = None + codegen_impls: dict[str, CodegenHandler] = dataclasses.field(default_factory=dict) + + def register_codegen( + self, backend: str + ) -> Callable[[CodegenHandler], CodegenHandler]: + def decorator(handler: CodegenHandler) -> CodegenHandler: + assert backend not in self.codegen_impls, ( + f"codegen already registered for backend {backend!r}" + ) + self.codegen_impls[backend] = handler + return handler + + return decorator def codegen(self, ctx: LoweringContext, node: Node) -> object: - return self.fn(ctx, node) + backend = CompileEnvironment.current().backend + try: + handler = self.codegen_impls[backend] + except KeyError as err: # pragma: no cover - defensive + target = self.target or "unknown" + raise exc.BackendImplementationMissing( + backend, + f"Aten lowering codegen not registered for {target!r}", + ) from err + return handler(ctx, node) def get_masked_value(self, node: Node) -> float | bool | None: if self.masked_value_fn is not None: @@ -78,33 +101,27 @@ def passthrough_masked_value(node: Node) -> float | bool | None: aten_lowering_dispatch: dict[object, Callable[[Node], Lowering]] = {} -def default_make_lowering( - handler: CodegenHandler, - node: Node, - masked_value_fn: MaskedValueFn | None = None, -) -> Lowering: - return LambdaLowering(handler, masked_value_fn=masked_value_fn) +def default_make_lowering(lowering: AtenLowering, node: Node) -> Lowering: + return lowering def register_lowering( fn: object, - make_lowering: Callable[[CodegenHandler, Node], Lowering] = default_make_lowering, + make_lowering: Callable[[AtenLowering, Node], Lowering] = default_make_lowering, masked_value_fn: MaskedValueFn | None = None, -) -> Callable[[CodegenHandler], CodegenHandler]: - def decorator(handler: CodegenHandler) -> CodegenHandler: - assert fn not in aten_lowering_dispatch, f"Lowering for {fn} already registered" - - aten_lowering_dispatch[fn] = lambda node: make_lowering( - handler, - node, - masked_value_fn=masked_value_fn, # pyright: ignore[reportCallIssue] - ) - return handler +) -> AtenLowering: + assert fn not in aten_lowering_dispatch, f"Lowering for {fn} already registered" + lowering = AtenLowering(target=fn, masked_value_fn=masked_value_fn) + aten_lowering_dispatch[fn] = lambda node: make_lowering(lowering, node) + return lowering - return decorator +sym_size_lowering = register_lowering( + torch.ops.aten.sym_size.int # pyright: ignore[reportAttributeAccessIssue] +) -@register_lowering(torch.ops.aten.sym_size.int) # pyright: ignore[reportAttributeAccessIssue] + +@sym_size_lowering.register_codegen("triton") def codegen_sym_size(ctx: LoweringContext, node: Node) -> object: val = node.meta["val"] assert isinstance( @@ -113,7 +130,10 @@ def codegen_sym_size(ctx: LoweringContext, node: Node) -> object: return val -@register_lowering(getitem, masked_value_fn=getitem_masked_value) +getitem_lowering = register_lowering(getitem, masked_value_fn=getitem_masked_value) + + +@getitem_lowering.register_codegen("triton") def codegen_getitem(ctx: LoweringContext, node: Node) -> object: assert not node.kwargs, "getitem kwargs not supported" lhs, rhs = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) @@ -122,12 +142,15 @@ def codegen_getitem(ctx: LoweringContext, node: Node) -> object: return lhs[rhs] -@register_lowering( +full_lowering = register_lowering( torch.ops.aten.full.default, # pyright: ignore[reportAttributeAccessIssue] masked_value_fn=lambda n: ( n.args[1] if isinstance(n.args[1], (int, float, bool)) else None ), ) + + +@full_lowering.register_codegen("triton") def codegen_full(ctx: LoweringContext, node: Node) -> object: env = CompileEnvironment.current() size = map_arg(node.args[0], lambda n: n.meta["val"]) @@ -147,10 +170,13 @@ def codegen_full(ctx: LoweringContext, node: Node) -> object: ) -@register_lowering( +unsqueeze_lowering = register_lowering( torch.ops.aten.unsqueeze.default, # pyright: ignore[reportAttributeAccessIssue] masked_value_fn=passthrough_masked_value, ) + + +@unsqueeze_lowering.register_codegen("triton") def codegen_unsqueeze(ctx: LoweringContext, node: Node) -> object: assert not node.kwargs, "getitem kwargs not supported" tensor, dim = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) @@ -168,15 +194,23 @@ def codegen_unsqueeze(ctx: LoweringContext, node: Node) -> object: ) -@register_lowering(torch.ops.aten.squeeze.dim, masked_value_fn=passthrough_masked_value) # pyright: ignore[reportAttributeAccessIssue] -@register_lowering( +squeeze_lowering = register_lowering( + torch.ops.aten.squeeze.dim, # pyright: ignore[reportAttributeAccessIssue] + masked_value_fn=passthrough_masked_value, +) +view_lowering = register_lowering( torch.ops.aten.view.default, # pyright: ignore[reportAttributeAccessIssue] masked_value_fn=passthrough_masked_value, ) -@register_lowering( +reshape_lowering = register_lowering( torch.ops.aten.reshape.default, # pyright: ignore[reportAttributeAccessIssue] masked_value_fn=passthrough_masked_value, ) + + +@squeeze_lowering.register_codegen("triton") +@view_lowering.register_codegen("triton") +@reshape_lowering.register_codegen("triton") def codegen_view(ctx: LoweringContext, node: Node) -> object: assert not node.kwargs, "view kwargs not supported" tensor = map_arg(node.args[0], lambda arg: _env_arg(ctx, arg)) @@ -187,10 +221,13 @@ def codegen_view(ctx: LoweringContext, node: Node) -> object: return expr_from_string(f"tl.reshape({{tensor}}, {shape_str})", tensor=tensor) -@register_lowering( +permute_lowering = register_lowering( torch.ops.aten.permute.default, # pyright: ignore[reportAttributeAccessIssue] masked_value_fn=passthrough_masked_value, ) + + +@permute_lowering.register_codegen("triton") def codegen_permute(ctx: LoweringContext, node: Node) -> object: assert not node.kwargs, "getitem kwargs not supported" tensor, dims = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) @@ -203,10 +240,13 @@ def codegen_permute(ctx: LoweringContext, node: Node) -> object: ) -@register_lowering( +stack_lowering = register_lowering( torch.ops.aten.stack.default, # pyright: ignore[reportAttributeAccessIssue] masked_value_fn=passthrough_masked_value, ) + + +@stack_lowering.register_codegen("triton") def codegen_stack(ctx: LoweringContext, node: Node) -> object: tensors = node.args[0] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) @@ -259,10 +299,13 @@ def codegen_stack(ctx: LoweringContext, node: Node) -> object: return expr_from_string(result) -@register_lowering( +expand_lowering = register_lowering( torch.ops.aten.expand.default, # pyright: ignore[reportAttributeAccessIssue] masked_value_fn=passthrough_masked_value, ) + + +@expand_lowering.register_codegen("triton") def codegen_expand(ctx: LoweringContext, node: Node) -> object: assert not node.kwargs, "getitem kwargs not supported" tensor, _ = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) @@ -284,11 +327,7 @@ def codegen_expand(ctx: LoweringContext, node: Node) -> object: ) -def apply_dot_requirements( - handler: CodegenHandler, - node: Node, - masked_value_fn: MaskedValueFn | None = None, -) -> Lowering: +def apply_dot_requirements(lowering: AtenLowering, node: Node) -> Lowering: """Apply min_dot_size requirements to the config_spec""" assert not node.kwargs, "dot kwargs not supported" assert len(node.args) in (2, 3) @@ -304,7 +343,7 @@ def apply_dot_requirements( lnode = apply_masking(lnode, base_node=node, other=0) rnode = apply_masking(rnode, base_node=node, other=0) node.args = (*maybe_acc, lnode, rnode) - return LambdaLowering(handler, masked_value_fn=masked_value_fn) + return lowering def reduce_3d_dot(ctx: LoweringContext, node: Node, with_acc: bool) -> ast.AST: @@ -369,27 +408,54 @@ def reduce_3d_dot(ctx: LoweringContext, node: Node, with_acc: bool) -> ast.AST: ) -@register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] -@register_lowering(torch.ops.aten.mm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] +bmm_lowering = register_lowering( + torch.ops.aten.bmm.default, # pyright: ignore[reportAttributeAccessIssue] + apply_dot_requirements, +) +mm_lowering = register_lowering( + torch.ops.aten.mm.default, # pyright: ignore[reportAttributeAccessIssue] + apply_dot_requirements, +) + + +@bmm_lowering.register_codegen("triton") +@mm_lowering.register_codegen("triton") def codegen_mm(ctx: LoweringContext, node: Node) -> ast.AST: assert not node.kwargs, "matmul kwargs not supported" return reduce_3d_dot(ctx, node, False) -@register_lowering(torch.ops.aten.addmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] +addmm_lowering = register_lowering( + torch.ops.aten.addmm.default, # pyright: ignore[reportAttributeAccessIssue] + apply_dot_requirements, +) + + +@addmm_lowering.register_codegen("triton") def codegen_addmm(ctx: LoweringContext, node: Node) -> ast.AST: assert not node.kwargs, "addmm kwargs not supported" return reduce_3d_dot(ctx, node, True) -@register_lowering(torch.ops.aten.baddbmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue] +baddbmm_lowering = register_lowering( + torch.ops.aten.baddbmm.default, # pyright: ignore[reportAttributeAccessIssue] + apply_dot_requirements, +) + + +@baddbmm_lowering.register_codegen("triton") def codegen_baddbmm(ctx: LoweringContext, node: Node) -> ast.AST: assert not node.kwargs, "baddbmm kwargs not supported" return reduce_3d_dot(ctx, node, True) -@register_lowering(torch.ops.prims.iota.default) # pyright: ignore[reportAttributeAccessIssue] +iota_lowering = register_lowering( + torch.ops.prims.iota.default # pyright: ignore[reportAttributeAccessIssue] +) + + +@iota_lowering.register_codegen("triton") def codegen_iota(ctx: LoweringContext, node: Node) -> object: """Generate tl.arange for torch.ops.prims.iota.default operations with automatic power-of-2 padding.""" start = node.kwargs.get("start", 0) @@ -514,11 +580,21 @@ def _codegen_rng_op( return rng_expr -@register_lowering(torch.ops.aten.rand.default) # pyright: ignore[reportAttributeAccessIssue] +rand_lowering = register_lowering( + torch.ops.aten.rand.default # pyright: ignore[reportAttributeAccessIssue] +) + + +@rand_lowering.register_codegen("triton") def codegen_rand(ctx: LoweringContext, node: Node) -> object: return _codegen_rng_op(ctx, node, "rand") -@register_lowering(torch.ops.aten.randn.default) # pyright: ignore[reportAttributeAccessIssue] +randn_lowering = register_lowering( + torch.ops.aten.randn.default # pyright: ignore[reportAttributeAccessIssue] +) + + +@randn_lowering.register_codegen("triton") def codegen_randn(ctx: LoweringContext, node: Node) -> object: return _codegen_rng_op(ctx, node, "randn")