From 4e27dfaea2973504fdf0f035bee7ab554b973d2d Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 20 Nov 2025 17:05:17 -0800 Subject: [PATCH 1/2] test --- test/test_misc.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/test/test_misc.py b/test/test_misc.py index f5bbdb961..909e0709a 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -54,6 +54,57 @@ def kernel_with_duplicate_refs(x: torch.Tensor) -> torch.Tensor: code, result = code_and_output(kernel_with_duplicate_refs, (x,)) torch.testing.assert_close(result, expected) + @skipIfRefEager("block_size=1 doesn't work in ref eager mode") + def test_min_hoist(self): + """Test case to reproduce issue #1155: offsets are hoisted out of loops""" + + @helion.kernel(autotune_effort="none") + def kernel( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + batch, seqlen, nheads = g.shape + dstate = u.shape[-1] + chunk_size = hl.specialize(chunk_size) + nchunks = (seqlen + chunk_size - 1) // chunk_size + out = torch.empty( + (batch, nchunks, nheads, dstate), device=g.device, dtype=g.dtype + ) + block_v = hl.register_block_size(dstate) + for tile_b, tile_h, tile_v in hl.tile( + [batch, nheads, dstate], block_size=[1, 1, block_v] + ): + for t_i in hl.tile(seqlen, block_size=chunk_size): + last = min(t_i.begin + chunk_size - 1, seqlen - 1) + g_scalar = g[tile_b.begin, last, tile_h.begin] + out[tile_b.begin, t_i.id, tile_h.begin, tile_v] = ( + g_scalar + hl.zeros([tile_v], dtype=g.dtype) + ) + return out + + batch, seqlen, nheads, dhead, dstate = 1, 10, 1, 1, 2 + chunk_size = 4 + k = torch.zeros( + batch, seqlen, nheads, dhead, device=DEVICE, dtype=torch.float32 + ) + w = torch.zeros_like(k) + u = torch.zeros( + batch, seqlen, nheads, dstate, device=DEVICE, dtype=torch.float32 + ) + g = torch.arange(seqlen, device=DEVICE, dtype=torch.float32).view( + batch, seqlen, nheads + ) + + expected = torch.tensor( + [[[[3, 3]], [[7, 7]], [[9, 9]]]], device=DEVICE, dtype=torch.float32 + ) + + result = kernel(k, w, u, g, chunk_size) + torch.testing.assert_close(result, expected) + def test_torch_alloc(self): @helion.kernel(config={"block_sizes": [64, 64]}) def fn(x: torch.Tensor) -> torch.Tensor: From 0e049bdd9e9ba94c28aef6b9121b31a83e390f01 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 20 Nov 2025 17:05:23 -0800 Subject: [PATCH 2/2] fix --- helion/_compiler/compile_environment.py | 5 +++-- helion/_compiler/device_function.py | 10 +++++++--- helion/language/_tracing_ops.py | 10 +++++++--- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 04159bf97..8324df313 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -25,6 +25,7 @@ from .source_location import SourceLocation from .source_location import current_location from .variable_origin import BlockSizeOrigin +from .variable_origin import GridOrigin from .variable_origin import Origin if TYPE_CHECKING: @@ -453,7 +454,7 @@ def get_block_id(self, size: int | torch.SymInt | sympy.Basic) -> int | None: Get the block ID associated with a given size expression. This method determines if a size expression corresponds to a registered block size - in the current compilation environment. It looks up the origin information of + or grid index in the current compilation environment. It looks up the origin information of symbolic expressions to find their associated block IDs. Args: @@ -470,7 +471,7 @@ def get_block_id(self, size: int | torch.SymInt | sympy.Basic) -> int | None: origin_info = HostFunction.current().expr_to_origin.get(size) if origin_info is not None and isinstance( origin_info.origin, - BlockSizeOrigin, + (BlockSizeOrigin, GridOrigin), ): return origin_info.origin.block_id return None diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 95d48ab3e..94e54e087 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -414,12 +414,16 @@ def _lift_sympy_arg(self, expr: sympy.Expr) -> str: def user_sympy_expr(self, expr: sympy.Expr) -> str: """A sympy expression that flows into user computations.""" + expr_to_origin = HostFunction.current().expr_to_origin replacements = {} for sym in sorted(expr.free_symbols, key=lambda s: s.name): assert isinstance(sym, sympy.Symbol) - block_idx = CompileEnvironment.current().get_block_id(sym) - if block_idx is not None: - replacements[sym] = self.tile_strategy.user_size(block_idx) + origin_info = expr_to_origin.get(sym) + if origin_info is None: + continue + origin = origin_info.origin + if isinstance(origin, BlockSizeOrigin): + replacements[sym] = self.tile_strategy.user_size(origin.block_id) if replacements: # pyrefly: ignore [bad-assignment] expr = expr.xreplace(replacements) diff --git a/helion/language/_tracing_ops.py b/helion/language/_tracing_ops.py index a4ddae9a8..e77f2aba9 100644 --- a/helion/language/_tracing_ops.py +++ b/helion/language/_tracing_ops.py @@ -16,6 +16,7 @@ from .._compiler.ast_extension import statement_from_string from .._compiler.compile_environment import CompileEnvironment from .._compiler.host_function import HostFunction +from .._compiler.variable_origin import BlockSizeOrigin from ..exc import NotInsideKernel from . import _decorators from .tile_proxy import Tile @@ -50,13 +51,16 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(str(val)) assert isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)), val + sym_expr = val._sympy_() + origin_info = HostFunction.current().expr_to_origin.get(sym_expr) # pyrefly: ignore [bad-argument-type] - if (block_idx := CompileEnvironment.current().get_block_id(val)) is not None: - block_size_var = state.device_function.block_size_var(block_idx) + if origin_info is not None and isinstance(origin_info.origin, BlockSizeOrigin): + block_size_var = state.device_function.block_size_var( + origin_info.origin.block_id + ) if block_size_var is None: return expr_from_string("1") return expr_from_string(block_size_var) - sym_expr = val._sympy_() return state.codegen.lift_symnode( expr_from_string(state.sympy_expr(sym_expr)), sym_expr,