Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 120 additions & 44 deletions helion/_compiler/aten_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.fx.node import map_arg
from triton import next_power_of_2

from .. import exc
from ..language.matmul_ops import enforce_dot_requirements
from .ast_extension import create
from .ast_extension import expr_from_string
Expand Down Expand Up @@ -55,12 +56,34 @@ def _env_arg(ctx: LoweringContext, node: Node) -> Argument:


@dataclasses.dataclass
class LambdaLowering(Lowering):
fn: Callable[..., object]
class AtenLowering(Lowering):
target: object | None = None
masked_value_fn: MaskedValueFn | None = None
codegen_impls: dict[str, CodegenHandler] = dataclasses.field(default_factory=dict)

def register_codegen(
self, backend: str
) -> Callable[[CodegenHandler], CodegenHandler]:
def decorator(handler: CodegenHandler) -> CodegenHandler:
assert backend not in self.codegen_impls, (
f"codegen already registered for backend {backend!r}"
)
self.codegen_impls[backend] = handler
return handler

return decorator

def codegen(self, ctx: LoweringContext, node: Node) -> object:
return self.fn(ctx, node)
backend = CompileEnvironment.current().backend
try:
handler = self.codegen_impls[backend]
except KeyError as err: # pragma: no cover - defensive
target = self.target or "unknown"
raise exc.BackendImplementationMissing(
backend,
f"Aten lowering codegen not registered for {target!r}",
) from err
return handler(ctx, node)

def get_masked_value(self, node: Node) -> float | bool | None:
if self.masked_value_fn is not None:
Expand All @@ -78,33 +101,27 @@ def passthrough_masked_value(node: Node) -> float | bool | None:
aten_lowering_dispatch: dict[object, Callable[[Node], Lowering]] = {}


def default_make_lowering(
handler: CodegenHandler,
node: Node,
masked_value_fn: MaskedValueFn | None = None,
) -> Lowering:
return LambdaLowering(handler, masked_value_fn=masked_value_fn)
def default_make_lowering(lowering: AtenLowering, node: Node) -> Lowering:
return lowering


def register_lowering(
fn: object,
make_lowering: Callable[[CodegenHandler, Node], Lowering] = default_make_lowering,
make_lowering: Callable[[AtenLowering, Node], Lowering] = default_make_lowering,
masked_value_fn: MaskedValueFn | None = None,
) -> Callable[[CodegenHandler], CodegenHandler]:
def decorator(handler: CodegenHandler) -> CodegenHandler:
assert fn not in aten_lowering_dispatch, f"Lowering for {fn} already registered"

aten_lowering_dispatch[fn] = lambda node: make_lowering(
handler,
node,
masked_value_fn=masked_value_fn, # pyright: ignore[reportCallIssue]
)
return handler
) -> AtenLowering:
assert fn not in aten_lowering_dispatch, f"Lowering for {fn} already registered"
lowering = AtenLowering(target=fn, masked_value_fn=masked_value_fn)
aten_lowering_dispatch[fn] = lambda node: make_lowering(lowering, node)
return lowering

return decorator

sym_size_lowering = register_lowering(
torch.ops.aten.sym_size.int # pyright: ignore[reportAttributeAccessIssue]
)

@register_lowering(torch.ops.aten.sym_size.int) # pyright: ignore[reportAttributeAccessIssue]

@sym_size_lowering.register_codegen("triton")
def codegen_sym_size(ctx: LoweringContext, node: Node) -> object:
val = node.meta["val"]
assert isinstance(
Expand All @@ -113,7 +130,10 @@ def codegen_sym_size(ctx: LoweringContext, node: Node) -> object:
return val


@register_lowering(getitem, masked_value_fn=getitem_masked_value)
getitem_lowering = register_lowering(getitem, masked_value_fn=getitem_masked_value)


@getitem_lowering.register_codegen("triton")
def codegen_getitem(ctx: LoweringContext, node: Node) -> object:
assert not node.kwargs, "getitem kwargs not supported"
lhs, rhs = map_arg(node.args, lambda arg: _env_arg(ctx, arg))
Expand All @@ -122,12 +142,15 @@ def codegen_getitem(ctx: LoweringContext, node: Node) -> object:
return lhs[rhs]


@register_lowering(
full_lowering = register_lowering(
torch.ops.aten.full.default, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=lambda n: (
n.args[1] if isinstance(n.args[1], (int, float, bool)) else None
),
)


@full_lowering.register_codegen("triton")
def codegen_full(ctx: LoweringContext, node: Node) -> object:
env = CompileEnvironment.current()
size = map_arg(node.args[0], lambda n: n.meta["val"])
Expand All @@ -147,10 +170,13 @@ def codegen_full(ctx: LoweringContext, node: Node) -> object:
)


@register_lowering(
unsqueeze_lowering = register_lowering(
torch.ops.aten.unsqueeze.default, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=passthrough_masked_value,
)


@unsqueeze_lowering.register_codegen("triton")
def codegen_unsqueeze(ctx: LoweringContext, node: Node) -> object:
assert not node.kwargs, "getitem kwargs not supported"
tensor, dim = map_arg(node.args, lambda arg: _env_arg(ctx, arg))
Expand All @@ -168,15 +194,23 @@ def codegen_unsqueeze(ctx: LoweringContext, node: Node) -> object:
)


@register_lowering(torch.ops.aten.squeeze.dim, masked_value_fn=passthrough_masked_value) # pyright: ignore[reportAttributeAccessIssue]
@register_lowering(
squeeze_lowering = register_lowering(
torch.ops.aten.squeeze.dim, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=passthrough_masked_value,
)
view_lowering = register_lowering(
torch.ops.aten.view.default, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=passthrough_masked_value,
)
@register_lowering(
reshape_lowering = register_lowering(
torch.ops.aten.reshape.default, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=passthrough_masked_value,
)


@squeeze_lowering.register_codegen("triton")
@view_lowering.register_codegen("triton")
@reshape_lowering.register_codegen("triton")
def codegen_view(ctx: LoweringContext, node: Node) -> object:
assert not node.kwargs, "view kwargs not supported"
tensor = map_arg(node.args[0], lambda arg: _env_arg(ctx, arg))
Expand All @@ -187,10 +221,13 @@ def codegen_view(ctx: LoweringContext, node: Node) -> object:
return expr_from_string(f"tl.reshape({{tensor}}, {shape_str})", tensor=tensor)


@register_lowering(
permute_lowering = register_lowering(
torch.ops.aten.permute.default, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=passthrough_masked_value,
)


@permute_lowering.register_codegen("triton")
def codegen_permute(ctx: LoweringContext, node: Node) -> object:
assert not node.kwargs, "getitem kwargs not supported"
tensor, dims = map_arg(node.args, lambda arg: _env_arg(ctx, arg))
Expand All @@ -203,10 +240,13 @@ def codegen_permute(ctx: LoweringContext, node: Node) -> object:
)


@register_lowering(
stack_lowering = register_lowering(
torch.ops.aten.stack.default, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=passthrough_masked_value,
)


@stack_lowering.register_codegen("triton")
def codegen_stack(ctx: LoweringContext, node: Node) -> object:
tensors = node.args[0]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
Expand Down Expand Up @@ -259,10 +299,13 @@ def codegen_stack(ctx: LoweringContext, node: Node) -> object:
return expr_from_string(result)


@register_lowering(
expand_lowering = register_lowering(
torch.ops.aten.expand.default, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=passthrough_masked_value,
)


@expand_lowering.register_codegen("triton")
def codegen_expand(ctx: LoweringContext, node: Node) -> object:
assert not node.kwargs, "getitem kwargs not supported"
tensor, _ = map_arg(node.args, lambda arg: _env_arg(ctx, arg))
Expand All @@ -284,11 +327,7 @@ def codegen_expand(ctx: LoweringContext, node: Node) -> object:
)


def apply_dot_requirements(
handler: CodegenHandler,
node: Node,
masked_value_fn: MaskedValueFn | None = None,
) -> Lowering:
def apply_dot_requirements(lowering: AtenLowering, node: Node) -> Lowering:
"""Apply min_dot_size requirements to the config_spec"""
assert not node.kwargs, "dot kwargs not supported"
assert len(node.args) in (2, 3)
Expand All @@ -304,7 +343,7 @@ def apply_dot_requirements(
lnode = apply_masking(lnode, base_node=node, other=0)
rnode = apply_masking(rnode, base_node=node, other=0)
node.args = (*maybe_acc, lnode, rnode)
return LambdaLowering(handler, masked_value_fn=masked_value_fn)
return lowering


def reduce_3d_dot(ctx: LoweringContext, node: Node, with_acc: bool) -> ast.AST:
Expand Down Expand Up @@ -369,27 +408,54 @@ def reduce_3d_dot(ctx: LoweringContext, node: Node, with_acc: bool) -> ast.AST:
)


@register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue]
@register_lowering(torch.ops.aten.mm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue]
bmm_lowering = register_lowering(
torch.ops.aten.bmm.default, # pyright: ignore[reportAttributeAccessIssue]
apply_dot_requirements,
)
mm_lowering = register_lowering(
torch.ops.aten.mm.default, # pyright: ignore[reportAttributeAccessIssue]
apply_dot_requirements,
)


@bmm_lowering.register_codegen("triton")
@mm_lowering.register_codegen("triton")
def codegen_mm(ctx: LoweringContext, node: Node) -> ast.AST:
assert not node.kwargs, "matmul kwargs not supported"

return reduce_3d_dot(ctx, node, False)


@register_lowering(torch.ops.aten.addmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue]
addmm_lowering = register_lowering(
torch.ops.aten.addmm.default, # pyright: ignore[reportAttributeAccessIssue]
apply_dot_requirements,
)


@addmm_lowering.register_codegen("triton")
def codegen_addmm(ctx: LoweringContext, node: Node) -> ast.AST:
assert not node.kwargs, "addmm kwargs not supported"
return reduce_3d_dot(ctx, node, True)


@register_lowering(torch.ops.aten.baddbmm.default, apply_dot_requirements) # pyright: ignore[reportAttributeAccessIssue]
baddbmm_lowering = register_lowering(
torch.ops.aten.baddbmm.default, # pyright: ignore[reportAttributeAccessIssue]
apply_dot_requirements,
)


@baddbmm_lowering.register_codegen("triton")
def codegen_baddbmm(ctx: LoweringContext, node: Node) -> ast.AST:
assert not node.kwargs, "baddbmm kwargs not supported"
return reduce_3d_dot(ctx, node, True)


@register_lowering(torch.ops.prims.iota.default) # pyright: ignore[reportAttributeAccessIssue]
iota_lowering = register_lowering(
torch.ops.prims.iota.default # pyright: ignore[reportAttributeAccessIssue]
)


@iota_lowering.register_codegen("triton")
def codegen_iota(ctx: LoweringContext, node: Node) -> object:
"""Generate tl.arange for torch.ops.prims.iota.default operations with automatic power-of-2 padding."""
start = node.kwargs.get("start", 0)
Expand Down Expand Up @@ -514,11 +580,21 @@ def _codegen_rng_op(
return rng_expr


@register_lowering(torch.ops.aten.rand.default) # pyright: ignore[reportAttributeAccessIssue]
rand_lowering = register_lowering(
torch.ops.aten.rand.default # pyright: ignore[reportAttributeAccessIssue]
)


@rand_lowering.register_codegen("triton")
def codegen_rand(ctx: LoweringContext, node: Node) -> object:
return _codegen_rng_op(ctx, node, "rand")


@register_lowering(torch.ops.aten.randn.default) # pyright: ignore[reportAttributeAccessIssue]
randn_lowering = register_lowering(
torch.ops.aten.randn.default # pyright: ignore[reportAttributeAccessIssue]
)


@randn_lowering.register_codegen("triton")
def codegen_randn(ctx: LoweringContext, node: Node) -> object:
return _codegen_rng_op(ctx, node, "randn")
Loading