diff --git a/helion/language/__init__.py b/helion/language/__init__.py index 856897808..47f6945fc 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -21,7 +21,6 @@ from .matmul_ops import dot as dot from .memory_ops import load as load from .memory_ops import store as store -from .random_ops import rand as rand from .reduce_ops import reduce as reduce from .scan_ops import associative_scan as associative_scan from .scan_ops import cumprod as cumprod diff --git a/helion/language/random_ops.py b/helion/language/random_ops.py deleted file mode 100644 index bf72fd86e..000000000 --- a/helion/language/random_ops.py +++ /dev/null @@ -1,121 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import torch - -from .._compiler.ast_extension import expr_from_string -from .._compiler.compile_environment import CompileEnvironment -from ..exc import NotInsideKernel -from . import _decorators -from .ref_tile import RefTile - -if TYPE_CHECKING: - import ast - - from .._compiler.inductor_lowering import CodegenState - -__all__ = ["rand"] - - -@_decorators.api(tiles_as_sizes=True) -def rand( - shape: list[object], - seed: int, - dtype: torch.dtype = torch.float32, - device: torch.device | None = None, -) -> torch.Tensor: - """ - The main propose of ``hl.rand`` is to explicitly pass a seed arg for deterministic - randomness in helion kernels, whereas ``torch.rand_like`` doesn't take seed arg - (though it can seeded globally)`. ``hl.rand`` lower to ``tl.rand(seed, offset)`` with ``offset`` - built from a linear range over the allocation and reshaped to the given shape. - - Note: - Only use within ``hl.tile()`` loops for creating local tensors. - For host allocations, use ``torch.rand()``. - - Args: - shape: A list of sizes - seed: int seed for the random number generator - dtype: currently only float32 supported - - Returns: - torch.Tensor: A device tensor of the given shape and dtype filled with random values - - Examples: - .. code-block:: python - - @helion.kernel - def process_kernel(x: torch.Tensor) -> torch.Tensor: - output = torch.zeros_like(x) - (m,) = x.shape - for (tile_m,) in hl.tile([m]): - output[tile_m] = hl.rand([tile_m], seed=seed) - return output - - """ - raise NotInsideKernel - - -@_decorators.register_fake(rand) -def _rand_fake( - shape: list[int | torch.SymInt], - seed: int, - dtype: torch.dtype = torch.float32, - device: torch.device | None = None, -) -> torch.Tensor: - if not isinstance(shape, (list, tuple)): - raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}") - env = CompileEnvironment.current() - env.add_kernel_tensor_size(shape) - return torch.empty( - [*shape], - dtype=dtype, - device=env.device if device is None else device, - ) - - -@_decorators.codegen(rand) -def _rand_codegen(state: CodegenState) -> ast.AST: - fake_value = state.fake_value - assert isinstance(fake_value, torch.Tensor) - shape_str = state.device_function.tile_strategy.shape_str(fake_value.size()) - - numel = " * ".join(shape_str.strip("[]").split(",")) - seed_ast = state.ast_arg(1) - offs_expr = f"tl.arange(0, {numel}).reshape({shape_str})" - expr = f"tl.rand({{seed}}, {offs_expr})" - - return expr_from_string(expr, seed=seed_ast) - - -@_decorators.get_masked_value(rand) -def _( - node: torch.fx.Node, -) -> float: - return 0 - - -@_decorators.ref(rand) -def _( - shape: list[int | RefTile], - seed: int, - dtype: torch.dtype = torch.float32, - device: torch.device | None = None, -) -> torch.Tensor: - processed_shape: list[int] = [] - for s in shape: - if isinstance(s, RefTile): - processed_shape.append(s.end - s.begin) - else: - processed_shape.append(int(s)) - env = CompileEnvironment.current() - gen = torch.Generator(device=env.device if device is None else device) - gen.manual_seed(seed) - return torch.rand( - processed_shape, - dtype=dtype, - generator=gen, - device=env.device if device is None else device, - ) diff --git a/test/test_rng.py b/test/test_rng.py index 7822d3fcf..585af03e2 100644 --- a/test/test_rng.py +++ b/test/test_rng.py @@ -348,97 +348,6 @@ def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor: f"Slice {b_idx} std {slice_std} is not well distributed", ) - def test_hl_rand_1d(self): - @helion.kernel - def rand_kernel_tiled_1d(x: torch.Tensor, seed: int) -> torch.Tensor: - output = torch.zeros_like(x) - (m,) = x.shape - for (tile_m,) in hl.tile([m]): - output[tile_m] = hl.rand([tile_m], seed=seed) - return output - - x_small = torch.ones(128, device=DEVICE) - _, output = code_and_output(rand_kernel_tiled_1d, (x_small, 42)) - _, output2 = code_and_output(rand_kernel_tiled_1d, (x_small, 1337)) - - self.assertFalse( - torch.allclose(output, output2), - "Different seeds should produce different outputs", - ) - - _, output3 = code_and_output(rand_kernel_tiled_1d, (x_small, 42)) - self.assertTrue( - torch.allclose(output, output3), - "Same seed should produce identical outputs", - ) - - # Check that all values are in [0, 1) range - self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0") - self.assertTrue(torch.all(output < 1.0), "All values should be < 1") - - def test_hl_rand_2d(self): - @helion.kernel - def rand_kernel_tiled_2d(x: torch.Tensor, seed: int) -> torch.Tensor: - output = torch.zeros_like(x) - m, n = x.shape - for tile_m, tile_n in hl.tile([m, n]): - output[tile_m, tile_n] = hl.rand([tile_m, tile_n], seed=seed) - return output - - x_small = torch.ones(128, 128, device=DEVICE) - _, output = code_and_output(rand_kernel_tiled_2d, (x_small, 42)) - _, output2 = code_and_output(rand_kernel_tiled_2d, (x_small, 1337)) - - self.assertFalse( - torch.allclose(output, output2), - "Different seeds should produce different outputs", - ) - - _, output3 = code_and_output(rand_kernel_tiled_2d, (x_small, 42)) - self.assertTrue( - torch.allclose(output, output3), - "Same seed should produce identical outputs", - ) - - self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0") - self.assertTrue(torch.all(output < 1.0), "All values should be < 1") - - def test_hl_rand_3d(self): - @helion.kernel - def rand_kernel_tiled_3d(x: torch.Tensor, seed: int) -> torch.Tensor: - output = torch.zeros_like(x) - b, m, n = x.shape - for tile_b, tile_m, tile_n in hl.tile([b, m, n]): - output[tile_b, tile_m, tile_n] = hl.rand( - [tile_b, tile_m, tile_n], seed=seed - ) - return output - - x_small = torch.ones(16, 32, 64, device=DEVICE) - _, output = code_and_output(rand_kernel_tiled_3d, (x_small, 42)) - _, output2 = code_and_output(rand_kernel_tiled_3d, (x_small, 1337)) - - self.assertFalse( - torch.allclose(output, output2), - "Different seeds should produce different outputs", - ) - - _, output3 = code_and_output(rand_kernel_tiled_3d, (x_small, 42)) - self.assertTrue( - torch.allclose(output, output3), - "Same seed should produce identical outputs", - ) - - self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0") - self.assertTrue(torch.all(output < 1.0), "All values should be < 1") - - # Check distribution properties - mean_val = output.mean().item() - self.assertTrue( - 0.4 < mean_val < 0.6, - f"Mean {mean_val:.3f} should be around 0.5 for uniform distribution", - ) - if __name__ == "__main__": unittest.main()