From 95c6eb0b7caf244f7787d2dfbb0c7b1121f95aed Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Nov 2025 08:24:20 -0800 Subject: [PATCH] Refactor _decorators.codegen to allow multiple backends stack-info: PR: https://github.com/pytorch/helion/pull/1099, branch: jansel/stack/220 --- helion/_compiler/compile_environment.py | 2 ++ helion/_compiler/generate_ast.py | 24 +++++++++++++++++------- helion/_compiler/inductor_lowering.py | 12 ++++++++---- helion/exc.py | 7 +++++++ helion/language/_decorators.py | 13 +++++++------ helion/language/_tracing_ops.py | 22 +++++++++++----------- helion/language/atomic_ops.py | 16 ++++++++-------- helion/language/constexpr.py | 2 +- helion/language/creation_ops.py | 2 +- helion/language/debug_ops.py | 2 +- helion/language/device_print.py | 2 +- helion/language/inline_asm_ops.py | 2 +- helion/language/inline_triton_ops.py | 2 +- helion/language/loops.py | 4 ++-- helion/language/matmul_ops.py | 2 +- helion/language/memory_ops.py | 4 ++-- helion/language/random_ops.py | 2 +- helion/language/reduce_ops.py | 2 +- helion/language/scan_ops.py | 2 +- helion/language/signal_wait.py | 4 ++-- helion/language/tile_ops.py | 10 +++++----- helion/language/tunable_ops.py | 4 ++-- helion/language/view_ops.py | 6 +++--- 23 files changed, 86 insertions(+), 62 deletions(-) diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 8dcf59e6f..26dd19859 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -79,6 +79,8 @@ def __init__(self, device: torch.device, settings: Settings) -> None: super().__init__() self.device = device self.settings = settings + # TODO(jansel): make backend configurable + self.backend = "triton" self.shape_env = ShapeEnv( specialize_zero_one=True, duck_shape=False, diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index c1b58e5a9..51681e9a1 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -272,7 +272,13 @@ def visit_For(self, node: ast.For) -> ast.AST | None: assert fn_node._type_info is not None fn = fn_node._type_info.proxy() assert is_api_func(fn) - assert fn._codegen is not None + env = CompileEnvironment.current() + codegen_fn = fn._codegen.get(env.backend) + if codegen_fn is None: + raise exc.BackendImplementationMissing( + env.backend, + f"codegen for API function {fn.__qualname__}", + ) bound = fn._signature.bind(*args, **kwargs) bound.apply_defaults() @@ -285,7 +291,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None: ast_args=None, # pyright: ignore[reportArgumentType] ) - fn._codegen(state) + codegen_fn(state) assert node._root_id is not None codegen_call_with_graph( self, @@ -376,11 +382,15 @@ def visit_Call(self, node: ast.Call) -> ast.AST: [x.from_config(self.device_function.config) for x in block_infos] ) ) - elif ( - isinstance(fn_type_info := func_node._type_info, CallableType) - and is_api_func(api := fn_type_info.value) - and api._codegen is not None + elif isinstance(fn_type_info := func_node._type_info, CallableType) and ( + is_api_func(api := fn_type_info.value) ): + codegen_fn = api._codegen.get(env.backend) + if codegen_fn is None: + raise exc.BackendImplementationMissing( + env.backend, + f"codegen for API function {api.__qualname__}", + ) ast_args = [] ast_kwargs = {} proxy_args = [] @@ -401,7 +411,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST: proxy_params = api._signature.bind(*proxy_args, **proxy_kwargs) ast_params.apply_defaults() proxy_params.apply_defaults() - return api._codegen( # pyright: ignore[reportReturnType] + return codegen_fn( # pyright: ignore[reportReturnType] CodegenState( self, None, diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 49db6d664..0931dd44b 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -626,7 +626,6 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: self.buffer.data.inner_fn(indices, reduction_indices) ) - from .. import exc from .generate_ast import GenerateAST if not isinstance(ctx.cg, GenerateAST): @@ -744,14 +743,19 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: ast_args = [*map_arg(node.args, lambda arg: ctx.env[arg])] proxy_args = [*map_arg(node.args, lambda arg: arg.meta["val"])] - assert self.api_func._codegen is not None - from .. import exc + env = CompileEnvironment.current() + codegen_fn = self.api_func._codegen.get(env.backend) + if codegen_fn is None: + raise exc.BackendImplementationMissing( + env.backend, + f"codegen for API function {self.api_func.__qualname__}", + ) from .generate_ast import GenerateAST if not isinstance(ctx.cg, GenerateAST): raise exc.NotAllowedInHelperFunction - return self.api_func._codegen( + return codegen_fn( CodegenState( ctx.cg, fx_node=node, diff --git a/helion/exc.py b/helion/exc.py index 6ce6b1ea5..fbc20b234 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -52,6 +52,13 @@ class AutotuneError(BaseError): message = "{0}" +class BackendImplementationMissing(BaseError): + message = "Backend '{backend}' is missing required implementation: {detail}" + + def __init__(self, backend: str, detail: str) -> None: + super().__init__(backend=backend, detail=detail) + + class CacheAssertionError(BaseError): message = "Expected cache hit for kernel '{0}', but got cache miss. See stderr for diagnostic information." diff --git a/helion/language/_decorators.py b/helion/language/_decorators.py index ca8d653de..89a073f63 100644 --- a/helion/language/_decorators.py +++ b/helion/language/_decorators.py @@ -55,7 +55,7 @@ class APIFunc(Protocol): _cache_type: Whether to cache the type information for repeated calls. _type_function: A callable that determines the return type of this function during type propagation phase. - _codegen: A callable that generates the device code for this function. + _codegen: Mapping of backend names to callables that generate device code. _fake_fn: A callable that provides a "fake" implementation used during tracing and compilation. _prepare_args: A callable that preprocesses the arguments before they're @@ -72,7 +72,7 @@ class APIFunc(Protocol): _tiles_as_sizes: bool _cache_type: bool _type_function: Callable[..., TypeInfo] | None - _codegen: Callable[[CodegenState], object] | None + _codegen: dict[str, Callable[[CodegenState], object]] _fake_fn: Callable[..., object] | None _prepare_args: Callable[[tuple[object, ...]], tuple[object, ...]] _get_masked_value: Callable[[torch.fx.Node], float | bool | None] | None @@ -189,7 +189,7 @@ def wrapper(*args: object, **kwargs: object) -> object: api._prepare_args = no_op_prepare_args api._cache_type = cache_type api._type_function = None - api._codegen = None + api._codegen = {} api._fake_fn = None api._get_masked_value = None api._to_device_ir = None @@ -254,15 +254,16 @@ def _impl( def codegen( original_fn: Callable[..., object], + backend: str, ) -> _NoReturnDecorator[object]: def _impl(codegen_fn: Callable[[CodegenState], object]) -> Callable[..., Never]: assert is_api_func(original_fn), ( f"{type_propagation.__qualname__} can only be used on API functions" ) - assert original_fn._codegen is None, ( - "codegen can only be used once per function" + assert backend not in original_fn._codegen, ( + f"codegen already registered for backend {backend!r}" ) - original_fn._codegen = codegen_fn + original_fn._codegen[backend] = codegen_fn return _no_call return _impl # pyright: ignore[reportReturnType] diff --git a/helion/language/_tracing_ops.py b/helion/language/_tracing_ops.py index a7bf1a0f7..0d5939b35 100644 --- a/helion/language/_tracing_ops.py +++ b/helion/language/_tracing_ops.py @@ -40,7 +40,7 @@ def _get_symnode(debug_name: str) -> int: raise AssertionError("this should never be called") -@_decorators.codegen(_get_symnode) +@_decorators.codegen(_get_symnode, "triton") def _(state: CodegenState) -> ast.AST: val = state.fx_node.meta["val"] # pyright: ignore[reportOptionalMemberAccess] @@ -69,7 +69,7 @@ def _host_tensor(debug_name: str) -> torch.Tensor: raise AssertionError("this should never be called") -@_decorators.codegen(_host_tensor) +@_decorators.codegen(_host_tensor, "triton") def _(state: CodegenState) -> ast.AST: return expr_from_string("_host_tensor") # should be unused @@ -83,7 +83,7 @@ def _for_loop( raise AssertionError("this should never be called") -@_decorators.codegen(_for_loop) +@_decorators.codegen(_for_loop, "triton") def _(state: CodegenState) -> None: return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue] @@ -100,7 +100,7 @@ def _while_loop( raise AssertionError("this should never be called") -@_decorators.codegen(_while_loop) +@_decorators.codegen(_while_loop, "triton") def _(state: CodegenState) -> None: return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue] @@ -112,7 +112,7 @@ def _if(test: object, graph_id: int, args: list[object]) -> list[object]: raise AssertionError("this should never be called") -@_decorators.codegen(_if) +@_decorators.codegen(_if, "triton") def _(state: CodegenState) -> None: return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue] @@ -139,7 +139,7 @@ def _(lhs: object, rhs: object) -> object: return torch.empty_like(lhs) -@_decorators.codegen(_phi) +@_decorators.codegen(_phi, "triton") def _(state: CodegenState) -> ast.Name: lhs = state.ast_arg(0) assert isinstance(lhs, ast.Name), lhs @@ -180,7 +180,7 @@ def _and(left: object, right: object) -> object: raise NotInsideKernel -@_decorators.codegen(_and) +@_decorators.codegen(_and, "triton") def _(state: CodegenState) -> None: return expr_from_string( "{lhs} and {rhs}", lhs=state.ast_arg(0), rhs=state.ast_arg(1) @@ -233,7 +233,7 @@ def _(left: object, right: object) -> object: return env.shape_env.create_unbacked_symbool() -@_decorators.codegen(_or) +@_decorators.codegen(_or, "triton") def _(state: CodegenState) -> None: return expr_from_string( "{lhs} or {rhs}", lhs=state.ast_arg(0), rhs=state.ast_arg(1) @@ -258,7 +258,7 @@ def _(left: object) -> object: return env.shape_env.create_unbacked_symbool() -@_decorators.codegen(_not) +@_decorators.codegen(_not, "triton") def _(state: CodegenState) -> ast.AST: return expr_from_string( "not {lhs}", @@ -289,7 +289,7 @@ def _(tensor: torch.Tensor, other: float) -> torch.Tensor: return torch.empty_like(tensor) -@_decorators.codegen(_mask_to) +@_decorators.codegen(_mask_to, "triton") def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) assert isinstance(tensor, torch.Tensor) @@ -351,7 +351,7 @@ def _(value: _T) -> _T: raise NotImplementedError(f"Unsupported type for _new_var: {type(value)}") -@_decorators.codegen(_new_var) +@_decorators.codegen(_new_var, "triton") def _(state: CodegenState) -> ast.AST: value = state.ast_arg(0) assert isinstance(value, ast.AST) diff --git a/helion/language/atomic_ops.py b/helion/language/atomic_ops.py index 417daf25d..2aa16331b 100644 --- a/helion/language/atomic_ops.py +++ b/helion/language/atomic_ops.py @@ -263,7 +263,7 @@ def apply(t: torch.Tensor, idx_tuple: tuple, v: object) -> None: return prev -@_decorators.codegen(atomic_add) +@_decorators.codegen(atomic_add, "triton") def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_add", state, _to_ast_values([value_expr])) @@ -343,7 +343,7 @@ def _( return prev -@_decorators.codegen(atomic_xchg) +@_decorators.codegen(atomic_xchg, "triton") def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_xchg", state, _to_ast_values([value_expr])) @@ -420,7 +420,7 @@ def _( return prev -@_decorators.codegen(atomic_and) +@_decorators.codegen(atomic_and, "triton") def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_and", state, _to_ast_values([value_expr])) @@ -494,7 +494,7 @@ def _( return prev -@_decorators.codegen(atomic_or) +@_decorators.codegen(atomic_or, "triton") def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_or", state, _to_ast_values([value_expr])) @@ -568,7 +568,7 @@ def _( return prev -@_decorators.codegen(atomic_xor) +@_decorators.codegen(atomic_xor, "triton") def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_xor", state, _to_ast_values([value_expr])) @@ -634,7 +634,7 @@ def apply(t: torch.Tensor, idx: tuple, v: object) -> None: _ref_apply(target, index, apply, value) -@_decorators.codegen(atomic_max) +@_decorators.codegen(atomic_max, "triton") def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_max", state, _to_ast_values([value_expr])) @@ -709,7 +709,7 @@ def _( return prev -@_decorators.codegen(atomic_min) +@_decorators.codegen(atomic_min, "triton") def _(state: CodegenState) -> ast.AST: value_expr = state.ast_args[2] return _codegen_common("atomic_min", state, _to_ast_values([value_expr])) @@ -810,7 +810,7 @@ def _( return prev -@_decorators.codegen(atomic_cas) +@_decorators.codegen(atomic_cas, "triton") def _(state: CodegenState) -> ast.AST: exp_expr = state.ast_args[2] val_expr = state.ast_args[3] diff --git a/helion/language/constexpr.py b/helion/language/constexpr.py index 21d0de5ab..8c73171f5 100644 --- a/helion/language/constexpr.py +++ b/helion/language/constexpr.py @@ -94,7 +94,7 @@ def handle_symint(symint: torch.SymInt) -> int: return TypeInfo.from_example(specialized, origin=origin) -@_decorators.codegen(specialize) +@_decorators.codegen(specialize, "triton") def _(state: CodegenState) -> ast.AST: value = state.proxy_arg(0) specialized = _convert_specializable(value) diff --git a/helion/language/creation_ops.py b/helion/language/creation_ops.py index 52d897d22..ef12d3168 100644 --- a/helion/language/creation_ops.py +++ b/helion/language/creation_ops.py @@ -128,7 +128,7 @@ def _full_fake( ) -@_decorators.codegen(full) +@_decorators.codegen(full, "triton") def _full_codegen(state: CodegenState) -> ast.AST: fake_value = state.fake_value assert isinstance(fake_value, torch.Tensor) diff --git a/helion/language/debug_ops.py b/helion/language/debug_ops.py index bb900ac2f..849a40970 100644 --- a/helion/language/debug_ops.py +++ b/helion/language/debug_ops.py @@ -47,7 +47,7 @@ def _(*args: object, origin: Origin, **kwargs: object) -> TypeInfo: return LiteralType(origin, None) -@_decorators.codegen(breakpoint) +@_decorators.codegen(breakpoint, "triton") def _(state: CodegenState) -> None: state.add_statement("breakpoint()") diff --git a/helion/language/device_print.py b/helion/language/device_print.py index 99d9f356f..394f6c366 100644 --- a/helion/language/device_print.py +++ b/helion/language/device_print.py @@ -67,7 +67,7 @@ def _(*args: object, origin: Origin, **kwargs: object) -> TypeInfo: return NoType(origin) -@_decorators.codegen(device_print) +@_decorators.codegen(device_print, "triton") def _(state: CodegenState) -> None: prefix = state.proxy_arg(0) call_args: list[ast.AST] = [create(ast.Constant, value=prefix)] diff --git a/helion/language/inline_asm_ops.py b/helion/language/inline_asm_ops.py index 09923b932..dcb834a36 100644 --- a/helion/language/inline_asm_ops.py +++ b/helion/language/inline_asm_ops.py @@ -148,7 +148,7 @@ def _( return torch.empty(broadcast_shape, dtype=dtypes[0], device=env.device) -@_decorators.codegen(inline_asm_elementwise) +@_decorators.codegen(inline_asm_elementwise, "triton") def _(state: CodegenState) -> ast.AST | list[ast.AST]: # Get arguments asm_str = state.proxy_arg(0) diff --git a/helion/language/inline_triton_ops.py b/helion/language/inline_triton_ops.py index 76eac1763..c47219bbc 100644 --- a/helion/language/inline_triton_ops.py +++ b/helion/language/inline_triton_ops.py @@ -337,7 +337,7 @@ def _emit_output_assertions( ) -@_decorators.codegen(inline_triton) +@_decorators.codegen(inline_triton, "triton") def _(state: CodegenState) -> ast.AST | list[ast.AST]: triton_source = state.proxy_arg(0) args_obj = state.proxy_arg(1) diff --git a/helion/language/loops.py b/helion/language/loops.py index 304c6cfb7..506391850 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -460,7 +460,7 @@ def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool: return hint < get_max_y_grid() -@_decorators.codegen(tile) +@_decorators.codegen(tile, "triton") def _(state: CodegenState) -> ast.AST: return _codegen_loop_helper(state) @@ -753,7 +753,7 @@ def _( return IterType(origin, result) -@_decorators.codegen(grid) +@_decorators.codegen(grid, "triton") def _(state: CodegenState) -> ast.AST: return _codegen_loop_helper(state) diff --git a/helion/language/matmul_ops.py b/helion/language/matmul_ops.py index 11f6e5e4f..07a554ef3 100644 --- a/helion/language/matmul_ops.py +++ b/helion/language/matmul_ops.py @@ -207,7 +207,7 @@ def _( return torch.empty(result_shape, dtype=resolved_out_dtype, device=mat1.device) -@_decorators.codegen(dot) +@_decorators.codegen(dot, "triton") def _(state: CodegenState) -> object: # Get the AST representations of our arguments lhs_ast = state.ast_arg(0) diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index cf281a30e..88bad4532 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -86,7 +86,7 @@ def _( return None -@_decorators.codegen(store) +@_decorators.codegen(store, "triton") def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) @@ -245,7 +245,7 @@ def _( raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}") -@_decorators.codegen(load) +@_decorators.codegen(load, "triton") def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) subscript = state.proxy_arg(1) diff --git a/helion/language/random_ops.py b/helion/language/random_ops.py index bb1403c20..4b65b6154 100644 --- a/helion/language/random_ops.py +++ b/helion/language/random_ops.py @@ -69,7 +69,7 @@ def _rand_fake( ) -@_decorators.codegen(rand) +@_decorators.codegen(rand, "triton") def _rand_codegen(state: CodegenState) -> ast.AST: """ Generate tl.rand() code with global indices for deterministic RNG per element. diff --git a/helion/language/reduce_ops.py b/helion/language/reduce_ops.py index 27d7ef92f..0412dbe7a 100644 --- a/helion/language/reduce_ops.py +++ b/helion/language/reduce_ops.py @@ -451,7 +451,7 @@ def _( return _fake_reduce_tensor(input_tensor, dim, keep_dims) -@_decorators.codegen(_reduce) +@_decorators.codegen(_reduce, "triton") def _(state: CodegenState) -> ast.AST | list[ast.AST]: """Generate code for reduce with combine function.""" diff --git a/helion/language/scan_ops.py b/helion/language/scan_ops.py index 18dec8af6..554c6eb73 100644 --- a/helion/language/scan_ops.py +++ b/helion/language/scan_ops.py @@ -327,7 +327,7 @@ def _( return torch.empty_like(input_tensor) -@_decorators.codegen(_associative_scan) +@_decorators.codegen(_associative_scan, "triton") def _(state: CodegenState) -> ast.AST | list[ast.AST]: """Generate code for associative scan with combine function.""" diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py index 9392f7158..43386daca 100644 --- a/helion/language/signal_wait.py +++ b/helion/language/signal_wait.py @@ -92,7 +92,7 @@ def _( return None -@_decorators.codegen(wait) +@_decorators.codegen(wait, "triton") def _(state: CodegenState) -> ast.AST: import ast @@ -256,7 +256,7 @@ def _( return signal_pad.new_empty(shape) -@_decorators.codegen(signal) +@_decorators.codegen(signal, "triton") def _(state: CodegenState) -> ast.AST: import ast diff --git a/helion/language/tile_ops.py b/helion/language/tile_ops.py index 43f9e5f6c..11bace3bd 100644 --- a/helion/language/tile_ops.py +++ b/helion/language/tile_ops.py @@ -53,7 +53,7 @@ def _(tile: torch.SymInt) -> torch.Tensor: return torch.empty([tile], dtype=env.settings.index_dtype, device=env.device) -@_decorators.codegen(tile_index) +@_decorators.codegen(tile_index, "triton") def _(state: CodegenState) -> ast.AST: index = _disable_flatten_get_tile(state.proxy_arg(0)) return expr_from_string(state.codegen.index_var(index)) @@ -97,7 +97,7 @@ def _disable_flatten_get_tile(tile: object) -> int: return index -@_decorators.codegen(tile_begin) +@_decorators.codegen(tile_begin, "triton") def _(state: CodegenState) -> ast.AST: index = _disable_flatten_get_tile(state.proxy_arg(0)) return expr_from_string(state.codegen.offset_var(index)) @@ -129,7 +129,7 @@ def _(tile: torch.SymInt) -> torch.SymInt: return result -@_decorators.codegen(tile_end) +@_decorators.codegen(tile_end, "triton") def _(state: CodegenState) -> ast.AST: index = _disable_flatten_get_tile(state.proxy_arg(0)) offset_var = state.codegen.offset_var(index) @@ -200,7 +200,7 @@ def _(tile: torch.SymInt) -> torch.SymInt: return result -@_decorators.codegen(tile_count) +@_decorators.codegen(tile_count, "triton") def _(state: CodegenState) -> ast.AST: index = _disable_flatten_get_tile(state.proxy_arg(0)) # Use device loop metadata to get end and block size @@ -245,7 +245,7 @@ def _(tile: torch.SymInt) -> torch.SymInt: return result -@_decorators.codegen(tile_id) +@_decorators.codegen(tile_id, "triton") def _(state: CodegenState) -> ast.AST: index = _disable_flatten_get_tile(state.proxy_arg(0)) offset = state.codegen.offset_var(index) diff --git a/helion/language/tunable_ops.py b/helion/language/tunable_ops.py index 4b60ec382..b030216ce 100644 --- a/helion/language/tunable_ops.py +++ b/helion/language/tunable_ops.py @@ -111,7 +111,7 @@ def _block_id_from_state(state: CodegenState) -> int: return block_id -@_decorators.codegen(register_block_size) +@_decorators.codegen(register_block_size, "triton") def _(state: CodegenState) -> ast.AST: env = CompileEnvironment.current() block_size = env.config_spec.block_sizes.config_get( @@ -176,7 +176,7 @@ def _register_tunable_type( return NumericType.subtype(python_type).new_unbacked(origin) -@_decorators.codegen(register_tunable) +@_decorators.codegen(register_tunable, "triton") def _register_tunable_codegen(state: CodegenState) -> ast.AST: name = state.proxy_arg(0) assert isinstance(name, str) diff --git a/helion/language/view_ops.py b/helion/language/view_ops.py index 7c6e52d88..863877be5 100644 --- a/helion/language/view_ops.py +++ b/helion/language/view_ops.py @@ -87,7 +87,7 @@ def _(tensor: torch.Tensor, index: list[object]) -> torch.Tensor: return tensor.new_empty(output_size) -@_decorators.codegen(subscript) +@_decorators.codegen(subscript, "triton") def _(state: CodegenState) -> ast.AST: output_keys = [] for val in state.proxy_arg(1): # pyright: ignore[reportGeneralTypeIssues] @@ -144,7 +144,7 @@ def _(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ) -@_decorators.codegen(split) +@_decorators.codegen(split, "triton") def _(state: CodegenState) -> list[ast.AST]: split_call = expr_from_string("tl.split({tensor})", tensor=state.ast_arg(0)) return [ @@ -192,7 +192,7 @@ def _(tensor0: torch.Tensor, tensor1: torch.Tensor) -> torch.Tensor: return tensor0.new_empty([*broadcast_shape, 2]) -@_decorators.codegen(join) +@_decorators.codegen(join, "triton") def _(state: CodegenState) -> ast.AST: return expr_from_string( "tl.join({tensor0}, {tensor1})",