Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .source_location import SourceLocation
from .source_location import current_location
from .variable_origin import BlockSizeOrigin
from .variable_origin import GridOrigin
from .variable_origin import Origin

if TYPE_CHECKING:
Expand Down Expand Up @@ -453,7 +454,7 @@ def get_block_id(self, size: int | torch.SymInt | sympy.Basic) -> int | None:
Get the block ID associated with a given size expression.

This method determines if a size expression corresponds to a registered block size
in the current compilation environment. It looks up the origin information of
or grid index in the current compilation environment. It looks up the origin information of
symbolic expressions to find their associated block IDs.

Args:
Expand All @@ -470,7 +471,7 @@ def get_block_id(self, size: int | torch.SymInt | sympy.Basic) -> int | None:
origin_info = HostFunction.current().expr_to_origin.get(size)
if origin_info is not None and isinstance(
origin_info.origin,
BlockSizeOrigin,
(BlockSizeOrigin, GridOrigin),
):
return origin_info.origin.block_id
return None
Expand Down
10 changes: 7 additions & 3 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,12 +414,16 @@ def _lift_sympy_arg(self, expr: sympy.Expr) -> str:

def user_sympy_expr(self, expr: sympy.Expr) -> str:
"""A sympy expression that flows into user computations."""
expr_to_origin = HostFunction.current().expr_to_origin
replacements = {}
for sym in sorted(expr.free_symbols, key=lambda s: s.name):
assert isinstance(sym, sympy.Symbol)
block_idx = CompileEnvironment.current().get_block_id(sym)
if block_idx is not None:
replacements[sym] = self.tile_strategy.user_size(block_idx)
origin_info = expr_to_origin.get(sym)
if origin_info is None:
continue
origin = origin_info.origin
if isinstance(origin, BlockSizeOrigin):
replacements[sym] = self.tile_strategy.user_size(origin.block_id)
if replacements:
# pyrefly: ignore [bad-assignment]
expr = expr.xreplace(replacements)
Expand Down
10 changes: 7 additions & 3 deletions helion/language/_tracing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .._compiler.ast_extension import statement_from_string
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.host_function import HostFunction
from .._compiler.variable_origin import BlockSizeOrigin
from ..exc import NotInsideKernel
from . import _decorators
from .tile_proxy import Tile
Expand Down Expand Up @@ -50,13 +51,16 @@ def _(state: CodegenState) -> ast.AST:
return expr_from_string(str(val))

assert isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)), val
sym_expr = val._sympy_()
origin_info = HostFunction.current().expr_to_origin.get(sym_expr)
# pyrefly: ignore [bad-argument-type]
if (block_idx := CompileEnvironment.current().get_block_id(val)) is not None:
block_size_var = state.device_function.block_size_var(block_idx)
if origin_info is not None and isinstance(origin_info.origin, BlockSizeOrigin):
block_size_var = state.device_function.block_size_var(
origin_info.origin.block_id
)
if block_size_var is None:
return expr_from_string("1")
return expr_from_string(block_size_var)
sym_expr = val._sympy_()
return state.codegen.lift_symnode(
expr_from_string(state.sympy_expr(sym_expr)),
sym_expr,
Expand Down
51 changes: 51 additions & 0 deletions test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,57 @@ def kernel_with_duplicate_refs(x: torch.Tensor) -> torch.Tensor:
code, result = code_and_output(kernel_with_duplicate_refs, (x,))
torch.testing.assert_close(result, expected)

@skipIfRefEager("block_size=1 doesn't work in ref eager mode")
def test_min_hoist(self):
"""Test case to reproduce issue #1155: offsets are hoisted out of loops"""

@helion.kernel(autotune_effort="none")
def kernel(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
batch, seqlen, nheads = g.shape
dstate = u.shape[-1]
chunk_size = hl.specialize(chunk_size)
nchunks = (seqlen + chunk_size - 1) // chunk_size
out = torch.empty(
(batch, nchunks, nheads, dstate), device=g.device, dtype=g.dtype
)
block_v = hl.register_block_size(dstate)
for tile_b, tile_h, tile_v in hl.tile(
[batch, nheads, dstate], block_size=[1, 1, block_v]
):
for t_i in hl.tile(seqlen, block_size=chunk_size):
last = min(t_i.begin + chunk_size - 1, seqlen - 1)
g_scalar = g[tile_b.begin, last, tile_h.begin]
out[tile_b.begin, t_i.id, tile_h.begin, tile_v] = (
g_scalar + hl.zeros([tile_v], dtype=g.dtype)
)
return out

batch, seqlen, nheads, dhead, dstate = 1, 10, 1, 1, 2
chunk_size = 4
k = torch.zeros(
batch, seqlen, nheads, dhead, device=DEVICE, dtype=torch.float32
)
w = torch.zeros_like(k)
u = torch.zeros(
batch, seqlen, nheads, dstate, device=DEVICE, dtype=torch.float32
)
g = torch.arange(seqlen, device=DEVICE, dtype=torch.float32).view(
batch, seqlen, nheads
)

expected = torch.tensor(
[[[[3, 3]], [[7, 7]], [[9, 9]]]], device=DEVICE, dtype=torch.float32
)

result = kernel(k, w, u, g, chunk_size)
torch.testing.assert_close(result, expected)

def test_torch_alloc(self):
@helion.kernel(config={"block_sizes": [64, 64]})
def fn(x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading