Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,22 +450,14 @@ def resolve_block_id(self, size: object) -> int | None:
cannot resolve the identifier directly.
"""

if isinstance(size, (int, torch.SymInt, sympy.Expr)):
block_id = self.get_block_id(size)
if block_id is not None:
return block_id
else:
block_id = None
if not isinstance(size, (int, torch.SymInt, sympy.Expr)):
return None

if isinstance(size, torch.SymInt):
expr: sympy.Expr | None = size._sympy_()
elif isinstance(size, int):
expr = sympy.Integer(size)
elif isinstance(size, sympy.Expr):
expr = sympy.simplify(size)
else:
expr = None
block_id = self.get_block_id(size)
if block_id is not None:
return block_id

expr = _to_sympy(size)
if expr is None or getattr(expr, "free_symbols", None):
return None

Expand Down Expand Up @@ -623,9 +615,13 @@ def warning(warning: exc.BaseWarning | type[exc.BaseWarning]) -> None:
print(f"WARNING[{type(warning).__name__}]: {warning.args[0]}", file=sys.stderr)


def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:
def _to_sympy(x: int | torch.SymInt | sympy.Expr) -> sympy.Expr:
if isinstance(x, torch.SymInt):
return x._sympy_()
if isinstance(x, int):
return sympy.Integer(x)
if isinstance(x, sympy.Expr):
return x
return sympy.sympify(x)


Expand Down
22 changes: 22 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Generator
import unittest

from packaging import version
import pytest
import torch
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -175,6 +176,27 @@ def skipIfPy314(reason: str) -> Callable[[Callable], Callable]:
return unittest.skipIf(sys.version_info >= (3, 14), reason)


def skipIfPyTorchBaseVerLessThan(min_version: str) -> Callable[[Callable], Callable]:
"""Skip test if PyTorch base version is less than the specified version.

Uses the base version for comparison, which ignores pre-release/dev/post suffixes.
This allows development versions like "2.10.0.dev20251104" to pass when checking >= "2.10".

Args:
min_version: Minimum required PyTorch version (e.g., "2.10")

Returns:
Decorator that skips the test if PyTorch base version is below min_version
"""
current_version = version.parse(torch.__version__.split("+")[0])
required_version = version.parse(min_version)
current_base = version.parse(current_version.base_version)
return unittest.skipIf(
current_base < required_version,
f"PyTorch version {min_version} or higher required",
)


@contextlib.contextmanager
def track_run_ref_calls() -> Generator[list[int], None, None]:
"""Context manager that tracks BoundKernel.run_ref calls.
Expand Down
2 changes: 2 additions & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .atomic_ops import atomic_or as atomic_or
from .atomic_ops import atomic_xchg as atomic_xchg
from .atomic_ops import atomic_xor as atomic_xor
from .builtin_ops import _builtin_max as _builtin_max
from .builtin_ops import _builtin_min as _builtin_min
from .constexpr import ConstExpr as constexpr # noqa: F401
from .constexpr import specialize as specialize
from .creation_ops import arange as arange
Expand Down
65 changes: 65 additions & 0 deletions helion/language/builtin_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import builtins
from typing import TYPE_CHECKING

import sympy

from .._compiler.compile_environment import CompileEnvironment
from .._compiler.compile_environment import _to_sympy
from . import _decorators

if TYPE_CHECKING:
import torch


def compute_symbolic_min_max(
args: tuple[int | torch.SymInt, ...], op: object
) -> torch.SymInt | int:
env = CompileEnvironment.current()
shape_env = env.shape_env
sympy_op = sympy.Min if op is builtins.min else sympy.Max
hint_fn = min if op is builtins.min else max

expr = _to_sympy(args[0])
hint = env.size_hint(args[0])

for arg in args[1:]:
rhs_expr = _to_sympy(arg)
rhs_hint = env.size_hint(arg)
expr = sympy_op(expr, rhs_expr) # type: ignore[call-arg]
hint = hint_fn(hint, rhs_hint) # type: ignore[arg-type]

