From 4e4bd80663f12eb30103661f6ce5578e5dcc4c77 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 5 Nov 2025 14:48:45 -0800 Subject: [PATCH 1/3] test --- test/test_misc.expected | 65 ++++++++++++++++++++++++++++++++++++++ test/test_misc.py | 70 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/test/test_misc.expected b/test/test_misc.expected index 1fc8e8a86..f8109a64a 100644 --- a/test/test_misc.expected +++ b/test/test_misc.expected @@ -1,6 +1,71 @@ This file is automatically generated by assertExpectedJournal calls in test_misc.py. Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. +--- assertExpectedJournal(TestMisc.test_builtin_max) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_helion_max_kernel(x_c, out): + # src[test_misc.py:N]: for chunk in hl.grid(nchunks): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + # src[test_misc.py:N]: out[chunk] = x_c[last_idx // chunk_size, last_idx % chunk_size] + floordiv = (5 * (5 >= 2 * offset_0) + 2 * offset_0 * (2 * offset_0 > 5)) // 2 + mod = (5 * (5 >= 2 * offset_0) + 2 * offset_0 * (2 * offset_0 > 5)) % 2 + load = tl.load(x_c + (floordiv * 2 + mod * 1), None) + tl.store(out + offset_0 * 1, load, None) + +def helion_max_kernel(x_c, *, _launcher=_default_launcher): + # src[test_misc.py:N]: nchunks, chunk_size = x_c.shape + nchunks, chunk_size = x_c.shape + # src[test_misc.py:N]: out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device) + out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device) + # src[test_misc.py:N]: for chunk in hl.grid(nchunks): + # src[test_misc.py:N]: first_idx = chunk * chunk_size + # src[test_misc.py:N]: last_idx = max(first_idx, seqlen - 1) + # src[test_misc.py:N-N]: ... + _launcher(_helion_helion_max_kernel, (3,), x_c, out, num_warps=4, num_stages=1) + # src[test_misc.py:N]: return out + return out + +--- assertExpectedJournal(TestMisc.test_builtin_min) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_helion_min_kernel(x_c, out): + # src[test_misc.py:N]: for chunk in hl.grid(nchunks): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + # src[test_misc.py:N]: last_idx = min((chunk + 1) * chunk_size, seqlen) - 1 + sub = -1 + (6 * (6 <= 2 + 2 * offset_0) + (2 + 2 * offset_0) * (2 + 2 * offset_0 < 6)) + # src[test_misc.py:N]: out[chunk] = x_c[last_idx // chunk_size, last_idx % chunk_size] + floordiv = triton_helpers.div_floor_integer(-1 + (6 * (6 <= 2 + 2 * offset_0) + (2 + 2 * offset_0) * (2 + 2 * offset_0 < 6)), 2) + load = tl.load(x_c + (floordiv * 2 + 1 * 1), None) + tl.store(out + offset_0 * 1, load, None) + +def helion_min_kernel(x_c, *, _launcher=_default_launcher): + # src[test_misc.py:N]: nchunks, chunk_size = x_c.shape + nchunks, chunk_size = x_c.shape + # src[test_misc.py:N]: out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device) + out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device) + # src[test_misc.py:N]: for chunk in hl.grid(nchunks): + # src[test_misc.py:N]: last_idx = min((chunk + 1) * chunk_size, seqlen) - 1 + # src[test_misc.py:N]: out[chunk] = x_c[last_idx // chunk_size, last_idx % chunk_size] + _launcher(_helion_helion_min_kernel, (3,), x_c, out, num_warps=4, num_stages=1) + # src[test_misc.py:N]: return out + return out + --- assertExpectedJournal(TestMisc.test_inputs) from __future__ import annotations diff --git a/test/test_misc.py b/test/test_misc.py index d2d409124..f5bbdb961 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -28,6 +28,7 @@ from helion._testing import code_and_output from helion._testing import import_path from helion._testing import skipIfCpu +from helion._testing import skipIfPyTorchBaseVerLessThan from helion._testing import skipIfRefEager import helion.language as hl @@ -581,6 +582,75 @@ def kernel(fn, t: torch.Tensor): ) ast.parse(code) + @skipIfPyTorchBaseVerLessThan("2.10") + def test_builtin_min(self) -> None: + @helion.kernel(autotune_effort="none") + def helion_min_kernel(x_c): + nchunks, chunk_size = x_c.shape + chunk_size = hl.specialize(chunk_size) + seqlen = chunk_size * nchunks + out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device) + for chunk in hl.grid(nchunks): + last_idx = min((chunk + 1) * chunk_size, seqlen) - 1 + out[chunk] = x_c[last_idx // chunk_size, last_idx % chunk_size] + return out + + def ref_min(x): + nchunks, chunk_size = x.shape + chunk_size = int(chunk_size) + seqlen = chunk_size * nchunks + out = torch.zeros(nchunks, dtype=x.dtype, device=x.device) + for chunk in range(nchunks): + last_idx = min((chunk + 1) * chunk_size, seqlen) - 1 + out[chunk] = x[last_idx // chunk_size, last_idx % chunk_size] + return out + + nchunks, chunk_size = 3, 2 + x = torch.arange( + nchunks * chunk_size, dtype=torch.float32, device=DEVICE + ).reshape(nchunks, chunk_size) + + code, helion_out = code_and_output(helion_min_kernel, (x,)) + ref_out = ref_min(x) + + torch.testing.assert_close(helion_out, ref_out, rtol=1e-3, atol=1e-3) + self.assertExpectedJournal(code) + + def test_builtin_max(self) -> None: + @helion.kernel(autotune_effort="none") + def helion_max_kernel(x_c): + nchunks, chunk_size = x_c.shape + chunk_size = hl.specialize(chunk_size) + seqlen = chunk_size * nchunks + out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device) + for chunk in hl.grid(nchunks): + first_idx = chunk * chunk_size + last_idx = max(first_idx, seqlen - 1) + out[chunk] = x_c[last_idx // chunk_size, last_idx % chunk_size] + return out + + def ref_max(x): + nchunks, chunk_size = x.shape + chunk_size = int(chunk_size) + seqlen = chunk_size * nchunks + out = torch.zeros(nchunks, dtype=x.dtype, device=x.device) + for chunk in range(nchunks): + first_idx = chunk * chunk_size + last_idx = max(first_idx, seqlen - 1) + out[chunk] = x[last_idx // chunk_size, last_idx % chunk_size] + return out + + nchunks, chunk_size = 3, 2 + x = torch.arange( + nchunks * chunk_size, dtype=torch.float32, device=DEVICE + ).reshape(nchunks, chunk_size) + + code, helion_out = code_and_output(helion_max_kernel, (x,)) + ref_out = ref_max(x) + + torch.testing.assert_close(helion_out, ref_out, rtol=1e-3, atol=1e-3) + self.assertExpectedJournal(code) + instantiate_parametrized_tests(TestMisc) From 39c206d4b3481fb94b6be9ccabd31f2e322521a9 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 5 Nov 2025 14:48:48 -0800 Subject: [PATCH 2/3] fix --- helion/_compiler/compile_environment.py | 26 +++++----- helion/_testing.py | 22 +++++++++ helion/language/__init__.py | 2 + helion/language/builtin_ops.py | 65 +++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 15 deletions(-) create mode 100644 helion/language/builtin_ops.py diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index d04e44bb1..392fb1e3a 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -450,22 +450,14 @@ def resolve_block_id(self, size: object) -> int | None: cannot resolve the identifier directly. """ - if isinstance(size, (int, torch.SymInt, sympy.Expr)): - block_id = self.get_block_id(size) - if block_id is not None: - return block_id - else: - block_id = None + if not isinstance(size, (int, torch.SymInt, sympy.Expr)): + return None - if isinstance(size, torch.SymInt): - expr: sympy.Expr | None = size._sympy_() - elif isinstance(size, int): - expr = sympy.Integer(size) - elif isinstance(size, sympy.Expr): - expr = sympy.simplify(size) - else: - expr = None + block_id = self.get_block_id(size) + if block_id is not None: + return block_id + expr = _to_sympy(size) if expr is None or getattr(expr, "free_symbols", None): return None @@ -623,9 +615,13 @@ def warning(warning: exc.BaseWarning | type[exc.BaseWarning]) -> None: print(f"WARNING[{type(warning).__name__}]: {warning.args[0]}", file=sys.stderr) -def _to_sympy(x: int | torch.SymInt) -> sympy.Expr: +def _to_sympy(x: int | torch.SymInt | sympy.Expr) -> sympy.Expr: if isinstance(x, torch.SymInt): return x._sympy_() + if isinstance(x, int): + return sympy.Integer(x) + if isinstance(x, sympy.Expr): + return sympy.simplify(x) return sympy.sympify(x) diff --git a/helion/_testing.py b/helion/_testing.py index 87f459558..c5411d080 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -17,6 +17,7 @@ from typing import Generator import unittest +from packaging import version import pytest import torch from torch.utils._pytree import tree_map @@ -175,6 +176,27 @@ def skipIfPy314(reason: str) -> Callable[[Callable], Callable]: return unittest.skipIf(sys.version_info >= (3, 14), reason) +def skipIfPyTorchBaseVerLessThan(min_version: str) -> Callable[[Callable], Callable]: + """Skip test if PyTorch base version is less than the specified version. + + Uses the base version for comparison, which ignores pre-release/dev/post suffixes. + This allows development versions like "2.10.0.dev20251104" to pass when checking >= "2.10". + + Args: + min_version: Minimum required PyTorch version (e.g., "2.10") + + Returns: + Decorator that skips the test if PyTorch base version is below min_version + """ + current_version = version.parse(torch.__version__.split("+")[0]) + required_version = version.parse(min_version) + current_base = version.parse(current_version.base_version) + return unittest.skipIf( + current_base < required_version, + f"PyTorch version {min_version} or higher required", + ) + + @contextlib.contextmanager def track_run_ref_calls() -> Generator[list[int], None, None]: """Context manager that tracks BoundKernel.run_ref calls. diff --git a/helion/language/__init__.py b/helion/language/__init__.py index 00f002d55..c8324ebe5 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -8,6 +8,8 @@ from .atomic_ops import atomic_or as atomic_or from .atomic_ops import atomic_xchg as atomic_xchg from .atomic_ops import atomic_xor as atomic_xor +from .builtin_ops import _builtin_max as _builtin_max +from .builtin_ops import _builtin_min as _builtin_min from .constexpr import ConstExpr as constexpr # noqa: F401 from .constexpr import specialize as specialize from .creation_ops import arange as arange diff --git a/helion/language/builtin_ops.py b/helion/language/builtin_ops.py new file mode 100644 index 000000000..89d547549 --- /dev/null +++ b/helion/language/builtin_ops.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import builtins +from typing import TYPE_CHECKING + +import sympy + +from .._compiler.compile_environment import CompileEnvironment +from .._compiler.compile_environment import _to_sympy +from . import _decorators + +if TYPE_CHECKING: + import torch + + +def compute_symbolic_min_max( + args: tuple[int | torch.SymInt, ...], op: object +) -> torch.SymInt | int: + env = CompileEnvironment.current() + shape_env = env.shape_env + sympy_op = sympy.Min if op is builtins.min else sympy.Max + hint_fn = min if op is builtins.min else max + + expr = _to_sympy(args[0]) + hint = env.size_hint(args[0]) + + for arg in args[1:]: + rhs_expr = _to_sympy(arg) + rhs_hint = env.size_hint(arg) + expr = sympy_op(expr, rhs_expr) # type: ignore[call-arg] + hint = hint_fn(hint, rhs_hint) # type: ignore[arg-type] + + return shape_env.create_symintnode(expr, hint=hint) # type: ignore[return-value] + + +@_decorators.device_func_replacement(builtins.min) +def _builtin_min(*args: int | torch.SymInt) -> torch.SymInt | int: + """Device replacement for builtin min() that supports symbolic integers. + + Returns the minimum value among the provided arguments, preserving + symbolic integer expressions when present. + + Args: + *args: Integer arguments, which may be concrete ints or symbolic SymInts + + Returns: + The minimum value, as a SymInt if any argument is symbolic, otherwise int + """ + return compute_symbolic_min_max(args, op=builtins.min) + + +@_decorators.device_func_replacement(builtins.max) +def _builtin_max(*args: int | torch.SymInt) -> torch.SymInt | int: + """Device replacement for builtin max() that supports symbolic integers. + + Returns the maximum value among the provided arguments, preserving + symbolic integer expressions when present. + + Args: + *args: Integer arguments, which may be concrete ints or symbolic SymInts + + Returns: + The maximum value, as a SymInt if any argument is symbolic, otherwise int + """ + return compute_symbolic_min_max(args, op=builtins.max) From 7ffc45f10d29689a6c11b9f6e5bcd874db228bd6 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 6 Nov 2025 21:27:17 -0800 Subject: [PATCH 3/3] up --- helion/_compiler/compile_environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 392fb1e3a..8dcf59e6f 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -621,7 +621,7 @@ def _to_sympy(x: int | torch.SymInt | sympy.Expr) -> sympy.Expr: if isinstance(x, int): return sympy.Integer(x) if isinstance(x, sympy.Expr): - return sympy.simplify(x) + return x return sympy.sympify(x)