Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/api/language.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,36 @@ Executes target-specific inline assembly on elements of one or more tensors with

Embeds small Triton code snippets directly inside a Helion kernel. Common indentation is removed automatically, placeholders are replaced using ``str.format`` with tuple or dict arguments, and the final line in the snippet becomes the return value. Provide tensors (or tuples of tensors) via ``output_like`` so Helion knows the type of the return value.

### triton_kernel()

```{eval-rst}
.. autofunction:: triton_kernel
```

Define (once) and call a ``@triton.jit`` function from Helion device code.

- Accepts either:
- a source string containing a single Triton function definition,
- a function name string referring to a ``@triton.jit`` function in the kernel’s module, or
- a Python function object (or Triton JITFunction; unwrapped via ``.fn``).
- The function is emitted at module scope once and then invoked from the kernel body.
- Pass ``output_like`` tensors for shape/dtype checks identical to ``inline_triton``.

Example (by name):

```python
@triton.jit
def add_pairs(a, b):
return a + b

@helion.kernel()
def k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.shape):
out[tile] = hl.triton_kernel("add_pairs", args=(x[tile], y[tile]), output_like=x[tile])
return out
```

## Tensor Creation

### zeros()
Expand Down
8 changes: 6 additions & 2 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,15 @@ def to_fake(self, obj: object, origin: Origin) -> object:
# Handle functions and Kernel objects
from ..runtime.kernel import Kernel

if isinstance(obj, (types.FunctionType, Kernel)):
if isinstance(obj, (types.FunctionType, Kernel)) or hasattr(obj, "fn"):
from .helper_function import extract_helper_function
from .lift_closures import lift_closures

