From cc8cbc64cb7e441b5c9c6481a4406995a68de588 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 30 Oct 2025 16:47:14 -0700 Subject: [PATCH 1/2] test --- test/test_rng.expected | 36 ++++++------- test/test_rng.py | 114 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 125 insertions(+), 25 deletions(-) diff --git a/test/test_rng.expected b/test/test_rng.expected index 7fbf80627..0de84e365 100644 --- a/test/test_rng.expected +++ b/test/test_rng.expected @@ -10,38 +10,36 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal_stride_0, normal_stride_1, rand1_stride_0, rand1_stride_1, rand2_stride_0, rand2_stride_1, randn_a_stride_0, randn_a_stride_1, randn_b_stride_0, randn_b_stride_1, randn_c_stride_0, randn_c_stride_1, uniform_stride_0, uniform_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer): +def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer): # src[test_rng.py:N]: for tile_m, tile_n in hl.tile([m, n]): - num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0) pid_0 = tl.program_id(0) % num_blocks_0 pid_1 = tl.program_id(0) // num_blocks_0 offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) - mask_0 = indices_0 < m offset_1 = pid_1 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) - mask_1 = indices_1 < n # src[test_rng.py:N]: rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) - rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) - tl.store(rand1 + (indices_0[:, None] * rand1_stride_0 + indices_1[None, :] * rand1_stride_1), rand, mask_0[:, None] & mask_1[None, :]) + rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32) + tl.store(rand1 + (indices_0[:, None] * 64 + indices_1[None, :] * 1), rand, None) # src[test_rng.py:N]: rand2[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) - rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) - tl.store(rand2 + (indices_0[:, None] * rand2_stride_0 + indices_1[None, :] * rand2_stride_1), rand_1, mask_0[:, None] & mask_1[None, :]) + rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32) + tl.store(rand2 + (indices_0[:, None] * 64 + indices_1[None, :] * 1), rand_1, None) # src[test_rng.py:N]: uniform[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) - rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) - tl.store(uniform + (indices_0[:, None] * uniform_stride_0 + indices_1[None, :] * uniform_stride_1), rand_2, mask_0[:, None] & mask_1[None, :]) + rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32) + tl.store(uniform + (indices_0[:, None] * 64 + indices_1[None, :] * 1), rand_2, None) # src[test_rng.py:N]: normal[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) - randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) - tl.store(normal + (indices_0[:, None] * normal_stride_0 + indices_1[None, :] * normal_stride_1), randn, mask_0[:, None] & mask_1[None, :]) + randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32) + tl.store(normal + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn, None) # src[test_rng.py:N]: randn_a[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) - randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) - tl.store(randn_a + (indices_0[:, None] * randn_a_stride_0 + indices_1[None, :] * randn_a_stride_1), randn_1, mask_0[:, None] & mask_1[None, :]) + randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32) + tl.store(randn_a + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn_1, None) # src[test_rng.py:N]: randn_b[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) - randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) - tl.store(randn_b + (indices_0[:, None] * randn_b_stride_0 + indices_1[None, :] * randn_b_stride_1), randn_2, mask_0[:, None] & mask_1[None, :]) + randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32) + tl.store(randn_b + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn_2, None) # src[test_rng.py:N]: randn_c[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) - randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) - tl.store(randn_c + (indices_0[:, None] * randn_c_stride_0 + indices_1[None, :] * randn_c_stride_1), randn_3, mask_0[:, None] & mask_1[None, :]) + randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32) + tl.store(randn_c + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn_3, None) def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): from torch._inductor import inductor_prims @@ -73,7 +71,7 @@ def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_rng.py:N]: # Two independent rand operations # src[test_rng.py:N]: rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) # src[test_rng.py:N-N]: ... - _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1) + _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1) # src[test_rng.py:N]: randn_sum = randn_a + randn_b + randn_c randn_sum = randn_a + randn_b + randn_c # src[test_rng.py:N]: return rand1, rand2, uniform, normal, randn_sum diff --git a/test/test_rng.py b/test/test_rng.py index c11c75eb6..3fec8d6d1 100644 --- a/test/test_rng.py +++ b/test/test_rng.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import Callable import unittest import torch @@ -16,7 +17,7 @@ class TestRNG(RefEagerTestBase, TestCase): def test_rand(self): """Test RNG seeding behavior, reproducibility, output range, and distribution.""" - @helion.kernel(static_shapes=False) + @helion.kernel(static_shapes=True, autotune_effort="none") def rand_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: output = torch.zeros_like(x) m, n = x.shape @@ -87,7 +88,7 @@ def rand_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: def test_rand_3d_tensor(self): """Test 3D RNG with tiled operations.""" - @helion.kernel(static_shapes=False) + @helion.kernel(static_shapes=True, autotune_effort="none") def rand_kernel_3d(x: torch.Tensor) -> torch.Tensor: output = torch.zeros_like(x) b, m, n = x.shape @@ -135,7 +136,7 @@ def rand_kernel_3d(x: torch.Tensor) -> torch.Tensor: def test_multiple_rng_ops(self): """Test multiple RNG operations: independence, reproducibility, mixed rand/randn.""" - @helion.kernel(static_shapes=False) + @helion.kernel(static_shapes=True, autotune_effort="none") def multiple_rng_ops_kernel( x: torch.Tensor, ) -> tuple[ @@ -258,7 +259,7 @@ def multiple_rng_ops_kernel( def test_randn_different_seeds_tiled(self): """Test that different torch.manual_seed values produce different outputs for randn.""" - @helion.kernel(static_shapes=False) + @helion.kernel(static_shapes=True, autotune_effort="none") def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: output = torch.zeros_like(x) m, n = x.shape @@ -280,7 +281,7 @@ def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: def test_randn_normal_distribution(self): """Test that torch.randn_like produces normal distribution (mean≈0, std≈1).""" - @helion.kernel(static_shapes=False) + @helion.kernel(static_shapes=True, autotune_effort="none") def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: output = torch.zeros_like(x) m, n = x.shape @@ -315,7 +316,7 @@ def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: def test_randn_3d_tensor(self): """Test 3D randn with tiled operations.""" - @helion.kernel(static_shapes=False) + @helion.kernel(static_shapes=True, autotune_effort="none") def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor: output = torch.zeros_like(x) b, m, n = x.shape @@ -348,6 +349,107 @@ def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor: f"Slice {b_idx} std {slice_std} is not well distributed", ) + def _test_rng_with_dynamic_tile_sizes(self, rng_func, is_uniform, rng_name): + """Common test logic for RNG operations with dynamic tile sizes.""" + + # Single kernel that takes an RNG callable as a parameter + @helion.kernel(static_shapes=True, autotune_effort="none") + def rng_kernel( + x: torch.Tensor, + rng_func: Callable[[int, int, torch.dtype], torch.Tensor], + ) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = rng_func(tile_m, tile_n, x.dtype) + return output + + x = torch.ones(64, 64, device=DEVICE) + torch.manual_seed(42) + _code, output = code_and_output(rng_kernel, (x, rng_func)) + + # Check distribution properties based on RNG type + if is_uniform: + # For rand: values in [0, 1), mean ~0.5 + self.assertTrue( + torch.all(output >= 0.0), f"{rng_name}: All values should be >= 0" + ) + self.assertTrue( + torch.all(output < 1.0), f"{rng_name}: All values should be < 1" + ) + mean_val = output.mean().item() + self.assertTrue( + 0.4 < mean_val < 0.6, + f"{rng_name}: Mean {mean_val:.3f} should be ~0.5", + ) + else: + # For randn: mean ~0, std ~1 + mean_val = output.mean().item() + std_val = output.std().item() + self.assertTrue( + -0.15 < mean_val < 0.15, f"{rng_name}: Mean {mean_val:.3f} should be ~0" + ) + self.assertTrue( + 0.9 < std_val < 1.1, f"{rng_name}: Std {std_val:.3f} should be ~1" + ) + + # Test reproducibility with same seed + torch.manual_seed(42) + _code2, output2 = code_and_output(rng_kernel, (x, rng_func)) + torch.testing.assert_close( + output, + output2, + msg=f"{rng_name}: Same seed should produce identical outputs", + ) + + # Test that different seeds produce different outputs + torch.manual_seed(99) + _code3, output3 = code_and_output(rng_kernel, (x, rng_func)) + self.assertFalse( + torch.allclose(output, output3), + f"{rng_name}: Different seeds should produce different outputs", + ) + + def test_rand_with_dynamic_tile_sizes(self): + """Test torch.rand with dynamic tile dimensions.""" + self._test_rng_with_dynamic_tile_sizes( + rng_func=lambda tile_m, tile_n, dtype: torch.rand( + (tile_m, tile_n), dtype=dtype, device=DEVICE + ), + is_uniform=True, + rng_name="rand", + ) + + def test_rand_like_with_dynamic_tile_sizes(self): + """Test torch.rand_like with dynamic tile dimensions.""" + self._test_rng_with_dynamic_tile_sizes( + rng_func=lambda tile_m, tile_n, dtype: torch.rand_like( + torch.ones((tile_m, tile_n), dtype=dtype, device=DEVICE) + ), + is_uniform=True, + rng_name="rand_like", + ) + + def test_randn_with_dynamic_tile_sizes(self): + """Test torch.randn with dynamic tile dimensions.""" + self._test_rng_with_dynamic_tile_sizes( + rng_func=lambda tile_m, tile_n, dtype: torch.randn( + (tile_m, tile_n), dtype=dtype, device=DEVICE + ), + is_uniform=False, + rng_name="randn", + ) + + def test_randn_like_with_dynamic_tile_sizes(self): + """Test torch.randn_like with dynamic tile dimensions.""" + self._test_rng_with_dynamic_tile_sizes( + rng_func=lambda tile_m, tile_n, dtype: torch.randn_like( + torch.ones((tile_m, tile_n), dtype=dtype, device=DEVICE) + ), + is_uniform=False, + rng_name="randn_like", + ) + if __name__ == "__main__": unittest.main() From 4ed0f9b947b827961043576b240a155be810b8ea Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 30 Oct 2025 19:45:32 -0700 Subject: [PATCH 2/2] fix --- helion/_compiler/inductor_lowering.py | 23 ++++++++++++----------- helion/language/ref_tile.py | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 001c0e78f..49db6d664 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -52,7 +52,6 @@ from .ast_extension import statement_from_string from .compile_environment import CompileEnvironment from .compile_environment import FixedBlockSizeSource -from .device_function import SymbolArgument from .device_function import VarInfo from .device_function import contains_only_block_size_symbols from .dtype_utils import cast_ast @@ -1559,7 +1558,10 @@ def _codegen_rng_op( node: The FX node for this operation rng_function: Either "rand" or "randn" """ + from .generate_ast import GenerateAST + assert rng_function in ["rand", "randn"] + assert isinstance(ctx.cg, GenerateAST) # Get unique seed index for this RNG operation device_fn = ctx.cg.device_function @@ -1567,19 +1569,18 @@ def _codegen_rng_op( # Get dimensionality and dtype assert hasattr(node, "meta") and "val" in node.meta - ndim = node.meta["val"].ndim + fake_value = node.meta["val"] + ndim = fake_value.ndim dtype = node.kwargs.get("dtype", None) - # Get the dimension variable names from the device function's symbol arguments - device_fn = ctx.cg.device_function - symbol_args = [ - arg for arg in device_fn.arguments if isinstance(arg, SymbolArgument) - ] - - # Extract dimension names - they should be the last ndim symbol arguments + # Get dimension names for offset calculation + env = CompileEnvironment.current() dim_names = [] - assert len(symbol_args) >= ndim, "Not enough symbol arguments for dimensions" - dim_names = [arg.name for arg in symbol_args[-ndim:]] + for size in fake_value.size(): + block_id = env.get_block_id(size) + assert block_id is not None + block_size = env.block_sizes[block_id].size + dim_names.append(device_fn.literal_expr(block_size)) offset_parts = [] diff --git a/helion/language/ref_tile.py b/helion/language/ref_tile.py index e2f3cc323..63a51441c 100644 --- a/helion/language/ref_tile.py +++ b/helion/language/ref_tile.py @@ -1,8 +1,10 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import TypeVar import torch +from torch.utils._pytree import tree_map_only from .. import exc from .._utils import convert_tile_indices_to_slices @@ -12,6 +14,8 @@ if TYPE_CHECKING: from collections.abc import Callable +_T = TypeVar("_T") + _ADD_OPS: set[object] = { torch.add, @@ -71,8 +75,24 @@ def __torch_function__( if func in _SUB_OPS: return cls._handle_sub(args) + # For any other torch.* function or torch.Tensor.* method, convert tiles to sizes + is_torch_func = getattr(func, "__module__", "") == "torch" + is_tensor_method = hasattr(torch.Tensor, getattr(func, "__name__", "")) + if is_torch_func or is_tensor_method: + new_args = cls._tiles_to_sizes(args) + new_kwargs = cls._tiles_to_sizes(kwargs) if kwargs else {} + return func(*new_args, **new_kwargs) + raise exc.IncorrectTileUsage(func) + @classmethod + def _tiles_to_sizes(cls, it: _T) -> _T: + return tree_map_only(RefTile, cls._tile_to_size, it) + + @staticmethod + def _tile_to_size(tile: RefTile) -> int: + return tile.block_size + @classmethod def _handle_add(cls, args: tuple[object, ...]) -> torch.Tensor: tile, offset, flipped = cls._extract_tile_and_offset(args, torch.add)