From c112edba47063d88e6983d36f3cfc57345cc56aa Mon Sep 17 00:00:00 2001 From: karthickai Date: Sun, 21 Sep 2025 19:43:13 -0700 Subject: [PATCH] Add hl.rand op with seed arg lowering to tl.rand stack-info: PR: https://github.com/pytorch/helion/pull/652, branch: karthickai/stack/2 --- helion/language/__init__.py | 1 + helion/language/random_ops.py | 121 ++++++++++++++++++++++++++++++++++ test/test_rng.py | 91 +++++++++++++++++++++++++ 3 files changed, 213 insertions(+) create mode 100644 helion/language/random_ops.py diff --git a/helion/language/__init__.py b/helion/language/__init__.py index 47f6945fc..856897808 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -21,6 +21,7 @@ from .matmul_ops import dot as dot from .memory_ops import load as load from .memory_ops import store as store +from .random_ops import rand as rand from .reduce_ops import reduce as reduce from .scan_ops import associative_scan as associative_scan from .scan_ops import cumprod as cumprod diff --git a/helion/language/random_ops.py b/helion/language/random_ops.py new file mode 100644 index 000000000..bf72fd86e --- /dev/null +++ b/helion/language/random_ops.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from .._compiler.ast_extension import expr_from_string +from .._compiler.compile_environment import CompileEnvironment +from ..exc import NotInsideKernel +from . import _decorators +from .ref_tile import RefTile + +if TYPE_CHECKING: + import ast + + from .._compiler.inductor_lowering import CodegenState + +__all__ = ["rand"] + + +@_decorators.api(tiles_as_sizes=True) +def rand( + shape: list[object], + seed: int, + dtype: torch.dtype = torch.float32, + device: torch.device | None = None, +) -> torch.Tensor: + """ + The main propose of ``hl.rand`` is to explicitly pass a seed arg for deterministic + randomness in helion kernels, whereas ``torch.rand_like`` doesn't take seed arg + (though it can seeded globally)`. ``hl.rand`` lower to ``tl.rand(seed, offset)`` with ``offset`` + built from a linear range over the allocation and reshaped to the given shape. + + Note: + Only use within ``hl.tile()`` loops for creating local tensors. + For host allocations, use ``torch.rand()``. + + Args: + shape: A list of sizes + seed: int seed for the random number generator + dtype: currently only float32 supported + + Returns: + torch.Tensor: A device tensor of the given shape and dtype filled with random values + + Examples: + .. code-block:: python + + @helion.kernel + def process_kernel(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + (m,) = x.shape + for (tile_m,) in hl.tile([m]): + output[tile_m] = hl.rand([tile_m], seed=seed) + return output + + """ + raise NotInsideKernel + + +@_decorators.register_fake(rand) +def _rand_fake( + shape: list[int | torch.SymInt], + seed: int, + dtype: torch.dtype = torch.float32, + device: torch.device | None = None, +) -> torch.Tensor: + if not isinstance(shape, (list, tuple)): + raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}") + env = CompileEnvironment.current() + env.add_kernel_tensor_size(shape) + return torch.empty( + [*shape], + dtype=dtype, + device=env.device if device is None else device, + ) + + +@_decorators.codegen(rand) +def _rand_codegen(state: CodegenState) -> ast.AST: + fake_value = state.fake_value + assert isinstance(fake_value, torch.Tensor) + shape_str = state.device_function.tile_strategy.shape_str(fake_value.size()) + + numel = " * ".join(shape_str.strip("[]").split(",")) + seed_ast = state.ast_arg(1) + offs_expr = f"tl.arange(0, {numel}).reshape({shape_str})" + expr = f"tl.rand({{seed}}, {offs_expr})" + + return expr_from_string(expr, seed=seed_ast) + + +@_decorators.get_masked_value(rand) +def _( + node: torch.fx.Node, +) -> float: + return 0 + + +@_decorators.ref(rand) +def _( + shape: list[int | RefTile], + seed: int, + dtype: torch.dtype = torch.float32, + device: torch.device | None = None, +) -> torch.Tensor: + processed_shape: list[int] = [] + for s in shape: + if isinstance(s, RefTile): + processed_shape.append(s.end - s.begin) + else: + processed_shape.append(int(s)) + env = CompileEnvironment.current() + gen = torch.Generator(device=env.device if device is None else device) + gen.manual_seed(seed) + return torch.rand( + processed_shape, + dtype=dtype, + generator=gen, + device=env.device if device is None else device, + ) diff --git a/test/test_rng.py b/test/test_rng.py index 585af03e2..7822d3fcf 100644 --- a/test/test_rng.py +++ b/test/test_rng.py @@ -348,6 +348,97 @@ def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor: f"Slice {b_idx} std {slice_std} is not well distributed", ) + def test_hl_rand_1d(self): + @helion.kernel + def rand_kernel_tiled_1d(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + (m,) = x.shape + for (tile_m,) in hl.tile([m]): + output[tile_m] = hl.rand([tile_m], seed=seed) + return output + + x_small = torch.ones(128, device=DEVICE) + _, output = code_and_output(rand_kernel_tiled_1d, (x_small, 42)) + _, output2 = code_and_output(rand_kernel_tiled_1d, (x_small, 1337)) + + self.assertFalse( + torch.allclose(output, output2), + "Different seeds should produce different outputs", + ) + + _, output3 = code_and_output(rand_kernel_tiled_1d, (x_small, 42)) + self.assertTrue( + torch.allclose(output, output3), + "Same seed should produce identical outputs", + ) + + # Check that all values are in [0, 1) range + self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0") + self.assertTrue(torch.all(output < 1.0), "All values should be < 1") + + def test_hl_rand_2d(self): + @helion.kernel + def rand_kernel_tiled_2d(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = hl.rand([tile_m, tile_n], seed=seed) + return output + + x_small = torch.ones(128, 128, device=DEVICE) + _, output = code_and_output(rand_kernel_tiled_2d, (x_small, 42)) + _, output2 = code_and_output(rand_kernel_tiled_2d, (x_small, 1337)) + + self.assertFalse( + torch.allclose(output, output2), + "Different seeds should produce different outputs", + ) + + _, output3 = code_and_output(rand_kernel_tiled_2d, (x_small, 42)) + self.assertTrue( + torch.allclose(output, output3), + "Same seed should produce identical outputs", + ) + + self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0") + self.assertTrue(torch.all(output < 1.0), "All values should be < 1") + + def test_hl_rand_3d(self): + @helion.kernel + def rand_kernel_tiled_3d(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + b, m, n = x.shape + for tile_b, tile_m, tile_n in hl.tile([b, m, n]): + output[tile_b, tile_m, tile_n] = hl.rand( + [tile_b, tile_m, tile_n], seed=seed + ) + return output + + x_small = torch.ones(16, 32, 64, device=DEVICE) + _, output = code_and_output(rand_kernel_tiled_3d, (x_small, 42)) + _, output2 = code_and_output(rand_kernel_tiled_3d, (x_small, 1337)) + + self.assertFalse( + torch.allclose(output, output2), + "Different seeds should produce different outputs", + ) + + _, output3 = code_and_output(rand_kernel_tiled_3d, (x_small, 42)) + self.assertTrue( + torch.allclose(output, output3), + "Same seed should produce identical outputs", + ) + + self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0") + self.assertTrue(torch.all(output < 1.0), "All values should be < 1") + + # Check distribution properties + mean_val = output.mean().item() + self.assertTrue( + 0.4 < mean_val < 0.6, + f"Mean {mean_val:.3f} should be around 0.5 for uniform distribution", + ) + if __name__ == "__main__": unittest.main()