diff --git a/docs/api/language.md b/docs/api/language.md index 1b8e327b4..412e86ab4 100644 --- a/docs/api/language.md +++ b/docs/api/language.md @@ -143,6 +143,36 @@ Executes target-specific inline assembly on elements of one or more tensors with Embeds small Triton code snippets directly inside a Helion kernel. Common indentation is removed automatically, placeholders are replaced using ``str.format`` with tuple or dict arguments, and the final line in the snippet becomes the return value. Provide tensors (or tuples of tensors) via ``output_like`` so Helion knows the type of the return value. +### triton_kernel() + +```{eval-rst} +.. autofunction:: triton_kernel +``` + +Define (once) and call a ``@triton.jit`` function from Helion device code. + +- Accepts either: + - a source string containing a single Triton function definition, + - a function name string referring to a ``@triton.jit`` function in the kernel’s module, or + - a Python function object (or Triton JITFunction; unwrapped via ``.fn``). +- The function is emitted at module scope once and then invoked from the kernel body. +- Pass ``output_like`` tensors for shape/dtype checks identical to ``inline_triton``. + +Example (by name): + +```python +@triton.jit +def add_pairs(a, b): + return a + b + +@helion.kernel() +def k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.shape): + out[tile] = hl.triton_kernel("add_pairs", args=(x[tile], y[tile]), output_like=x[tile]) + return out +``` + ## Tensor Creation ### zeros() diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 09d63af11..0d01c0280 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -299,11 +299,15 @@ def to_fake(self, obj: object, origin: Origin) -> object: # Handle functions and Kernel objects from ..runtime.kernel import Kernel - if isinstance(obj, (types.FunctionType, Kernel)): + if isinstance(obj, (types.FunctionType, Kernel)) or hasattr(obj, "fn"): from .helper_function import extract_helper_function from .lift_closures import lift_closures - fn = extract_helper_function(obj) + # If Triton JITFunction is passed, try to unwrap to underlying Python function + if hasattr(obj, "fn") and isinstance(obj.fn, types.FunctionType): + fn = obj.fn + else: + fn = extract_helper_function(obj) return lift_closures(fn, origin) # Handle GraphModule - treat it like a function if isinstance(obj, torch.fx.GraphModule): diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index f07f41ded..00b8f8083 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -47,6 +47,7 @@ def __init__(self, func: HostFunction, config: Config) -> None: # Initialize our attributes self.host_function = func self.host_statements: list[ast.AST] = [] + self.module_statements: list[ast.stmt] = [] self.statements_stack: list[list[ast.AST]] = [self.host_statements] self.on_device = False self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = ( @@ -502,6 +503,7 @@ def generate_ast( result = ast.Module( [ *func.codegen_imports(), + *codegen.module_statements, *codegen.device_function.codegen_helper_functions(), *kernel_def, host_def, diff --git a/helion/language/__init__.py b/helion/language/__init__.py index c8324ebe5..27b3582b7 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -19,6 +19,7 @@ from .device_print import device_print as device_print from .inline_asm_ops import inline_asm_elementwise as inline_asm_elementwise from .inline_triton_ops import inline_triton as inline_triton +from .inline_triton_ops import triton_kernel as triton_kernel from .loops import grid as grid from .loops import static_range as static_range from .loops import tile as tile diff --git a/helion/language/inline_triton_ops.py b/helion/language/inline_triton_ops.py index c47219bbc..a1d7baa5e 100644 --- a/helion/language/inline_triton_ops.py +++ b/helion/language/inline_triton_ops.py @@ -3,6 +3,7 @@ import ast from collections.abc import Mapping from collections.abc import Sequence +import inspect import textwrap from typing import TYPE_CHECKING from typing import TypeVar @@ -17,14 +18,18 @@ from .._compiler.ast_extension import create from .._compiler.ast_extension import expr_from_string from .._compiler.ast_extension import statement_from_string +from .._compiler.host_function import HostFunction +from .._compiler.output_header import SOURCE_MODULE from . import _decorators if TYPE_CHECKING: + from types import FunctionType + from .._compiler.inductor_lowering import CodegenState _T = TypeVar("_T") -__all__ = ["inline_triton"] +__all__ = ["inline_triton", "triton_kernel"] @has_side_effect @@ -266,6 +271,111 @@ def _collect_output_metadata( ) +def _ensure_triton_jit_decorator(func_def: ast.FunctionDef) -> ast.FunctionDef: + has_jit = any( + (isinstance(d, ast.Attribute) and d.attr == "jit") + or (isinstance(d, ast.Name) and d.id == "triton") + or ( + isinstance(d, ast.Call) + and isinstance(d.func, ast.Attribute) + and d.func.attr == "jit" + ) + for d in func_def.decorator_list + ) + if has_jit: + return func_def + func_def.decorator_list.insert(0, cast("ast.expr", expr_from_string("triton.jit"))) + return func_def + + +def _get_or_add_triton_function_preamble( + state: CodegenState, triton_source_or_fn: object +) -> str: + """ + Parse a @triton.jit function definition from source and add it once to the + device function preamble. Returns the (possibly renamed) function name to call. + """ + if isinstance(triton_source_or_fn, str): + candidate = textwrap.dedent(triton_source_or_fn).strip() + # If looks like a bare identifier (function name), resolve from kernel globals + if ( + candidate + and candidate.isidentifier() + and "\n" not in candidate + and "def " not in candidate + ): + hf = HostFunction.current() + # Ensure SOURCE_MODULE is registered + hf.global_scope_origin("") + module_obj = hf.global_imports[SOURCE_MODULE].value + module_scope = cast("dict[str, object]", module_obj.__dict__) + fn_obj = module_scope[candidate] + func_obj = fn_obj if inspect.isfunction(fn_obj) else fn_obj.fn # type: ignore[attr-defined] + func_obj_typed: FunctionType = cast("FunctionType", func_obj) + try: + src = textwrap.dedent(inspect.getsource(func_obj_typed)).strip() + except OSError as exc_value: + raise exc.InvalidAPIUsage( + f"Could not get source for Triton function '{candidate}': {exc_value}" + ) from exc_value + base_name_hint = func_obj_typed.__name__ + else: + src = candidate + base_name_hint = None + else: + # Expect a function object (already unwrapped by to_fake) + func_obj = triton_source_or_fn + func_obj_typed: FunctionType = cast("FunctionType", func_obj) + try: + src = textwrap.dedent(inspect.getsource(func_obj_typed)).strip() + except OSError as exc_value: + raise exc.InvalidAPIUsage( + f"Could not get source for Triton function: {exc_value}" + ) from exc_value + base_name_hint = func_obj_typed.__name__ + if not src: + raise exc.InvalidAPIUsage("triton_kernel source must contain a function") + + try: + module = ast.parse(src) + except SyntaxError as exc_value: + raise exc.InvalidAPIUsage( + f"Failed to parse triton_kernel source: {exc_value}" + ) from exc_value + + func_defs = [node for node in module.body if isinstance(node, ast.FunctionDef)] + if len(func_defs) != 1: + raise exc.InvalidAPIUsage( + f"triton_kernel expects exactly one function definition, found {len(func_defs)}" + ) + fn_def = cast("ast.FunctionDef", convert(func_defs[0])) + fn_def = _ensure_triton_jit_decorator(fn_def) + + # Cache to avoid duplicate definitions + cache_name = "_added_triton_kernel_defs" + added: dict[str, str] = getattr(state.device_function, cache_name, {}) + + # Use the function name plus source as key to avoid collisions on same name different body + # Use function name if available, else the parsed name + parsed_name = fn_def.name + if base_name_hint and isinstance(base_name_hint, str): + parsed_name = base_name_hint + + key = f"{parsed_name}:{src}" + if key in added: + return added[key] + + # Ensure uniqueness of function name in module scope + unique_name = state.device_function.new_var(parsed_name) + fn_def.name = unique_name + + # Define the Triton function at module scope to avoid nested jit def issues + state.codegen.module_statements.append(fn_def) + added[key] = unique_name + setattr(state.device_function, cache_name, added) + return unique_name + + def _emit_output_assertions( state: CodegenState, result_name: str, @@ -383,3 +493,105 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]: return [expr_from_string(f"{result_name}[{i}]") for i in range(len(dtypes))] return expr_from_string(result_name) + + +@_decorators.api(is_device_only=True, allow_host_tensor=True) +def triton_kernel( + triton_source_or_fn: object, + args: Sequence[object] | Mapping[str, object], + output_like: _T, +) -> _T: + """ + Define (once) and call a @triton.jit function from Helion device code. + + Args: + triton_source_or_fn: Source for a single @triton.jit function definition, + or a Python function object defining a @triton.jit kernel. + args: Positional or keyword placeholders that will be substituted via + name resolution of Helion variables. + output_like: Example tensor(s) describing the expected outputs for shape/dtype checks. + """ + raise exc.NotInsideKernel + + +@_decorators.register_fake(triton_kernel) +def _( + triton_source_or_fn: object, + args: object, + output_like: object, +) -> object: + if not ( + isinstance(triton_source_or_fn, str) or inspect.isfunction(triton_source_or_fn) + ): + raise exc.InvalidAPIUsage( + f"triton_kernel expects a string source or a function, got {type(triton_source_or_fn)}" + ) + _validate_args(args) + return _fake_outputs(output_like) + + +@_decorators.codegen(triton_kernel, "triton") +def _(state: CodegenState) -> ast.AST | list[ast.AST]: + triton_source_or_fn = state.proxy_arg(0) + args_obj = state.proxy_arg(1) + output_like = state.proxy_arg(2) + + if not ( + isinstance(triton_source_or_fn, str) or inspect.isfunction(triton_source_or_fn) + ): + raise exc.InvalidAPIUsage( + f"triton_kernel expects a string source or a function, got {type(triton_source_or_fn)}" + ) + _validate_args(args_obj) + + # Install the Triton function into preamble (once) and get the callable name + fn_name = _get_or_add_triton_function_preamble(state, triton_source_or_fn) + + # Resolve argument names similar to inline_triton formatting + call_args_src = "" + if isinstance(state.ast_args[1], dict): + kw_pairs: list[str] = [] + mapping = cast("Mapping[str, object]", args_obj) + for key, node in state.ast_args[1].items(): + kw_pairs.append(f"{key}=" + _ensure_name(state, node, mapping[key])) + call_args_src = ", ".join(kw_pairs) + else: + if not isinstance(state.ast_args[1], (ast.List, ast.Tuple, list, tuple)): + raise exc.InvalidAPIUsage( + "triton_kernel expects a literal list/tuple for positional args" + ) + arg_nodes = ( + state.ast_args[1].elts + if isinstance(state.ast_args[1], (ast.List, ast.Tuple)) + else list(state.ast_args[1]) + ) + names = [ + _ensure_name(state, node, arg) + for node, arg in zip( + arg_nodes, cast("Sequence[object]", args_obj), strict=False + ) + ] + call_args_src = ", ".join(names) + + call_expr = expr_from_string(f"{fn_name}({call_args_src})") + + if output_like is None: + state.add_statement(create(ast.Expr, value=call_expr)) + return create(ast.Constant, value=None) + + result_name = state.device_function.new_var("triton_kernel_result") + assign = create( + ast.Assign, + targets=[create(ast.Name, id=result_name, ctx=ast.Store())], + value=call_expr, + ) + state.add_statement(assign) + + dtypes, output_nodes, is_multi = _collect_output_metadata( + output_like, state.ast_args[2] + ) + _emit_output_assertions(state, result_name, dtypes, output_nodes, is_multi) + + if is_multi: + return [expr_from_string(f"{result_name}[{i}]") for i in range(len(dtypes))] + return expr_from_string(result_name) diff --git a/test/test_triton_kernel.expected b/test/test_triton_kernel.expected new file mode 100644 index 000000000..94e71d05f --- /dev/null +++ b/test/test_triton_kernel.expected @@ -0,0 +1,160 @@ +This file is automatically generated by assertExpectedJournal calls in test_triton_kernel.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestTritonKernel.test_triton_kernel_multi_output) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_triton_kernel as _source_module + +@triton.jit +def pairwise_ops(a, b): + # src[test_triton_kernel.py:N]: sum_out = torch.empty_like(x) + sum_val = a + b + # src[test_triton_kernel.py:N]: prod_out = torch.empty_like(x) + prod_val = a * b + # src[test_triton_kernel.py:N]: for tile in hl.tile(x.shape): + return (sum_val, prod_val) + +@triton.jit +def _helion_k(x, y, sum_out, prod_out, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_triton_kernel.py:N]: x_val = x[tile] + x_val = tl.load(x + indices_0 * 1, None) + # src[test_triton_kernel.py:N]: y_val = y[tile] + y_val = tl.load(y + indices_0 * 1, None) + # src[test_triton_kernel.py:N]: sum_val, prod_val = hl.triton_kernel( + # src[test_triton_kernel.py:N]: "pairwise_ops", + # src[test_triton_kernel.py:N]: args=(x_val, y_val), + # src[test_triton_kernel.py:N-N]: ... + triton_kernel_result = pairwise_ops(x_val, y_val) + tl.static_assert(len(triton_kernel_result) == 2, 'inline_triton expected 2 outputs') + tl.static_assert(triton_kernel_result[0].dtype == tl.float32, 'inline_triton output 0 dtype mismatch; expected torch.float32') + tl.static_assert(triton_kernel_result[0].shape == x_val.shape, 'inline_triton output 0 shape mismatch') + tl.static_assert(triton_kernel_result[1].dtype == tl.float32, 'inline_triton output 1 dtype mismatch; expected torch.float32') + tl.static_assert(triton_kernel_result[1].shape == x_val.shape, 'inline_triton output 1 shape mismatch') + sum_val = triton_kernel_result[0] + prod_val = triton_kernel_result[1] + # src[test_triton_kernel.py:N]: sum_out[tile] = sum_val + tl.store(sum_out + indices_0 * 1, sum_val, None) + # src[test_triton_kernel.py:N]: prod_out[tile] = prod_val + tl.store(prod_out + indices_0 * 1, prod_val, None) + +def k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_triton_kernel.py:N]: sum_out = torch.empty_like(x) + sum_out = torch.empty_like(x) + # src[test_triton_kernel.py:N]: prod_out = torch.empty_like(x) + prod_out = torch.empty_like(x) + # src[test_triton_kernel.py:N]: for tile in hl.tile(x.shape): + _BLOCK_SIZE_0 = 32 + # src[test_triton_kernel.py:N]: for tile in hl.tile(x.shape): + # src[test_triton_kernel.py:N]: x_val = x[tile] + # src[test_triton_kernel.py:N]: y_val = y[tile] + # src[test_triton_kernel.py:N-N]: ... + _launcher(_helion_k, (triton.cdiv(64, _BLOCK_SIZE_0),), x, y, sum_out, prod_out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) + # src[test_triton_kernel.py:N]: return sum_out, prod_out + return (sum_out, prod_out) + +--- assertExpectedJournal(TestTritonKernel.test_triton_kernel_simple_add) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_triton_kernel as _source_module + +@triton.jit +def add_pairs(a, b): + # src[test_triton_kernel.py:N]: out = torch.empty_like(x) + return a + b + +@triton.jit +def _helion_triton_kernel_add_pairs(x, y, out, _BLOCK_SIZE_0: tl.constexpr): + # src[test_triton_kernel.py:N]: for tile in hl.tile(x.shape): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_triton_kernel.py:N]: x_val = x[tile] + x_val = tl.load(x + indices_0 * 1, None) + # src[test_triton_kernel.py:N]: y_val = y[tile] + y_val = tl.load(y + indices_0 * 1, None) + # src[test_triton_kernel.py:N]: result = hl.triton_kernel("add_pairs", args=(x_val, y_val), output_like=x_val) + triton_kernel_result = add_pairs(x_val, y_val) + tl.static_assert(triton_kernel_result.dtype == tl.float32, 'inline_triton output dtype mismatch; expected torch.float32') + tl.static_assert(triton_kernel_result.shape == x_val.shape, 'inline_triton output shape mismatch') + # src[test_triton_kernel.py:N]: out[tile] = result + tl.store(out + indices_0 * 1, triton_kernel_result, None) + +def triton_kernel_add_pairs(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_triton_kernel.py:N]: out = torch.empty_like(x) + out = torch.empty_like(x) + # src[test_triton_kernel.py:N]: for tile in hl.tile(x.shape): + _BLOCK_SIZE_0 = 32 + # src[test_triton_kernel.py:N]: for tile in hl.tile(x.shape): + # src[test_triton_kernel.py:N]: x_val = x[tile] + # src[test_triton_kernel.py:N]: y_val = y[tile] + # src[test_triton_kernel.py:N-N]: ... + _launcher(_helion_triton_kernel_add_pairs, (triton.cdiv(128, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) + # src[test_triton_kernel.py:N]: return out + return out + +--- assertExpectedJournal(TestTritonKernel.test_triton_kernel_tl_ops) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_triton_kernel as _source_module + +@triton.jit +def vector_mix_norm(a, b): + # src[test_triton_kernel.py:N]: out = torch.empty_like(x) + mixed = tl.where(a > b, a - b, a + b) + # src[test_triton_kernel.py:N]: for tile in hl.tile(x.shape): + l2sq = tl.sum(mixed * mixed) + # src[test_triton_kernel.py:N]: x_val = x[tile] + inv = tl.rsqrt(tl.maximum(l2sq, 1e-12)) + # src[test_triton_kernel.py:N]: y_val = y[tile] + return mixed * inv + +@triton.jit +def _helion_k(x, y, out, _BLOCK_SIZE_0: tl.constexpr): + # src[test_triton_kernel.py:N]: for tile in hl.tile(x.shape): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_triton_kernel.py:N]: x_val = x[tile] + x_val = tl.load(x + indices_0 * 1, None) + # src[test_triton_kernel.py:N]: y_val = y[tile] + y_val = tl.load(y + indices_0 * 1, None) + # src[test_triton_kernel.py:N]: out[tile] = hl.triton_kernel( + # src[test_triton_kernel.py:N]: "vector_mix_norm", + # src[test_triton_kernel.py:N]: args=(x_val, y_val), + # src[test_triton_kernel.py:N-N]: ... + triton_kernel_result = vector_mix_norm(x_val, y_val) + tl.static_assert(triton_kernel_result.dtype == tl.float32, 'inline_triton output dtype mismatch; expected torch.float32') + tl.static_assert(triton_kernel_result.shape == x_val.shape, 'inline_triton output shape mismatch') + tl.store(out + indices_0 * 1, triton_kernel_result, None) + +def k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_triton_kernel.py:N]: out = torch.empty_like(x) + out = torch.empty_like(x) + # src[test_triton_kernel.py:N]: for tile in hl.tile(x.shape): + _BLOCK_SIZE_0 = 32 + # src[test_triton_kernel.py:N]: for tile in hl.tile(x.shape): + # src[test_triton_kernel.py:N]: x_val = x[tile] + # src[test_triton_kernel.py:N]: y_val = y[tile] + # src[test_triton_kernel.py:N-N]: ... + _launcher(_helion_k, (triton.cdiv(96, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) + # src[test_triton_kernel.py:N]: return out + return out diff --git a/test/test_triton_kernel.py b/test/test_triton_kernel.py new file mode 100644 index 000000000..06437d028 --- /dev/null +++ b/test/test_triton_kernel.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +import helion +from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled +from helion._testing import TestCase +from helion._testing import code_and_output +import helion.language as hl + + +@triton.jit +def add_pairs(a, b): + return a + b + + +@triton.jit +def pairwise_ops(a, b): + sum_val = a + b + prod_val = a * b + return sum_val, prod_val + + +@triton.jit +def vector_mix_norm(a, b): + mixed = tl.where(a > b, a - b, a + b) + l2sq = tl.sum(mixed * mixed) + inv = tl.rsqrt(tl.maximum(l2sq, 1e-12)) + return mixed * inv + + +@helion.kernel(autotune_effort="none") +def triton_kernel_add_pairs(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.shape): + x_val = x[tile] + y_val = y[tile] + # Pass by function name string to avoid closures in kernels + result = hl.triton_kernel("add_pairs", args=(x_val, y_val), output_like=x_val) + out[tile] = result + return out + + +class TestTritonKernel(RefEagerTestDisabled, TestCase): + def test_triton_kernel_simple_add(self) -> None: + x = torch.randn(128, device=DEVICE, dtype=torch.float32) + y = torch.randn_like(x) + code, result = code_and_output(triton_kernel_add_pairs, (x, y)) + self.assertIn("@triton.jit", code) + self.assertIn("add_pairs", code) + torch.testing.assert_close(result, x + y) + self.assertExpectedJournal(code) + + def test_triton_kernel_multi_output(self) -> None: + @helion.kernel(autotune_effort="none") + def k(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + sum_out = torch.empty_like(x) + prod_out = torch.empty_like(x) + for tile in hl.tile(x.shape): + x_val = x[tile] + y_val = y[tile] + sum_val, prod_val = hl.triton_kernel( + "pairwise_ops", + args=(x_val, y_val), + output_like=(x_val, x_val), + ) + sum_out[tile] = sum_val + prod_out[tile] = prod_val + return sum_out, prod_out + + x = torch.randn(64, device=DEVICE, dtype=torch.float32) + y = torch.randn_like(x) + code, (sum_result, prod_result) = code_and_output(k, (x, y)) + self.assertIn("pairwise_ops", code) + torch.testing.assert_close(sum_result, x + y) + torch.testing.assert_close(prod_result, x * y) + self.assertExpectedJournal(code) + + def test_triton_kernel_tl_ops(self) -> None: + @helion.kernel(autotune_effort="none") + def k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + for tile in hl.tile(x.shape): + x_val = x[tile] + y_val = y[tile] + out[tile] = hl.triton_kernel( + "vector_mix_norm", + args=(x_val, y_val), + output_like=x_val, + ) + return out + + x = torch.randn(96, device=DEVICE, dtype=torch.float32) + y = torch.randn_like(x) + code, result = code_and_output(k, (x, y)) + self.assertIn("vector_mix_norm", code) + + bs = 32 + expected = torch.empty_like(x) + for i in range(0, x.numel(), bs): + xa = x[i : i + bs] + ya = y[i : i + bs] + mixed = torch.where(xa > ya, xa - ya, xa + ya) + l2sq = torch.sum(mixed * mixed) + inv = torch.rsqrt( + torch.maximum(l2sq, torch.tensor(1e-12, device=DEVICE, dtype=x.dtype)) + ) + expected[i : i + bs] = mixed * inv + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + self.assertExpectedJournal(code) + + +if __name__ == "__main__": + import unittest + + unittest.main()