diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 09d63af11..f7ae06fcc 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -281,7 +281,17 @@ def to_fake(self, obj: object, origin: Origin) -> object: with self.shape_env.ignore_fresh_unbacked_symbols(): return self.shape_env.create_unbacked_symbool() if isinstance(obj, int): - return self.create_unbacked_symint() + # Preserve the concrete value as the initial hint so that + # subsequent hl.specialize() calls can recover the real value + # rather than falling back to the generic size hint. + sym = self.create_unbacked_symint(hint=obj) + try: + source = origin.to_source() + except NotImplementedError: + pass + else: + self.shape_env.var_to_sources[sym._sympy_()] = [source] + return sym if isinstance(obj, float): with self.shape_env.ignore_fresh_unbacked_symbols(): return self.shape_env.create_unbacked_symfloat() diff --git a/test/test_views.expected b/test/test_views.expected index e9c86e941..8a055bbe0 100644 --- a/test/test_views.expected +++ b/test/test_views.expected @@ -1,6 +1,53 @@ This file is automatically generated by assertExpectedJournal calls in test_views.py. Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. +--- assertExpectedJournal(TestViews.test_specialize_reshape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_fn(reshaped, out, _BLOCK_SIZE_2: tl.constexpr): + # src[test_views.py:N]: for tile in hl.tile(reshaped.size()): + num_blocks_0 = 2 + num_blocks_1 = 3 + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1 + pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) + offset_1 = pid_1 + indices_1 = offset_1 + tl.zeros([1], tl.int32) + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + # src[test_views.py:N]: out[tile] = reshaped[tile] + 1 + load = tl.load(reshaped + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), None) + v_0 = 1.0 + v_1 = load + v_0 + tl.store(out + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), v_1, None) + +def fn(x: torch.Tensor, chunk_size: int, *, _launcher=_default_launcher): + # src[test_views.py:N]: batch, seqlen = x.shape + batch, seqlen = x.shape + # src[test_views.py:N]: chunk_size = hl.specialize(chunk_size) + chunk_size = 32 + # src[test_views.py:N]: nchunks = (seqlen + chunk_size - 1) // chunk_size + nchunks = (seqlen + chunk_size - 1) // chunk_size + # src[test_views.py:N]: reshaped = x.reshape(batch, nchunks, chunk_size) + reshaped = x.reshape(batch, nchunks, chunk_size) + # src[test_views.py:N]: out = torch.empty_like(reshaped) + out = torch.empty_like(reshaped) + # src[test_views.py:N]: for tile in hl.tile(reshaped.size()): + _BLOCK_SIZE_2 = 32 + # src[test_views.py:N]: for tile in hl.tile(reshaped.size()): + # src[test_views.py:N]: out[tile] = reshaped[tile] + 1 + _launcher(_helion_fn, (2 * 3 * triton.cdiv(32, _BLOCK_SIZE_2),), reshaped, out, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_views.py:N]: return out.reshape(batch, seqlen) + return out.reshape(batch, seqlen) + --- assertExpectedJournal(TestViews.test_reshape_sum) from __future__ import annotations diff --git a/test/test_views.py b/test/test_views.py index 51fe0535d..894106fc8 100644 --- a/test/test_views.py +++ b/test/test_views.py @@ -16,6 +16,28 @@ class TestViews(RefEagerTestBase, TestCase): + def test_specialize_reshape(self): + @helion.kernel() + def fn(x: torch.Tensor, chunk_size: int) -> torch.Tensor: + batch, seqlen = x.shape + chunk_size = hl.specialize(chunk_size) + nchunks = (seqlen + chunk_size - 1) // chunk_size + reshaped = x.reshape(batch, nchunks, chunk_size) + out = torch.empty_like(reshaped) + for tile in hl.tile(reshaped.size()): + out[tile] = reshaped[tile] + 1 + return out.reshape(batch, seqlen) + + chunk_size = 32 + x = torch.randn(2, chunk_size * 3, device=DEVICE) + code, result = code_and_output( + fn, + (x, chunk_size), + block_sizes=[1, 1, 32], + ) + torch.testing.assert_close(result, x + 1) + self.assertExpectedJournal(code) + def test_softmax_unsqueeze(self): @helion.kernel(config={"block_size": 1}) def softmax(x: torch.Tensor) -> torch.Tensor: