Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
super().__init__()
self.device = device
self.settings = settings
# TODO(jansel): make backend configurable
self.backend = "triton"
self.shape_env = ShapeEnv(
specialize_zero_one=True,
duck_shape=False,
Expand Down
24 changes: 17 additions & 7 deletions helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,13 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
assert fn_node._type_info is not None
fn = fn_node._type_info.proxy()
assert is_api_func(fn)
assert fn._codegen is not None
env = CompileEnvironment.current()
codegen_fn = fn._codegen.get(env.backend)
if codegen_fn is None:
raise exc.BackendImplementationMissing(
env.backend,
f"codegen for API function {fn.__qualname__}",
)
bound = fn._signature.bind(*args, **kwargs)
bound.apply_defaults()

Expand All @@ -285,7 +291,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
ast_args=None, # pyright: ignore[reportArgumentType]
)

fn._codegen(state)
codegen_fn(state)
assert node._root_id is not None
codegen_call_with_graph(
self,
Expand Down Expand Up @@ -376,11 +382,15 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
[x.from_config(self.device_function.config) for x in block_infos]
)
)
elif (
isinstance(fn_type_info := func_node._type_info, CallableType)
and is_api_func(api := fn_type_info.value)
and api._codegen is not None
elif isinstance(fn_type_info := func_node._type_info, CallableType) and (
is_api_func(api := fn_type_info.value)
):
codegen_fn = api._codegen.get(env.backend)
if codegen_fn is None:
raise exc.BackendImplementationMissing(
env.backend,
f"codegen for API function {api.__qualname__}",
)
ast_args = []
ast_kwargs = {}
proxy_args = []
Expand All @@ -401,7 +411,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
proxy_params = api._signature.bind(*proxy_args, **proxy_kwargs)
ast_params.apply_defaults()
proxy_params.apply_defaults()
return api._codegen( # pyright: ignore[reportReturnType]
return codegen_fn( # pyright: ignore[reportReturnType]
CodegenState(
self,
None,
Expand Down
12 changes: 8 additions & 4 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,6 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
self.buffer.data.inner_fn(indices, reduction_indices)
)

from .. import exc
from .generate_ast import GenerateAST

if not isinstance(ctx.cg, GenerateAST):
Expand Down Expand Up @@ -744,14 +743,19 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
ast_args = [*map_arg(node.args, lambda arg: ctx.env[arg])]
proxy_args = [*map_arg(node.args, lambda arg: arg.meta["val"])]

assert self.api_func._codegen is not None
from .. import exc
env = CompileEnvironment.current()
codegen_fn = self.api_func._codegen.get(env.backend)
if codegen_fn is None:
raise exc.BackendImplementationMissing(
env.backend,
f"codegen for API function {self.api_func.__qualname__}",
)
from .generate_ast import GenerateAST

if not isinstance(ctx.cg, GenerateAST):
raise exc.NotAllowedInHelperFunction

return self.api_func._codegen(
return codegen_fn(
CodegenState(
ctx.cg,
fx_node=node,
Expand Down
7 changes: 7 additions & 0 deletions helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class AutotuneError(BaseError):
message = "{0}"


class BackendImplementationMissing(BaseError):
message = "Backend '{backend}' is missing required implementation: {detail}"

def __init__(self, backend: str, detail: str) -> None:
super().__init__(backend=backend, detail=detail)


class CacheAssertionError(BaseError):
message = "Expected cache hit for kernel '{0}', but got cache miss. See stderr for diagnostic information."

Expand Down
13 changes: 7 additions & 6 deletions helion/language/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class APIFunc(Protocol):
_cache_type: Whether to cache the type information for repeated calls.
_type_function: A callable that determines the return type of this function
during type propagation phase.
_codegen: A callable that generates the device code for this function.
_codegen: Mapping of backend names to callables that generate device code.
_fake_fn: A callable that provides a "fake" implementation used during
tracing and compilation.
_prepare_args: A callable that preprocesses the arguments before they're
Expand All @@ -72,7 +72,7 @@ class APIFunc(Protocol):
_tiles_as_sizes: bool
_cache_type: bool
_type_function: Callable[..., TypeInfo] | None
_codegen: Callable[[CodegenState], object] | None
_codegen: dict[str, Callable[[CodegenState], object]]
_fake_fn: Callable[..., object] | None
_prepare_args: Callable[[tuple[object, ...]], tuple[object, ...]]
_get_masked_value: Callable[[torch.fx.Node], float | bool | None] | None
Expand Down Expand Up @@ -189,7 +189,7 @@ def wrapper(*args: object, **kwargs: object) -> object:
api._prepare_args = no_op_prepare_args
api._cache_type = cache_type
api._type_function = None
api._codegen = None
api._codegen = {}
api._fake_fn = None
api._get_masked_value = None
api._to_device_ir = None
Expand Down Expand Up @@ -254,15 +254,16 @@ def _impl(

def codegen(
original_fn: Callable[..., object],
backend: str,
) -> _NoReturnDecorator[object]:
def _impl(codegen_fn: Callable[[CodegenState], object]) -> Callable[..., Never]:
assert is_api_func(original_fn), (
f"{type_propagation.__qualname__} can only be used on API functions"
)
assert original_fn._codegen is None, (
"codegen can only be used once per function"
assert backend not in original_fn._codegen, (
f"codegen already registered for backend {backend!r}"
)
original_fn._codegen = codegen_fn
original_fn._codegen[backend] = codegen_fn
return _no_call

return _impl # pyright: ignore[reportReturnType]
Expand Down
22 changes: 11 additions & 11 deletions helion/language/_tracing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_symnode(debug_name: str) -> int:
raise AssertionError("this should never be called")


@_decorators.codegen(_get_symnode)
@_decorators.codegen(_get_symnode, "triton")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of hardcoding the string here, can we use an enum?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would make out-of-tree backends harder to implement.

def _(state: CodegenState) -> ast.AST:
val = state.fx_node.meta["val"] # pyright: ignore[reportOptionalMemberAccess]

Expand Down Expand Up @@ -69,7 +69,7 @@ def _host_tensor(debug_name: str) -> torch.Tensor:
raise AssertionError("this should never be called")


@_decorators.codegen(_host_tensor)
@_decorators.codegen(_host_tensor, "triton")
def _(state: CodegenState) -> ast.AST:
return expr_from_string("_host_tensor") # should be unused

Expand All @@ -83,7 +83,7 @@ def _for_loop(
raise AssertionError("this should never be called")


@_decorators.codegen(_for_loop)
@_decorators.codegen(_for_loop, "triton")
def _(state: CodegenState) -> None:
return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue]

Expand All @@ -100,7 +100,7 @@ def _while_loop(
raise AssertionError("this should never be called")


@_decorators.codegen(_while_loop)
@_decorators.codegen(_while_loop, "triton")
def _(state: CodegenState) -> None:
return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue]

Expand All @@ -112,7 +112,7 @@ def _if(test: object, graph_id: int, args: list[object]) -> list[object]:
raise AssertionError("this should never be called")


@_decorators.codegen(_if)
@_decorators.codegen(_if, "triton")
def _(state: CodegenState) -> None:
return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue]

Expand All @@ -139,7 +139,7 @@ def _(lhs: object, rhs: object) -> object:
return torch.empty_like(lhs)


@_decorators.codegen(_phi)
@_decorators.codegen(_phi, "triton")
def _(state: CodegenState) -> ast.Name:
lhs = state.ast_arg(0)
assert isinstance(lhs, ast.Name), lhs
Expand Down Expand Up @@ -180,7 +180,7 @@ def _and(left: object, right: object) -> object:
raise NotInsideKernel


@_decorators.codegen(_and)
@_decorators.codegen(_and, "triton")
def _(state: CodegenState) -> None:
return expr_from_string(
"{lhs} and {rhs}", lhs=state.ast_arg(0), rhs=state.ast_arg(1)
Expand Down Expand Up @@ -233,7 +233,7 @@ def _(left: object, right: object) -> object:
return env.shape_env.create_unbacked_symbool()


@_decorators.codegen(_or)
@_decorators.codegen(_or, "triton")
def _(state: CodegenState) -> None:
return expr_from_string(
"{lhs} or {rhs}", lhs=state.ast_arg(0), rhs=state.ast_arg(1)
Expand All @@ -258,7 +258,7 @@ def _(left: object) -> object:
return env.shape_env.create_unbacked_symbool()


@_decorators.codegen(_not)
@_decorators.codegen(_not, "triton")
def _(state: CodegenState) -> ast.AST:
return expr_from_string(
"not {lhs}",
Expand Down Expand Up @@ -289,7 +289,7 @@ def _(tensor: torch.Tensor, other: float) -> torch.Tensor:
return torch.empty_like(tensor)


@_decorators.codegen(_mask_to)
@_decorators.codegen(_mask_to, "triton")
def _(state: CodegenState) -> ast.AST:
tensor = state.proxy_arg(0)
assert isinstance(tensor, torch.Tensor)
Expand Down Expand Up @@ -351,7 +351,7 @@ def _(value: _T) -> _T:
raise NotImplementedError(f"Unsupported type for _new_var: {type(value)}")


@_decorators.codegen(_new_var)
@_decorators.codegen(_new_var, "triton")
def _(state: CodegenState) -> ast.AST:
value = state.ast_arg(0)
assert isinstance(value, ast.AST)
Expand Down
16 changes: 8 additions & 8 deletions helion/language/atomic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def apply(t: torch.Tensor, idx_tuple: tuple, v: object) -> None:
return prev


@_decorators.codegen(atomic_add)
@_decorators.codegen(atomic_add, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_add", state, _to_ast_values([value_expr]))
Expand Down Expand Up @@ -343,7 +343,7 @@ def _(
return prev


@_decorators.codegen(atomic_xchg)
@_decorators.codegen(atomic_xchg, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_xchg", state, _to_ast_values([value_expr]))
Expand Down Expand Up @@ -420,7 +420,7 @@ def _(
return prev


@_decorators.codegen(atomic_and)
@_decorators.codegen(atomic_and, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_and", state, _to_ast_values([value_expr]))
Expand Down Expand Up @@ -494,7 +494,7 @@ def _(
return prev


@_decorators.codegen(atomic_or)
@_decorators.codegen(atomic_or, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_or", state, _to_ast_values([value_expr]))
Expand Down Expand Up @@ -568,7 +568,7 @@ def _(
return prev


@_decorators.codegen(atomic_xor)
@_decorators.codegen(atomic_xor, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_xor", state, _to_ast_values([value_expr]))
Expand Down Expand Up @@ -634,7 +634,7 @@ def apply(t: torch.Tensor, idx: tuple, v: object) -> None:
_ref_apply(target, index, apply, value)


@_decorators.codegen(atomic_max)
@_decorators.codegen(atomic_max, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_max", state, _to_ast_values([value_expr]))
Expand Down Expand Up @@ -709,7 +709,7 @@ def _(
return prev


@_decorators.codegen(atomic_min)
@_decorators.codegen(atomic_min, "triton")
def _(state: CodegenState) -> ast.AST:
value_expr = state.ast_args[2]
return _codegen_common("atomic_min", state, _to_ast_values([value_expr]))
Expand Down Expand Up @@ -810,7 +810,7 @@ def _(
return prev


@_decorators.codegen(atomic_cas)
@_decorators.codegen(atomic_cas, "triton")
def _(state: CodegenState) -> ast.AST:
exp_expr = state.ast_args[2]
val_expr = state.ast_args[3]
Expand Down
2 changes: 1 addition & 1 deletion helion/language/constexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def handle_symint(symint: torch.SymInt) -> int:
return TypeInfo.from_example(specialized, origin=origin)


@_decorators.codegen(specialize)
@_decorators.codegen(specialize, "triton")
def _(state: CodegenState) -> ast.AST:
value = state.proxy_arg(0)
specialized = _convert_specializable(value)
Expand Down
2 changes: 1 addition & 1 deletion helion/language/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _full_fake(
)


@_decorators.codegen(full)
@_decorators.codegen(full, "triton")
def _full_codegen(state: CodegenState) -> ast.AST:
fake_value = state.fake_value
assert isinstance(fake_value, torch.Tensor)
Expand Down
2 changes: 1 addition & 1 deletion helion/language/debug_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _(*args: object, origin: Origin, **kwargs: object) -> TypeInfo:
return LiteralType(origin, None)


@_decorators.codegen(breakpoint)
@_decorators.codegen(breakpoint, "triton")
def _(state: CodegenState) -> None:
state.add_statement("breakpoint()")

Expand Down
2 changes: 1 addition & 1 deletion helion/language/device_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _(*args: object, origin: Origin, **kwargs: object) -> TypeInfo:
return NoType(origin)


@_decorators.codegen(device_print)
@_decorators.codegen(device_print, "triton")
def _(state: CodegenState) -> None:
prefix = state.proxy_arg(0)
call_args: list[ast.AST] = [create(ast.Constant, value=prefix)]
Expand Down
2 changes: 1 addition & 1 deletion helion/language/inline_asm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _(
return torch.empty(broadcast_shape, dtype=dtypes[0], device=env.device)


@_decorators.codegen(inline_asm_elementwise)
@_decorators.codegen(inline_asm_elementwise, "triton")
def _(state: CodegenState) -> ast.AST | list[ast.AST]:
# Get arguments
asm_str = state.proxy_arg(0)
Expand Down
2 changes: 1 addition & 1 deletion helion/language/inline_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def _emit_output_assertions(
)


@_decorators.codegen(inline_triton)
@_decorators.codegen(inline_triton, "triton")
def _(state: CodegenState) -> ast.AST | list[ast.AST]:
triton_source = state.proxy_arg(0)
args_obj = state.proxy_arg(1)
Expand Down
4 changes: 2 additions & 2 deletions helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:
return hint < get_max_y_grid()


@_decorators.codegen(tile)
@_decorators.codegen(tile, "triton")
def _(state: CodegenState) -> ast.AST:
return _codegen_loop_helper(state)

Expand Down Expand Up @@ -753,7 +753,7 @@ def _(
return IterType(origin, result)


@_decorators.codegen(grid)
@_decorators.codegen(grid, "triton")
def _(state: CodegenState) -> ast.AST:
return _codegen_loop_helper(state)

Expand Down
Loading
Loading