fn = extract_helper_function(obj)
# If Triton JITFunction is passed, try to unwrap to underlying Python function
if hasattr(obj, "fn") and isinstance(obj.fn, types.FunctionType):
fn = obj.fn
else:
fn = extract_helper_function(obj)
return lift_closures(fn, origin)
# Handle GraphModule - treat it like a function
if isinstance(obj, torch.fx.GraphModule):
Expand Down
2 changes: 2 additions & 0 deletions helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, func: HostFunction, config: Config) -> None:
# Initialize our attributes
self.host_function = func
self.host_statements: list[ast.AST] = []
self.module_statements: list[ast.stmt] = []
self.statements_stack: list[list[ast.AST]] = [self.host_statements]
self.on_device = False
self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = (
Expand Down Expand Up @@ -502,6 +503,7 @@ def generate_ast(
result = ast.Module(
[
*func.codegen_imports(),
*codegen.module_statements,
*codegen.device_function.codegen_helper_functions(),
*kernel_def,
host_def,
Expand Down
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .device_print import device_print as device_print
from .inline_asm_ops import inline_asm_elementwise as inline_asm_elementwise
from .inline_triton_ops import inline_triton as inline_triton
from .inline_triton_ops import triton_kernel as triton_kernel
from .loops import grid as grid
from .loops import static_range as static_range
from .loops import tile as tile
Expand Down
214 changes: 213 additions & 1 deletion helion/language/inline_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
from collections.abc import Mapping
from collections.abc import Sequence
import inspect
import textwrap
from typing import TYPE_CHECKING
from typing import TypeVar
Expand All @@ -17,14 +18,18 @@
from .._compiler.ast_extension import create
from .._compiler.ast_extension import expr_from_string
from .._compiler.ast_extension import statement_from_string
from .._compiler.host_function import HostFunction
from .._compiler.output_header import SOURCE_MODULE
from . import _decorators

if TYPE_CHECKING:
from types import FunctionType

from .._compiler.inductor_lowering import CodegenState

_T = TypeVar("_T")

__all__ = ["inline_triton"]
__all__ = ["inline_triton", "triton_kernel"]


@has_side_effect
Expand Down Expand Up @@ -266,6 +271,111 @@ def _collect_output_metadata(
)


def _ensure_triton_jit_decorator(func_def: ast.FunctionDef) -> ast.FunctionDef:
has_jit = any(
(isinstance(d, ast.Attribute) and d.attr == "jit")
or (isinstance(d, ast.Name) and d.id == "triton")
or (
isinstance(d, ast.Call)
and isinstance(d.func, ast.Attribute)
and d.func.attr == "jit"
)
for d in func_def.decorator_list
)
if has_jit:
return func_def
func_def.decorator_list.insert(0, cast("ast.expr", expr_from_string("triton.jit")))
return func_def


def _get_or_add_triton_function_preamble(
state: CodegenState, triton_source_or_fn: object
) -> str:
"""
Parse a @triton.jit function definition from source and add it once to the
device function preamble. Returns the (possibly renamed) function name to call.
"""
if isinstance(triton_source_or_fn, str):
candidate = textwrap.dedent(triton_source_or_fn).strip()
# If looks like a bare identifier (function name), resolve from kernel globals
if (
candidate
and candidate.isidentifier()
and "\n" not in candidate
and "def " not in candidate
):
hf = HostFunction.current()
# Ensure SOURCE_MODULE is registered
hf.global_scope_origin("")
module_obj = hf.global_imports[SOURCE_MODULE].value
module_scope = cast("dict[str, object]", module_obj.__dict__)
fn_obj = module_scope[candidate]
func_obj = fn_obj if inspect.isfunction(fn_obj) else fn_obj.fn # type: ignore[attr-defined]
func_obj_typed: FunctionType = cast("FunctionType", func_obj)
try:
src = textwrap.dedent(inspect.getsource(func_obj_typed)).strip()
except OSError as exc_value:
raise exc.InvalidAPIUsage(
f"Could not get source for Triton function '{candidate}': {exc_value}"
) from exc_value
base_name_hint = func_obj_typed.__name__
else:
src = candidate
base_name_hint = None
else:
# Expect a function object (already unwrapped by to_fake)
func_obj = triton_source_or_fn
func_obj_typed: FunctionType = cast("FunctionType", func_obj)
try:
src = textwrap.dedent(inspect.getsource(func_obj_typed)).strip()
except OSError as exc_value:
raise exc.InvalidAPIUsage(
f"Could not get source for Triton function: {exc_value}"
) from exc_value
base_name_hint = func_obj_typed.__name__
if not src:
raise exc.InvalidAPIUsage("triton_kernel source must contain a function")

try:
module = ast.parse(src)
except SyntaxError as exc_value:
raise exc.InvalidAPIUsage(
f"Failed to parse triton_kernel source: {exc_value}"
) from exc_value

func_defs = [node for node in module.body if isinstance(node, ast.FunctionDef)]
if len(func_defs) != 1:
raise exc.InvalidAPIUsage(
f"triton_kernel expects exactly one function definition, found {len(func_defs)}"
)
fn_def = cast("ast.FunctionDef", convert(func_defs[0]))
fn_def = _ensure_triton_jit_decorator(fn_def)

# Cache to avoid duplicate definitions
cache_name = "_added_triton_kernel_defs"
added: dict[str, str] = getattr(state.device_function, cache_name, {})

# Use the function name plus source as key to avoid collisions on same name different body
# Use function name if available, else the parsed name
parsed_name = fn_def.name
if base_name_hint and isinstance(base_name_hint, str):
parsed_name = base_name_hint

key = f"{parsed_name}:{src}"
if key in added:
return added[key]

# Ensure uniqueness of function name in module scope
unique_name = state.device_function.new_var(parsed_name)
fn_def.name = unique_name

# Define the Triton function at module scope to avoid nested jit def issues
state.codegen.module_statements.append(fn_def)
added[key] = unique_name
setattr(state.device_function, cache_name, added)
return unique_name


def _emit_output_assertions(
state: CodegenState,
result_name: str,
Expand Down Expand Up @@ -383,3 +493,105 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]:
return [expr_from_string(f"{result_name}[{i}]") for i in range(len(dtypes))]

return expr_from_string(result_name)


@_decorators.api(is_device_only=True, allow_host_tensor=True)
def triton_kernel(
triton_source_or_fn: object,
args: Sequence[object] | Mapping[str, object],
output_like: _T,
) -> _T:
"""
Define (once) and call a @triton.jit function from Helion device code.

Args:
triton_source_or_fn: Source for a single @triton.jit function definition,
or a Python function object defining a @triton.jit kernel.
args: Positional or keyword placeholders that will be substituted via
name resolution of Helion variables.
output_like: Example tensor(s) describing the expected outputs for shape/dtype checks.
"""
raise exc.NotInsideKernel


@_decorators.register_fake(triton_kernel)
def _(
triton_source_or_fn: object,
args: object,
output_like: object,
) -> object:
if not (
isinstance(triton_source_or_fn, str) or inspect.isfunction(triton_source_or_fn)
):
raise exc.InvalidAPIUsage(
f"triton_kernel expects a string source or a function, got {type(triton_source_or_fn)}"
)
_validate_args(args)
return _fake_outputs(output_like)


@_decorators.codegen(triton_kernel, "triton")
def _(state: CodegenState) -> ast.AST | list[ast.AST]:
triton_source_or_fn = state.proxy_arg(0)
args_obj = state.proxy_arg(1)
output_like = state.proxy_arg(2)

if not (
isinstance(triton_source_or_fn, str) or inspect.isfunction(triton_source_or_fn)
):
raise exc.InvalidAPIUsage(
f"triton_kernel expects a string source or a function, got {type(triton_source_or_fn)}"
)
_validate_args(args_obj)

# Install the Triton function into preamble (once) and get the callable name
fn_name = _get_or_add_triton_function_preamble(state, triton_source_or_fn)

# Resolve argument names similar to inline_triton formatting
call_args_src = ""
if isinstance(state.ast_args[1], dict):
kw_pairs: list[str] = []
mapping = cast("Mapping[str, object]", args_obj)
for key, node in state.ast_args[1].items():
kw_pairs.append(f"{key}=" + _ensure_name(state, node, mapping[key]))
call_args_src = ", ".join(kw_pairs)
else:
if not isinstance(state.ast_args[1], (ast.List, ast.Tuple, list, tuple)):
raise exc.InvalidAPIUsage(
"triton_kernel expects a literal list/tuple for positional args"
)
arg_nodes = (
state.ast_args[1].elts
if isinstance(state.ast_args[1], (ast.List, ast.Tuple))
else list(state.ast_args[1])
)
names = [
_ensure_name(state, node, arg)
for node, arg in zip(
arg_nodes, cast("Sequence[object]", args_obj), strict=False
)
]
call_args_src = ", ".join(names)

call_expr = expr_from_string(f"{fn_name}({call_args_src})")

if output_like is None:
state.add_statement(create(ast.Expr, value=call_expr))
return create(ast.Constant, value=None)

result_name = state.device_function.new_var("triton_kernel_result")
assign = create(
ast.Assign,
targets=[create(ast.Name, id=result_name, ctx=ast.Store())],
value=call_expr,
)
state.add_statement(assign)

dtypes, output_nodes, is_multi = _collect_output_metadata(
output_like, state.ast_args[2]
)
_emit_output_assertions(state, result_name, dtypes, output_nodes, is_multi)

if is_multi:
return [expr_from_string(f"{result_name}[{i}]") for i in range(len(dtypes))]
return expr_from_string(result_name)
Loading
Loading