return shape_env.create_symintnode(expr, hint=hint) # type: ignore[return-value]


@_decorators.device_func_replacement(builtins.min)
def _builtin_min(*args: int | torch.SymInt) -> torch.SymInt | int:
"""Device replacement for builtin min() that supports symbolic integers.

Returns the minimum value among the provided arguments, preserving
symbolic integer expressions when present.

Args:
*args: Integer arguments, which may be concrete ints or symbolic SymInts

Returns:
The minimum value, as a SymInt if any argument is symbolic, otherwise int
"""
return compute_symbolic_min_max(args, op=builtins.min)


@_decorators.device_func_replacement(builtins.max)
def _builtin_max(*args: int | torch.SymInt) -> torch.SymInt | int:
"""Device replacement for builtin max() that supports symbolic integers.

Returns the maximum value among the provided arguments, preserving
symbolic integer expressions when present.

Args:
*args: Integer arguments, which may be concrete ints or symbolic SymInts

Returns:
The maximum value, as a SymInt if any argument is symbolic, otherwise int
"""
return compute_symbolic_min_max(args, op=builtins.max)
65 changes: 65 additions & 0 deletions test/test_misc.expected
Original file line number Diff line number Diff line change
@@ -1,6 +1,71 @@
This file is automatically generated by assertExpectedJournal calls in test_misc.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestMisc.test_builtin_max)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_helion_max_kernel(x_c, out):
# src[test_misc.py:N]: for chunk in hl.grid(nchunks):
pid_0 = tl.program_id(0)
offset_0 = pid_0
# src[test_misc.py:N]: out[chunk] = x_c[last_idx // chunk_size, last_idx % chunk_size]
floordiv = (5 * (5 >= 2 * offset_0) + 2 * offset_0 * (2 * offset_0 > 5)) // 2
mod = (5 * (5 >= 2 * offset_0) + 2 * offset_0 * (2 * offset_0 > 5)) % 2
load = tl.load(x_c + (floordiv * 2 + mod * 1), None)
tl.store(out + offset_0 * 1, load, None)

def helion_max_kernel(x_c, *, _launcher=_default_launcher):
# src[test_misc.py:N]: nchunks, chunk_size = x_c.shape
nchunks, chunk_size = x_c.shape
# src[test_misc.py:N]: out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device)
out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device)
# src[test_misc.py:N]: for chunk in hl.grid(nchunks):
# src[test_misc.py:N]: first_idx = chunk * chunk_size
# src[test_misc.py:N]: last_idx = max(first_idx, seqlen - 1)
# src[test_misc.py:N-N]: ...
_launcher(_helion_helion_max_kernel, (3,), x_c, out, num_warps=4, num_stages=1)
# src[test_misc.py:N]: return out
return out

--- assertExpectedJournal(TestMisc.test_builtin_min)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_helion_min_kernel(x_c, out):
# src[test_misc.py:N]: for chunk in hl.grid(nchunks):
pid_0 = tl.program_id(0)
offset_0 = pid_0
# src[test_misc.py:N]: last_idx = min((chunk + 1) * chunk_size, seqlen) - 1
sub = -1 + (6 * (6 <= 2 + 2 * offset_0) + (2 + 2 * offset_0) * (2 + 2 * offset_0 < 6))
# src[test_misc.py:N]: out[chunk] = x_c[last_idx // chunk_size, last_idx % chunk_size]
floordiv = triton_helpers.div_floor_integer(-1 + (6 * (6 <= 2 + 2 * offset_0) + (2 + 2 * offset_0) * (2 + 2 * offset_0 < 6)), 2)
load = tl.load(x_c + (floordiv * 2 + 1 * 1), None)
tl.store(out + offset_0 * 1, load, None)

def helion_min_kernel(x_c, *, _launcher=_default_launcher):
# src[test_misc.py:N]: nchunks, chunk_size = x_c.shape
nchunks, chunk_size = x_c.shape
# src[test_misc.py:N]: out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device)
out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device)
# src[test_misc.py:N]: for chunk in hl.grid(nchunks):
# src[test_misc.py:N]: last_idx = min((chunk + 1) * chunk_size, seqlen) - 1
# src[test_misc.py:N]: out[chunk] = x_c[last_idx // chunk_size, last_idx % chunk_size]
_launcher(_helion_helion_min_kernel, (3,), x_c, out, num_warps=4, num_stages=1)
# src[test_misc.py:N]: return out
return out

