Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,17 @@ def to_fake(self, obj: object, origin: Origin) -> object:
with self.shape_env.ignore_fresh_unbacked_symbols():
return self.shape_env.create_unbacked_symbool()
if isinstance(obj, int):
return self.create_unbacked_symint()
# Preserve the concrete value as the initial hint so that
# subsequent hl.specialize() calls can recover the real value
# rather than falling back to the generic size hint.
sym = self.create_unbacked_symint(hint=obj)
try:
source = origin.to_source()
except NotImplementedError:
pass
else:
self.shape_env.var_to_sources[sym._sympy_()] = [source]
return sym
if isinstance(obj, float):
with self.shape_env.ignore_fresh_unbacked_symbols():
return self.shape_env.create_unbacked_symfloat()
Expand Down
47 changes: 47 additions & 0 deletions test/test_views.expected
Original file line number Diff line number Diff line change
@@ -1,6 +1,53 @@
This file is automatically generated by assertExpectedJournal calls in test_views.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestViews.test_specialize_reshape)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_fn(reshaped, out, _BLOCK_SIZE_2: tl.constexpr):
# src[test_views.py:N]: for tile in hl.tile(reshaped.size()):
num_blocks_0 = 2
num_blocks_1 = 3
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
offset_1 = pid_1
indices_1 = offset_1 + tl.zeros([1], tl.int32)
offset_2 = pid_2 * _BLOCK_SIZE_2
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
# src[test_views.py:N]: out[tile] = reshaped[tile] + 1
load = tl.load(reshaped + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), None)
v_0 = 1.0
v_1 = load + v_0
tl.store(out + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), v_1, None)

def fn(x: torch.Tensor, chunk_size: int, *, _launcher=_default_launcher):
# src[test_views.py:N]: batch, seqlen = x.shape
batch, seqlen = x.shape
# src[test_views.py:N]: chunk_size = hl.specialize(chunk_size)
chunk_size = 32
# src[test_views.py:N]: nchunks = (seqlen + chunk_size - 1) // chunk_size
nchunks = (seqlen + chunk_size - 1) // chunk_size
# src[test_views.py:N]: reshaped = x.reshape(batch, nchunks, chunk_size)
reshaped = x.reshape(batch, nchunks, chunk_size)
# src[test_views.py:N]: out = torch.empty_like(reshaped)
out = torch.empty_like(reshaped)
# src[test_views.py:N]: for tile in hl.tile(reshaped.size()):
_BLOCK_SIZE_2 = 32
# src[test_views.py:N]: for tile in hl.tile(reshaped.size()):
# src[test_views.py:N]: out[tile] = reshaped[tile] + 1
_launcher(_helion_fn, (2 * 3 * triton.cdiv(32, _BLOCK_SIZE_2),), reshaped, out, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
# src[test_views.py:N]: return out.reshape(batch, seqlen)
return out.reshape(batch, seqlen)

--- assertExpectedJournal(TestViews.test_reshape_sum)
from __future__ import annotations

Expand Down
22 changes: 22 additions & 0 deletions test/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,28 @@


class TestViews(RefEagerTestBase, TestCase):
def test_specialize_reshape(self):
@helion.kernel()
def fn(x: torch.Tensor, chunk_size: int) -> torch.Tensor:
batch, seqlen = x.shape
chunk_size = hl.specialize(chunk_size)
nchunks = (seqlen + chunk_size - 1) // chunk_size
reshaped = x.reshape(batch, nchunks, chunk_size)
out = torch.empty_like(reshaped)
for tile in hl.tile(reshaped.size()):
out[tile] = reshaped[tile] + 1
return out.reshape(batch, seqlen)

chunk_size = 32
x = torch.randn(2, chunk_size * 3, device=DEVICE)
code, result = code_and_output(
fn,
(x, chunk_size),
block_sizes=[1, 1, 32],
)
torch.testing.assert_close(result, x + 1)
self.assertExpectedJournal(code)

def test_softmax_unsqueeze(self):
@helion.kernel(config={"block_size": 1})
def softmax(x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading