diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 2fc1910e7..3b6c2e2bd 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -352,6 +352,12 @@ def get_broadcast_str( return stack_broadcast, tensor_broadcast + @staticmethod + def get_element_broadcast_slice(dim_index: int, total_dims: int) -> str: + broadcast_keys = ["None"] * total_dims + broadcast_keys[dim_index] = ":" + return f"[{', '.join(broadcast_keys)}]" + @staticmethod def get_mask_expr( state: CodegenState, diff --git a/helion/language/__init__.py b/helion/language/__init__.py index 9f33822b1..06f02fa22 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -22,6 +22,7 @@ from .matmul_ops import dot as dot from .memory_ops import load as load from .memory_ops import store as store +from .random_ops import rand as rand from .reduce_ops import reduce as reduce from .scan_ops import associative_scan as associative_scan from .scan_ops import cumprod as cumprod diff --git a/helion/language/random_ops.py b/helion/language/random_ops.py new file mode 100644 index 000000000..412ec1f6c --- /dev/null +++ b/helion/language/random_ops.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from .._compiler.ast_extension import expr_from_string +from .._compiler.compile_environment import CompileEnvironment +from .._compiler.indexing_strategy import StackIndexingStrategy +from ..exc import NotInsideKernel +from . import _decorators +from .ref_tile import RefTile + +if TYPE_CHECKING: + import ast + + from .._compiler.inductor_lowering import CodegenState + +__all__ = ["rand"] + + +@_decorators.api(tiles_as_sizes=True) +def rand( + shape: list[object], + seed: int | torch.Tensor, + device: torch.device | None = None, +) -> torch.Tensor: + """ + hl.rand provides a Philox-based pseudorandom number generator (PRNG) that operates independently of PyTorch’s global random seed. + Instead, it requires an explicit seed argument. Offsets are derived from the full logical sizes of the tiles specified in the shape argument. + + Args: + shape: A list of sizes for the output tensor + seed: A single element int64 tensor or int literal + + Returns: + torch.Tensor: A device tensor of float32 dtype filled with uniform random values in [0, 1) + + Examples: + .. code-block:: python + + @helion.kernel + def process_kernel(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + (m,) = x.shape + for tile_m in hl.tile(m): + output[tile_m] = hl.rand([tile_m], seed=42) + return output + + """ + raise NotInsideKernel + + +@_decorators.register_fake(rand) +def _rand_fake( + shape: list[int | torch.SymInt], + seed: int | torch.Tensor, + device: torch.device | None = None, +) -> torch.Tensor: + if not isinstance(shape, (list, tuple)): + raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}") + env = CompileEnvironment.current() + env.add_kernel_tensor_size(shape) + return torch.empty( + [*shape], + dtype=torch.float32, + device=env.device if device is None else device, + ) + + +@_decorators.codegen(rand) +def _rand_codegen(state: CodegenState) -> ast.AST: + """ + Generate tl.rand() code with global indices for deterministic RNG per element. + + This implementation uses improved dimension detection and broadcasting logic + while maintaining compatibility with the existing approach. + """ + fake_value = state.fake_value + assert isinstance(fake_value, torch.Tensor) + + env = CompileEnvironment.current() + tensor_shape = fake_value.size() + ndim = len(tensor_shape) + if ndim == 0: + raise ValueError("hl.rand() requires at least one dimension") + + seed_ast = state.ast_arg(1) + + index_vars = [] + size_names = [] + for i in range(ndim): + size = tensor_shape[i] + block_id = env.get_block_id(size) + if block_id is not None: + index_vars.append(state.codegen.index_var(block_id)) + original_tensor_size = env.block_sizes[block_id].size + assert isinstance(original_tensor_size, torch.SymInt), ( + f"Expected SymInt, got {type(original_tensor_size)}" + ) + size_names.append( + state.device_function.sympy_expr(original_tensor_size._sympy_()) + ) + else: + rdim = env.allocate_reduction_dimension(size) + index_vars.append(state.codegen.index_var(rdim.block_id)) + assert isinstance(rdim.var, torch.SymInt), ( + f"Expected SymInt, got {type(rdim.var)}" + ) + size_names.append(state.device_function.sympy_expr(rdim.var._sympy_())) + + if ndim == 1: + offset_expr = expr_from_string(index_vars[0]) + else: + offset_parts = [] + for i in range(ndim): + broadcast_slice = StackIndexingStrategy.get_element_broadcast_slice(i, ndim) + broadcasted_index = f"{index_vars[i]}{broadcast_slice}" + if i < ndim - 1: + stride_expr = " * ".join(map("({})".format, size_names[i + 1 :])) + offset_parts.append(f"{broadcasted_index} * {stride_expr}") + else: + offset_parts.append(broadcasted_index) + offset_expr = expr_from_string(" + ".join(offset_parts)) + return expr_from_string( + "tl.rand({seed}, {offset})", seed=seed_ast, offset=offset_expr + ) + + +@_decorators.get_masked_value(rand) +def _( + node: torch.fx.Node, +) -> float: + return 0 + + +@_decorators.ref(rand) +def _( + shape: list[int | RefTile], + seed: int | torch.Tensor, + device: torch.device | None = None, +) -> torch.Tensor: + processed_shape: list[int] = [] + for s in shape: + if isinstance(s, RefTile): + processed_shape.append(s.end - s.begin) + else: + processed_shape.append(int(s)) + env = CompileEnvironment.current() + gen = torch.Generator(device=env.device if device is None else device) + if isinstance(seed, torch.Tensor): + gen.manual_seed(int(seed.item())) + else: + gen.manual_seed(seed) + return torch.rand( + processed_shape, + dtype=torch.float32, + generator=gen, + device=env.device if device is None else device, + ) diff --git a/test/test_random.expected b/test/test_random.expected new file mode 100644 index 000000000..9975307c0 --- /dev/null +++ b/test/test_random.expected @@ -0,0 +1,262 @@ +This file is automatically generated by assertExpectedJournal calls in test_random.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestRandom.test_hl_rand_1d) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_rand_kernel_tiled_1d(output, output_stride_0, m, seed, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + rand = tl.rand(seed, indices_0) + tl.store(output + indices_0 * output_stride_0, rand, mask_0) + +def rand_kernel_tiled_1d(x: torch.Tensor, seed: int, *, _launcher=_default_launcher): + output = torch.zeros_like(x) + m, = x.shape + _BLOCK_SIZE_0 = 128 + _launcher(_helion_rand_kernel_tiled_1d, (triton.cdiv(m, _BLOCK_SIZE_0),), output, output.stride(0), m, seed, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return output + +--- assertExpectedJournal(TestRandom.test_hl_rand_2d) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_rand_kernel_tiled_2d(output, output_stride_0, output_stride_1, m, n, seed, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + rand = tl.rand(seed, indices_0[:, None] * n + indices_1[None, :]) + tl.store(output + (indices_0[:, None] * output_stride_0 + indices_1[None, :] * output_stride_1), rand, mask_0[:, None] & mask_1[None, :]) + +def rand_kernel_tiled_2d(x: torch.Tensor, seed: int, *, _launcher=_default_launcher): + output = torch.zeros_like(x) + m, n = x.shape + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_rand_kernel_tiled_2d, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), output, output.stride(0), output.stride(1), m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return output + +--- assertExpectedJournal(TestRandom.test_hl_rand_3d) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_rand_kernel_tiled_3d(output, output_stride_0, output_stride_1, output_stride_2, b, m, n, seed, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(b, _BLOCK_SIZE_0) + num_blocks_1 = tl.cdiv(m, _BLOCK_SIZE_1) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1 + pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < b + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < m + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < n + rand = tl.rand(seed, indices_0[:, None, None] * m * n + indices_1[None, :, None] * n + indices_2[None, None, :]) + tl.store(output + (indices_0[:, None, None] * output_stride_0 + indices_1[None, :, None] * output_stride_1 + indices_2[None, None, :] * output_stride_2), rand, mask_0[:, None, None] & mask_1[None, :, None] & mask_2[None, None, :]) + +def rand_kernel_tiled_3d(x: torch.Tensor, seed: int, *, _launcher=_default_launcher): + output = torch.zeros_like(x) + b, m, n = x.shape + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + _BLOCK_SIZE_2 = 16 + _launcher(_helion_rand_kernel_tiled_3d, (triton.cdiv(b, _BLOCK_SIZE_0) * triton.cdiv(m, _BLOCK_SIZE_1) * triton.cdiv(n, _BLOCK_SIZE_2),), output, output.stride(0), output.stride(1), output.stride(2), b, m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return output + +--- assertExpectedJournal(TestRandom.test_hl_rand_mixed_argument_order) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_rand_kernel_normal_order(output, output_stride_0, output_stride_1, output_stride_2, m, n, k, seed, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + num_blocks_1 = tl.cdiv(n, _BLOCK_SIZE_1) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1 + pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < k + rand = tl.rand(seed, indices_0[:, None, None] * n * k + indices_1[None, :, None] * k + indices_2[None, None, :]) + tl.store(output + (indices_0[:, None, None] * output_stride_0 + indices_1[None, :, None] * output_stride_1 + indices_2[None, None, :] * output_stride_2), rand, mask_0[:, None, None] & mask_1[None, :, None] & mask_2[None, None, :]) + +def rand_kernel_normal_order(x: torch.Tensor, seed: int, *, _launcher=_default_launcher): + output = torch.zeros_like(x) + m, n, k = x.shape + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + _BLOCK_SIZE_2 = 16 + _launcher(_helion_rand_kernel_normal_order, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1) * triton.cdiv(k, _BLOCK_SIZE_2),), output, output.stride(0), output.stride(1), output.stride(2), m, n, k, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return output + +--- assertExpectedJournal(TestRandom.test_hl_rand_mixed_argument_order) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_rand_kernel_mixed_order(output, output_stride_0, output_stride_1, output_stride_2, k, m, n, seed, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(k, _BLOCK_SIZE_0) + num_blocks_1 = tl.cdiv(m, _BLOCK_SIZE_1) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1 + pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < k + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < m + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < n + rand = tl.rand(seed, indices_1[:, None, None] * n * k + indices_2[None, :, None] * k + indices_0[None, None, :]) + tl.store(output + (indices_1[:, None, None] * output_stride_0 + indices_2[None, :, None] * output_stride_1 + indices_0[None, None, :] * output_stride_2), rand, mask_1[:, None, None] & mask_2[None, :, None] & mask_0[None, None, :]) + +def rand_kernel_mixed_order(x: torch.Tensor, seed: int, *, _launcher=_default_launcher): + output = torch.zeros_like(x) + m, n, k = x.shape + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + _BLOCK_SIZE_2 = 16 + _launcher(_helion_rand_kernel_mixed_order, (triton.cdiv(k, _BLOCK_SIZE_0) * triton.cdiv(m, _BLOCK_SIZE_1) * triton.cdiv(n, _BLOCK_SIZE_2),), output, output.stride(0), output.stride(1), output.stride(2), k, m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return output + +--- assertExpectedJournal(TestRandom.test_hl_rand_non_tiled_dimensions) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_rand_kernel_partial_tile(output, output_stride_0, output_stride_1, output_stride_2, m, n, seed, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + rand = tl.rand(seed, indices_0[:, None, None] * n * _RDIM_SIZE_2 + indices_1[None, :, None] * _RDIM_SIZE_2 + indices_2[None, None, :]) + tl.store(output + (indices_0[:, None, None] * output_stride_0 + indices_1[None, :, None] * output_stride_1 + indices_2[None, None, :] * output_stride_2), rand, mask_0[:, None, None] & mask_1[None, :, None]) + +def rand_kernel_partial_tile(x: torch.Tensor, seed: int, *, _launcher=_default_launcher): + output = torch.zeros_like(x) + m, n, k = x.shape + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _RDIM_SIZE_2 = 8 + _launcher(_helion_rand_kernel_partial_tile, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), output, output.stride(0), output.stride(1), output.stride(2), m, n, seed, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=3) + return output + +--- assertExpectedJournal(TestRandom.test_hl_rand_rolled_reductions) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_rand_kernel_with_reduction(x, output, output_stride_0, x_stride_0, x_stride_1, m, n, seed, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < n + tile_values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + rand_values = tl.rand(seed, indices_0) + mean_val_extra = tl.cast(tl.sum(tile_values, 1), tl.float32) + v_0 = mean_val_extra / n.to(tl.float32) + v_1 = rand_values * v_0 + tl.store(output + indices_0 * output_stride_0, v_1, mask_0) + +def rand_kernel_with_reduction(x: torch.Tensor, seed: int, *, _launcher=_default_launcher): + m, n = x.shape + output = torch.zeros([m], device=x.device) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_1 = triton.next_power_of_2(n) + _launcher(_helion_rand_kernel_with_reduction, (triton.cdiv(m, _BLOCK_SIZE_0),), x, output, output.stride(0), x.stride(0), x.stride(1), m, n, seed, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return output + +--- assertExpectedJournal(TestRandom.test_hl_rand_rolled_reductions) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_rand_kernel_with_reduction(x, output, output_stride_0, x_stride_0, x_stride_1, m, seed, n, _BLOCK_SIZE_0: tl.constexpr, _REDUCTION_BLOCK_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + rand_values = tl.rand(seed, indices_0) + mean_val_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32) + for roffset_1 in tl.range(0, n, _REDUCTION_BLOCK_1): + rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32) + mask_1 = rindex_1 < n + tile_values = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_0 = mean_val_extra_acc + tile_values + mean_val_extra_acc = v_0 + mean_val_extra = tl.cast(tl.sum(mean_val_extra_acc, 1), tl.float32) + v_1 = mean_val_extra / n.to(tl.float32) + v_2 = rand_values * v_1 + tl.store(output + indices_0 * output_stride_0, v_2, mask_0) + +def rand_kernel_with_reduction(x: torch.Tensor, seed: int, *, _launcher=_default_launcher): + m, n = x.shape + output = torch.zeros([m], device=x.device) + _BLOCK_SIZE_0 = 32 + _REDUCTION_BLOCK_1 = 64 + _launcher(_helion_rand_kernel_with_reduction, (triton.cdiv(m, _BLOCK_SIZE_0),), x, output, output.stride(0), x.stride(0), x.stride(1), m, seed, n, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3) + return output diff --git a/test/test_random.py b/test/test_random.py new file mode 100644 index 000000000..317c8c098 --- /dev/null +++ b/test/test_random.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import unittest + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import RefEagerTestBase +from helion._testing import TestCase +from helion._testing import code_and_output +import helion.language as hl + + +class TestRandom(RefEagerTestBase, TestCase): + def test_hl_rand_1d(self): + @helion.kernel + def rand_kernel_tiled_1d(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + (m,) = x.shape + for tile_m in hl.tile(m): + output[tile_m] = hl.rand([tile_m], seed=seed) + return output + + x_small = torch.ones(128, device=DEVICE) + _, output = code_and_output(rand_kernel_tiled_1d, (x_small, 42)) + _, output2 = code_and_output(rand_kernel_tiled_1d, (x_small, 1337)) + + self.assertFalse( + torch.allclose(output, output2), + "Different seeds should produce different outputs", + ) + + code3, output3 = code_and_output(rand_kernel_tiled_1d, (x_small, 42)) + self.assertTrue( + torch.allclose(output, output3), + "Same seed should produce identical outputs", + ) + + # Check that all values are in [0, 1) range + self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0") + self.assertTrue(torch.all(output < 1.0), "All values should be < 1") + + self.assertExpectedJournal(code3) + + def test_hl_rand_2d(self): + @helion.kernel + def rand_kernel_tiled_2d(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = hl.rand([tile_m, tile_n], seed=seed) + return output + + x_small = torch.ones(128, 128, device=DEVICE) + _, output = code_and_output(rand_kernel_tiled_2d, (x_small, 42)) + _, output2 = code_and_output(rand_kernel_tiled_2d, (x_small, 1337)) + + self.assertFalse( + torch.allclose(output, output2), + "Different seeds should produce different outputs", + ) + + code3, output3 = code_and_output(rand_kernel_tiled_2d, (x_small, 42)) + self.assertTrue( + torch.allclose(output, output3), + "Same seed should produce identical outputs", + ) + + self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0") + self.assertTrue(torch.all(output < 1.0), "All values should be < 1") + self.assertExpectedJournal(code3) + + def test_hl_rand_3d(self): + @helion.kernel + def rand_kernel_tiled_3d(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + b, m, n = x.shape + for tile_b, tile_m, tile_n in hl.tile([b, m, n]): + output[tile_b, tile_m, tile_n] = hl.rand( + [tile_b, tile_m, tile_n], seed=seed + ) + return output + + x_small = torch.ones(16, 32, 64, device=DEVICE) + _, output = code_and_output(rand_kernel_tiled_3d, (x_small, 42)) + _, output2 = code_and_output(rand_kernel_tiled_3d, (x_small, 1337)) + + self.assertFalse( + torch.allclose(output, output2), + "Different seeds should produce different outputs", + ) + + code3, output3 = code_and_output(rand_kernel_tiled_3d, (x_small, 42)) + self.assertTrue( + torch.allclose(output, output3), + "Same seed should produce identical outputs", + ) + + self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0") + self.assertTrue(torch.all(output < 1.0), "All values should be < 1") + + # Check distribution properties + mean_val = output.mean().item() + self.assertTrue( + 0.4 < mean_val < 0.6, + f"Mean {mean_val:.3f} should be around 0.5 for uniform distribution", + ) + self.assertExpectedJournal(code3) + + def test_hl_rand_block_size_determinism(self): + @helion.kernel + def rand_kernel_2d(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = hl.rand([tile_m, tile_n], seed=seed) + return output + + x = torch.ones(128, 256, device=DEVICE) + seed = 42 + + _, output_32_32 = code_and_output( + rand_kernel_2d, (x, seed), block_sizes=[32, 32] + ) + _, output_64_64 = code_and_output( + rand_kernel_2d, (x, seed), block_sizes=[64, 64] + ) + _, output_128_128 = code_and_output( + rand_kernel_2d, (x, seed), block_sizes=[128, 128] + ) + _, output_16_32 = code_and_output( + rand_kernel_2d, (x, seed), block_sizes=[16, 32] + ) + + torch.testing.assert_close( + output_32_32, + output_64_64, + msg="rand should be deterministic across different block sizes (32x32 vs 64x64)", + ) + torch.testing.assert_close( + output_32_32, + output_128_128, + msg="rand should be deterministic across different block sizes (32x32 vs 128x128)", + ) + torch.testing.assert_close( + output_32_32, + output_16_32, + msg="rand should be deterministic across different block sizes (32x32 vs 16x32)", + ) + + self.assertTrue(torch.all(output_32_32 >= 0.0)) + self.assertTrue(torch.all(output_32_32 < 1.0)) + + def test_hl_rand_uniqueness_distribution(self): + @helion.kernel + def rand_kernel(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = hl.rand([tile_m, tile_n], seed=seed) + return output + + x = torch.ones(256, 256, device=DEVICE) + seed = 1337 + + _, output = code_and_output(rand_kernel, (x, seed)) + + sorted_values = torch.sort(output.flatten()).values + + unique_values = torch.unique(sorted_values) + total_values = output.numel() + uniqueness_ratio = len(unique_values) / total_values + + self.assertGreater( + uniqueness_ratio, + 0.99, + f"Expected >99% unique values, got {uniqueness_ratio:.4f}", + ) + + n_quartile = total_values // 4 + q1_val = sorted_values[n_quartile].item() + q2_val = sorted_values[2 * n_quartile].item() + q3_val = sorted_values[3 * n_quartile].item() + + self.assertTrue( + 0.2 < q1_val < 0.3, f"First quartile {q1_val:.3f} should be around 0.25" + ) + self.assertTrue( + 0.45 < q2_val < 0.55, f"Median {q2_val:.3f} should be around 0.5" + ) + self.assertTrue( + 0.7 < q3_val < 0.8, f"Third quartile {q3_val:.3f} should be around 0.75" + ) + + def test_hl_rand_non_tiled_dimensions(self): + @helion.kernel + def rand_kernel_partial_tile(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + m, n, k = x.shape + k = hl.specialize(k) + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n, :] = hl.rand([tile_m, tile_n, k], seed=seed) + return output + + x = torch.ones(64, 64, 8, device=DEVICE) + seed = 1337 + + _, output = code_and_output(rand_kernel_partial_tile, (x, seed)) + + self.assertTrue(torch.all(output >= 0.0)) + self.assertTrue(torch.all(output < 1.0)) + + code2, output2 = code_and_output(rand_kernel_partial_tile, (x, seed)) + torch.testing.assert_close(output, output2, msg="it should deterministic") + + self.assertExpectedJournal(code2) + + def test_hl_rand_mixed_argument_order(self): + @helion.kernel + def rand_kernel_normal_order(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + m, n, k = x.shape + for tile_m, tile_n, tile_k in hl.tile([m, n, k]): + output[tile_m, tile_n, tile_k] = hl.rand( + [tile_m, tile_n, tile_k], seed=seed + ) + return output + + @helion.kernel + def rand_kernel_mixed_order(x: torch.Tensor, seed: int) -> torch.Tensor: + output = torch.zeros_like(x) + m, n, k = x.shape + for tile_k, tile_m, tile_n in hl.tile([k, m, n]): + output[tile_m, tile_n, tile_k] = hl.rand( + [tile_m, tile_n, tile_k], seed=seed + ) + return output + + x = torch.ones(32, 64, 16, device=DEVICE) + seed = 1337 + + code1, output1 = code_and_output(rand_kernel_normal_order, (x, seed)) + code2, output2 = code_and_output(rand_kernel_mixed_order, (x, seed)) + self.assertExpectedJournal(code1) + self.assertExpectedJournal(code2) + + torch.testing.assert_close( + output1, + output2, + msg="Mixed tile argument order should produce identical results", + ) + + def test_hl_rand_rolled_reductions(self): + @helion.kernel + def rand_kernel_with_reduction(x: torch.Tensor, seed: int) -> torch.Tensor: + m, n = x.shape + output = torch.zeros([m], device=x.device) + for tile_m in hl.tile(m): + tile_values = x[tile_m, :] + rand_values = hl.rand([tile_m], seed=seed) + mean_val = tile_values.mean(-1) + output[tile_m] = rand_values * mean_val + return output + + x = torch.ones(64, 128, device=DEVICE) + seed = 42 + + code1, output_persistent = code_and_output( + rand_kernel_with_reduction, + (x, seed), + block_sizes=[32], + reduction_loops=[None], + ) + code2, output_rolled = code_and_output( + rand_kernel_with_reduction, + (x, seed), + block_sizes=[32], + reduction_loops=[64], + ) + self.assertExpectedJournal(code1) + self.assertExpectedJournal(code2) + + torch.testing.assert_close( + output_persistent, + output_rolled, + msg="Persistent and rolled reductions should produce identical results", + ) + + +if __name__ == "__main__": + unittest.main()