--- assertExpectedJournal(TestMisc.test_inputs)
from __future__ import annotations

Expand Down
70 changes: 70 additions & 0 deletions test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from helion._testing import code_and_output
from helion._testing import import_path
from helion._testing import skipIfCpu
from helion._testing import skipIfPyTorchBaseVerLessThan
from helion._testing import skipIfRefEager
import helion.language as hl

Expand Down Expand Up @@ -581,6 +582,75 @@ def kernel(fn, t: torch.Tensor):
)
ast.parse(code)

@skipIfPyTorchBaseVerLessThan("2.10")
Copy link
Contributor Author

@yf225 yf225 Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch 2.9 has a indexing expression related bug that's only fixed in PyTorch nightly: pytorch/pytorch#131761. Making a patch for that in Helion framework is quite complicated, thus we skip this unit test unless we are on 2.10 version (including current nightly).

def test_builtin_min(self) -> None:
@helion.kernel(autotune_effort="none")
def helion_min_kernel(x_c):
nchunks, chunk_size = x_c.shape
chunk_size = hl.specialize(chunk_size)
seqlen = chunk_size * nchunks
out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device)
for chunk in hl.grid(nchunks):
last_idx = min((chunk + 1) * chunk_size, seqlen) - 1
out[chunk] = x_c[last_idx // chunk_size, last_idx % chunk_size]
return out

def ref_min(x):
nchunks, chunk_size = x.shape
chunk_size = int(chunk_size)
seqlen = chunk_size * nchunks
out = torch.zeros(nchunks, dtype=x.dtype, device=x.device)
for chunk in range(nchunks):
last_idx = min((chunk + 1) * chunk_size, seqlen) - 1
out[chunk] = x[last_idx // chunk_size, last_idx % chunk_size]
return out

nchunks, chunk_size = 3, 2
x = torch.arange(
nchunks * chunk_size, dtype=torch.float32, device=DEVICE
).reshape(nchunks, chunk_size)

code, helion_out = code_and_output(helion_min_kernel, (x,))
ref_out = ref_min(x)

torch.testing.assert_close(helion_out, ref_out, rtol=1e-3, atol=1e-3)
self.assertExpectedJournal(code)

def test_builtin_max(self) -> None:
@helion.kernel(autotune_effort="none")
def helion_max_kernel(x_c):
nchunks, chunk_size = x_c.shape
chunk_size = hl.specialize(chunk_size)
seqlen = chunk_size * nchunks
out = torch.zeros(nchunks, dtype=x_c.dtype, device=x_c.device)
for chunk in hl.grid(nchunks):
first_idx = chunk * chunk_size
last_idx = max(first_idx, seqlen - 1)
out[chunk] = x_c[last_idx // chunk_size, last_idx % chunk_size]
return out

def ref_max(x):
nchunks, chunk_size = x.shape
chunk_size = int(chunk_size)
seqlen = chunk_size * nchunks
out = torch.zeros(nchunks, dtype=x.dtype, device=x.device)
for chunk in range(nchunks):
first_idx = chunk * chunk_size
last_idx = max(first_idx, seqlen - 1)
out[chunk] = x[last_idx // chunk_size, last_idx % chunk_size]
return out

nchunks, chunk_size = 3, 2
x = torch.arange(
nchunks * chunk_size, dtype=torch.float32, device=DEVICE
).reshape(nchunks, chunk_size)

code, helion_out = code_and_output(helion_max_kernel, (x,))
ref_out = ref_max(x)

torch.testing.assert_close(helion_out, ref_out, rtol=1e-3, atol=1e-3)
self.assertExpectedJournal(code)


instantiate_parametrized_tests(TestMisc)

Expand Down
Loading