diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 914aafa25..8a37f177d 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -241,6 +241,29 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config) self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config) + self.rng_seed_count = 0 + # Name of the RNG seed buffer parameter in kernel signature + self.rng_seed_buffer_param_name = None + + def has_rng_ops(self) -> bool: + """Check if this kernel uses any RNG operations.""" + return self.rng_seed_count > 0 and self.rng_seed_buffer_param_name is not None + + def allocate_rng_seed(self) -> int: + """Allocate a new RNG seed index and ensure buffer argument exists. + + Returns: + The seed index for this RNG operation. + """ + seed_index = self.rng_seed_count + self.rng_seed_count += 1 + + # Ensure seed buffer parameter name exists + if self.rng_seed_buffer_param_name is None: + self.rng_seed_buffer_param_name = self.new_var("rng_seed_buffer") + + return seed_index + def block_size_var(self, block_id: int) -> str | None: return self.block_size_var_cache.get((block_id,)) @@ -487,15 +510,20 @@ def codegen_function_def(self) -> list[ast.stmt]: prefix.append( statement_from_string("helion.runtime.set_triton_allocator()") ) + + args = [arg.arg_def_node() for arg in self.sorted_args()] + if self.has_rng_ops(): + # Add the seed buffer as a pointer parameter to kernel signature + assert self.rng_seed_buffer_param_name is not None + args.append(create_arg(self.rng_seed_buffer_param_name)) + return [ *prefix, ast_rename( create( ast.FunctionDef, name=self.name, - args=create_arguments( - [arg.arg_def_node() for arg in self.sorted_args()] - ), + args=create_arguments(args), body=[*self.preamble, *self.body], decorator_list=[expr_from_string("triton.jit")], type_params=[], @@ -507,6 +535,10 @@ def codegen_function_def(self) -> list[ast.stmt]: def codegen_function_call(self) -> ast.AST: args = [arg.host_str() for arg in self.sorted_args()] + if self.has_rng_ops(): + # Pass the host-side seed buffer variable to the kernel + args.append("_rng_seed_buffer") + # Workaround for triton bug: warp_specialize requires at least 4 warps # See: https://github.com/triton-lang/triton/issues/7354 num_warps = self.config.num_warps diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 2963a97ae..e4caf8e02 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -74,6 +74,18 @@ def add_statement(self, stmt: ast.AST | str | None) -> None: stmt = statement_from_string(stmt) self.statements_stack[-1].append(stmt) + def get_rng_seed_buffer_statements(self) -> list[ast.AST]: + import_stmt = statement_from_string( + "from torch._inductor import inductor_prims" + ) + + # Create host-side seed buffer with the required number of seeds + seed_buffer_stmt = statement_from_string( + f"_rng_seed_buffer = inductor_prims.seeds({self.device_function.rng_seed_count}, torch.device('cuda'))" + ) + + return [import_stmt, seed_buffer_stmt] + def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name: if isinstance(expr, ast.Name): return expr @@ -395,7 +407,16 @@ def generate_ast( codegen.add_statement(codegen.visit(stmt)) kernel_def = codegen.device_function.codegen_function_def() codegen.host_dead_code_elimination() - host_def = func.codegen_function_def(codegen.host_statements) + + # Inject RNG seed buffer creation if needed + rng_statements = ( + codegen.get_rng_seed_buffer_statements() + if codegen.device_function.has_rng_ops() + else [] + ) + final_host_statements = rng_statements + codegen.host_statements + + host_def = func.codegen_function_def(final_host_statements) call_def = [] main_def = [] diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index e4c8c32c7..29cb21efa 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -50,6 +50,7 @@ from .ast_extension import expr_from_string from .ast_extension import statement_from_string from .compile_environment import CompileEnvironment +from .device_function import SymbolArgument from .device_function import VarInfo from .device_function import contains_only_block_size_symbols from .node_masking import apply_masking @@ -1307,3 +1308,105 @@ def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object: step=ctx.to_ast(step), length=ctx.to_ast(length_arg), ) + + +def _codegen_rng_op( + ctx: GraphInterpreter, + node: torch.fx.Node, + rng_function: str, +) -> object: + """Common codegen implementation for all RNG operations. + + Args: + ctx: The graph interpreter context + node: The FX node for this operation + rng_function: Either "rand" or "randn" + """ + assert rng_function in ["rand", "randn"] + + # Get unique seed index for this RNG operation + device_fn = ctx.cg.device_function + seed_index = device_fn.allocate_rng_seed() + + # Get dimensionality and dtype + assert hasattr(node, "meta") and "val" in node.meta + ndim = node.meta["val"].ndim + dtype = node.kwargs.get("dtype", None) + + # Get the dimension variable names from the device function's symbol arguments + device_fn = ctx.cg.device_function + symbol_args = [ + arg for arg in device_fn.arguments if isinstance(arg, SymbolArgument) + ] + + # Extract dimension names - they should be the last ndim symbol arguments + dim_names = [] + assert len(symbol_args) >= ndim, "Not enough symbol arguments for dimensions" + dim_names = [arg.name for arg in symbol_args[-ndim:]] + + offset_parts = [] + + for i in range(ndim): + # Create the index variable with proper broadcasting + index_expr = f"indices_{i}" + + # Add broadcasting slices for this dimension + # For 1D tensors, this will just be indices_0 with no slicing + slice_parts = [] + for j in range(ndim): + if j < i: + slice_parts.append("None") + elif j == i: + slice_parts.append(":") + else: + slice_parts.append("None") + + # Create the broadcasted index expression + if ndim == 1: + # For 1D, no broadcasting needed + broadcasted_index = index_expr + else: + broadcasted_index = f"{index_expr}[{', '.join(slice_parts)}]" + + # Calculate stride (product of dimensions after this one) + if i < ndim - 1: + # Use the actual dimension variable names + stride_parts = dim_names[i + 1 :] + stride_expr = " * ".join(stride_parts) + offset_parts.append(f"{broadcasted_index} * {stride_expr}") + else: + # Last dimension has no stride multiplication + offset_parts.append(broadcasted_index) + + offset_expr = expr_from_string(" + ".join(offset_parts)) + + # Load seed from buffer using the kernel parameter name + assert device_fn.rng_seed_buffer_param_name is not None + seed_expr = expr_from_string( + "tl.load({buffer} + {index})", + buffer=expr_from_string(device_fn.rng_seed_buffer_param_name), + index=create(ast.Constant, value=seed_index), + ) + + # Generate the RNG call + # Note: tl.rand() and tl.randn() always return float32 + rng_expr = expr_from_string( + f"tl.{rng_function}({{seed}}, {{offset}})", seed=seed_expr, offset=offset_expr + ) + + # Cast to target dtype only if explicitly specified + if dtype is not None: + assert isinstance(dtype, torch.dtype) + rng_expr = expr_from_string(f"{{val}}.to({triton_type(dtype)})", val=rng_expr) + + return rng_expr + + +@register_lowering(torch.ops.aten.rand.default) # pyright: ignore[reportAttributeAccessIssue] +def codegen_rand(ctx: GraphInterpreter, node: torch.fx.Node) -> object: + return _codegen_rng_op(ctx, node, "rand") + + +@register_lowering(torch.ops.aten.randn.default) # pyright: ignore[reportAttributeAccessIssue] +def codegen_randn(ctx: GraphInterpreter, node: torch.fx.Node) -> object: + return _codegen_rng_op(ctx, node, "randn") diff --git a/helion/_testing.py b/helion/_testing.py index 57a574825..d9d43741a 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -124,6 +124,13 @@ class RefEagerTestBase: _original_skip_test_func = None # Class-level tracking for pytest.raises patching _original_pytest_raises = None + # Class-level tracking for assertTrue/assertFalse/assertGreater + _assert_true_count = 0 + _original_assert_true_func = None + _assert_false_count = 0 + _original_assert_false_func = None + _assert_greater_count = 0 + _original_assert_greater_func = None def setUp(self) -> None: """Common setup for all ref eager tests.""" @@ -142,6 +149,10 @@ def setUp(self) -> None: RefEagerTestBase._assert_raises_count = 0 # Reset skipTest counter for this test RefEagerTestBase._skip_test_count = 0 + # Reset assertTrue/assertFalse/assertGreater counters + RefEagerTestBase._assert_true_count = 0 + RefEagerTestBase._assert_false_count = 0 + RefEagerTestBase._assert_greater_count = 0 # Patch torch.testing.assert_close to count calls if RefEagerTestBase._original_assert_close_func is None: @@ -189,6 +200,36 @@ def counting_pytest_raises(*args: object, **kwargs: object) -> object: pytest.raises = counting_pytest_raises # type: ignore[assignment] + # Patch self.assertTrue to count calls + if RefEagerTestBase._original_assert_true_func is None: + RefEagerTestBase._original_assert_true_func = self.assertTrue + + def counting_assert_true(*args: object, **kwargs: object) -> None: + RefEagerTestBase._assert_true_count += 1 + return RefEagerTestBase._original_assert_true_func(*args, **kwargs) # type: ignore[misc] + + self.assertTrue = counting_assert_true # type: ignore[assignment] + + # Patch self.assertFalse to count calls + if RefEagerTestBase._original_assert_false_func is None: + RefEagerTestBase._original_assert_false_func = self.assertFalse + + def counting_assert_false(*args: object, **kwargs: object) -> None: + RefEagerTestBase._assert_false_count += 1 + return RefEagerTestBase._original_assert_false_func(*args, **kwargs) # type: ignore[misc] + + self.assertFalse = counting_assert_false # type: ignore[assignment] + + # Patch self.assertGreater to count calls + if RefEagerTestBase._original_assert_greater_func is None: + RefEagerTestBase._original_assert_greater_func = self.assertGreater + + def counting_assert_greater(*args: object, **kwargs: object) -> None: + RefEagerTestBase._assert_greater_count += 1 + return RefEagerTestBase._original_assert_greater_func(*args, **kwargs) # type: ignore[misc] + + self.assertGreater = counting_assert_greater # type: ignore[assignment] + def tearDown(self) -> None: """Common teardown with assertion counting check.""" # If not in ref eager mode, skip the teardown logic @@ -215,17 +256,27 @@ def tearDown(self) -> None: ) if not is_skipped: - # Check that either assert_close, assertRaises, or skipTest was called + # Check that either assert_close, assertRaises, skipTest, assertTrue, assertFalse, or assertGreater was called total_assertions = ( RefEagerTestBase._assert_close_count + RefEagerTestBase._assert_raises_count + RefEagerTestBase._skip_test_count + + RefEagerTestBase._assert_true_count + + RefEagerTestBase._assert_false_count + + RefEagerTestBase._assert_greater_count ) - self.assertGreater( # type: ignore[attr-defined] - total_assertions, - 0, - f"Test {self._testMethodName} did not call torch.testing.assert_close, assertRaises, or skipTest", # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - ) + # Need to use the original assertGreater to avoid recursion + if RefEagerTestBase._original_assert_greater_func is not None: + RefEagerTestBase._original_assert_greater_func( # type: ignore[misc] + total_assertions, + 0, + f"Test {self._testMethodName} did not call torch.testing.assert_close, assertRaises, skipTest, assertTrue, assertFalse, or assertGreater", # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + ) + else: + # Fallback if original not available + assert total_assertions > 0, ( + f"Test {self._testMethodName} did not call any assertion methods" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + ) finally: # Restore the original assert_close function if RefEagerTestBase._original_assert_close_func is not None: @@ -245,6 +296,18 @@ def tearDown(self) -> None: if RefEagerTestBase._original_pytest_raises is not None: # pyright: ignore[reportAttributeAccessIssue] pytest.raises = RefEagerTestBase._original_pytest_raises # pyright: ignore[reportAttributeAccessIssue] + # Restore the original assertTrue function + if RefEagerTestBase._original_assert_true_func is not None: + self.assertTrue = RefEagerTestBase._original_assert_true_func + + # Restore the original assertFalse function + if RefEagerTestBase._original_assert_false_func is not None: + self.assertFalse = RefEagerTestBase._original_assert_false_func + + # Restore the original assertGreater function + if RefEagerTestBase._original_assert_greater_func is not None: + self.assertGreater = RefEagerTestBase._original_assert_greater_func + super().tearDown() # type: ignore[misc] # NOTE: We no-op these methods because they commonly check behaviors that are not relevant in ref eager mode. diff --git a/test/test_rng.expected b/test/test_rng.expected new file mode 100644 index 000000000..da0721a6a --- /dev/null +++ b/test/test_rng.expected @@ -0,0 +1,53 @@ +This file is automatically generated by assertExpectedJournal calls in test_rng.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestRNG.test_multiple_rng_ops) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal_stride_0, normal_stride_1, rand1_stride_0, rand1_stride_1, rand2_stride_0, rand2_stride_1, randn_a_stride_0, randn_a_stride_1, randn_b_stride_0, randn_b_stride_1, randn_c_stride_0, randn_c_stride_1, uniform_stride_0, uniform_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(rand1 + (indices_0[:, None] * rand1_stride_0 + indices_1[None, :] * rand1_stride_1), rand, mask_0[:, None] & mask_1[None, :]) + rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(rand2 + (indices_0[:, None] * rand2_stride_0 + indices_1[None, :] * rand2_stride_1), rand_1, mask_0[:, None] & mask_1[None, :]) + rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(uniform + (indices_0[:, None] * uniform_stride_0 + indices_1[None, :] * uniform_stride_1), rand_2, mask_0[:, None] & mask_1[None, :]) + randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(normal + (indices_0[:, None] * normal_stride_0 + indices_1[None, :] * normal_stride_1), randn, mask_0[:, None] & mask_1[None, :]) + randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(randn_a + (indices_0[:, None] * randn_a_stride_0 + indices_1[None, :] * randn_a_stride_1), randn_1, mask_0[:, None] & mask_1[None, :]) + randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(randn_b + (indices_0[:, None] * randn_b_stride_0 + indices_1[None, :] * randn_b_stride_1), randn_2, mask_0[:, None] & mask_1[None, :]) + randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(randn_c + (indices_0[:, None] * randn_c_stride_0 + indices_1[None, :] * randn_c_stride_1), randn_3, mask_0[:, None] & mask_1[None, :]) + +def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): + from torch._inductor import inductor_prims + _rng_seed_buffer = inductor_prims.seeds(7, torch.device('cuda')) + rand1 = torch.zeros_like(x) + rand2 = torch.zeros_like(x) + uniform = torch.zeros_like(x) + normal = torch.zeros_like(x) + randn_a = torch.zeros_like(x) + randn_b = torch.zeros_like(x) + randn_c = torch.zeros_like(x) + m, n = x.shape + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=3) + randn_sum = randn_a + randn_b + randn_c + return (rand1, rand2, uniform, normal, randn_sum) diff --git a/test/test_rng.py b/test/test_rng.py new file mode 100644 index 000000000..585af03e2 --- /dev/null +++ b/test/test_rng.py @@ -0,0 +1,353 @@ +from __future__ import annotations + +import unittest + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import RefEagerTestBase +from helion._testing import TestCase +from helion._testing import code_and_output +import helion.language as hl + + +class TestRNG(RefEagerTestBase, TestCase): + def test_rand(self): + """Test RNG seeding behavior, reproducibility, output range, and distribution.""" + + @helion.kernel + def rand_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) + return output + + # Test with different tensor sizes for different aspects + x_small = torch.ones(32, 32, device=DEVICE) # For distribution tests + x_large = torch.ones(64, 64, device=DEVICE) # For seeding tests + + # Test 1: Different seeds produce different outputs + torch.manual_seed(42) + _code1, output1 = code_and_output(rand_kernel_tiled_2d, (x_large,)) + + torch.manual_seed(123) + _code2, output2 = code_and_output(rand_kernel_tiled_2d, (x_large,)) + + self.assertFalse( + torch.allclose(output1, output2), + "Different seeds should produce different outputs", + ) + + # Test 2: Same seed produces identical outputs (reproducibility) + torch.manual_seed(42) + _code3, output3 = code_and_output(rand_kernel_tiled_2d, (x_large,)) + + torch.testing.assert_close( + output1, output3, msg="Same seed should produce identical outputs" + ) + + # Test 3: RNG state advances between calls + torch.manual_seed(42) + _code4, output4 = code_and_output(rand_kernel_tiled_2d, (x_large,)) + # No manual_seed here - RNG state should advance + _code5, output5 = code_and_output(rand_kernel_tiled_2d, (x_large,)) + + self.assertFalse( + torch.allclose(output4, output5), + "Sequential calls should produce different outputs (RNG state advanced)", + ) + + # Test 4: Output range and distribution properties + torch.manual_seed(42) + _code6, output6 = code_and_output(rand_kernel_tiled_2d, (x_small,)) + + # All values should be in [0, 1) range + self.assertTrue(torch.all(output6 >= 0.0), "All values should be >= 0") + self.assertTrue(torch.all(output6 < 1.0), "All values should be < 1") + + # Check distribution properties + mean_val = output6.mean().item() + self.assertTrue( + 0.4 < mean_val < 0.6, + f"Mean {mean_val:.3f} should be around 0.5 for uniform distribution", + ) + + # Check spread of values + min_val = output6.min().item() + max_val = output6.max().item() + self.assertTrue( + min_val < 0.2, f"Min value {min_val:.3f} should be < 0.2 for good spread" + ) + self.assertTrue( + max_val > 0.8, f"Max value {max_val:.3f} should be > 0.8 for good spread" + ) + + def test_rand_3d_tensor(self): + """Test 3D RNG with tiled operations.""" + + @helion.kernel + def rand_kernel_3d(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + b, m, n = x.shape + for tile_b, tile_m, tile_n in hl.tile([b, m, n]): + output[tile_b, tile_m, tile_n] = torch.rand_like( + x[tile_b, tile_m, tile_n] + ) + return output + + x = torch.ones(16, 32, 64, device=DEVICE) # 3D tensor + torch.manual_seed(77) + _code, output = code_and_output(rand_kernel_3d, (x,)) + + # All values should be in [0, 1) range + self.assertTrue(torch.all(output >= 0.0)) + self.assertTrue(torch.all(output < 1.0)) + + # Check uniqueness - 3D should generate different values for each element + unique_values = output.unique().numel() + total_values = output.numel() + + # With a good RNG, we should have mostly unique values + uniqueness_ratio = unique_values / total_values + print( + f"3D Unique values: {unique_values}, Total: {total_values}, Percentage: {uniqueness_ratio * 100:.2f}%" + ) + + # Expect at least 95% unique values for good 3D RNG + self.assertGreater(uniqueness_ratio, 0.95) + + # Check distribution across dimensions + # Mean should be around 0.5 for each 2D slice + for b_idx in range(x.shape[0]): + slice_mean = output[b_idx].mean().item() + self.assertTrue( + 0.35 < slice_mean < 0.65, + f"Slice {b_idx} mean {slice_mean} is not well distributed", + ) + + # Verify different seeds produce different results + torch.manual_seed(88) + _code2, output2 = code_and_output(rand_kernel_3d, (x,)) + self.assertFalse(torch.allclose(output, output2)) + + def test_multiple_rng_ops(self): + """Test multiple RNG operations: independence, reproducibility, mixed rand/randn.""" + + @helion.kernel + def multiple_rng_ops_kernel( + x: torch.Tensor, + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor + ]: + # Two independent rand operations + rand1 = torch.zeros_like(x) + rand2 = torch.zeros_like(x) + + # Mixed rand and randn + uniform = torch.zeros_like(x) + normal = torch.zeros_like(x) + + # Multiple randn for sum + randn_a = torch.zeros_like(x) + randn_b = torch.zeros_like(x) + randn_c = torch.zeros_like(x) + + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + # Two independent rand operations + rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) + rand2[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) + + # Mixed rand and randn + uniform[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) + normal[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + + # Multiple randn + randn_a[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + randn_b[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + randn_c[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + + # Combine the three randn outside the loop + randn_sum = randn_a + randn_b + randn_c + + return rand1, rand2, uniform, normal, randn_sum + + x = torch.ones(64, 64, device=DEVICE) + + # Test 1: Independence and distribution properties + torch.manual_seed(42) + _code1, (rand1, rand2, uniform, normal, randn_sum) = code_and_output( + multiple_rng_ops_kernel, (x,) + ) + + # Check two independent rand operations + self.assertTrue( + torch.all(rand1 >= 0.0) and torch.all(rand1 < 1.0), + "First rand output should be in [0, 1)", + ) + self.assertTrue( + torch.all(rand2 >= 0.0) and torch.all(rand2 < 1.0), + "Second rand output should be in [0, 1)", + ) + self.assertFalse( + torch.allclose(rand1, rand2), + "Two independent RNG ops should produce different outputs", + ) + self.assertTrue( + 0.45 < rand1.mean().item() < 0.55, + f"First rand mean {rand1.mean().item():.3f} should be ~0.5", + ) + self.assertTrue( + 0.45 < rand2.mean().item() < 0.55, + f"Second rand mean {rand2.mean().item():.3f} should be ~0.5", + ) + + # Check mixed rand and randn + self.assertTrue( + torch.all(uniform >= 0.0) and torch.all(uniform < 1.0), + "Uniform (rand) values should be in [0, 1)", + ) + self.assertTrue( + 0.4 < uniform.mean().item() < 0.6, + f"Uniform mean {uniform.mean().item():.3f} should be ~0.5", + ) + self.assertTrue( + -0.2 < normal.mean().item() < 0.2, + f"Normal mean {normal.mean().item():.3f} should be ~0", + ) + self.assertTrue( + 0.9 < normal.std().item() < 1.1, + f"Normal std {normal.std().item():.3f} should be ~1", + ) + self.assertTrue( + torch.any(normal < 0.0), "Normal distribution should have negative values" + ) + self.assertFalse( + torch.allclose(uniform, normal), + "Uniform and normal distributions should be different", + ) + + # Check sum of multiple randn + expected_std = 3**0.5 + mean = randn_sum.mean().item() + std = randn_sum.std().item() + self.assertTrue(-0.2 < mean < 0.2, f"Combined mean {mean:.3f} should be ~0") + self.assertTrue( + expected_std * 0.9 < std < expected_std * 1.1, + f"Combined std {std:.3f} should be ~{expected_std:.3f}", + ) + + # Test 2: Reproducibility with same seed + torch.manual_seed(42) + _code2, outputs_a = code_and_output(multiple_rng_ops_kernel, (x,)) + + torch.manual_seed(42) + _code3, outputs_b = code_and_output(multiple_rng_ops_kernel, (x,)) + + # All outputs should be identical with same seed + for i, (a, b) in enumerate(zip(outputs_a, outputs_b, strict=False)): + torch.testing.assert_close( + a, b, msg=f"Output {i} should be identical with same seed" + ) + + # Verify generated code with expected journal + self.assertExpectedJournal(_code1) + + def test_randn_different_seeds_tiled(self): + """Test that different torch.manual_seed values produce different outputs for randn.""" + + @helion.kernel + def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + return output + + x = torch.ones(64, 64, device=DEVICE) + + torch.manual_seed(42) + _code1, output1 = code_and_output(randn_kernel_tiled_2d, (x,)) + + torch.manual_seed(123) + _code2, output2 = code_and_output(randn_kernel_tiled_2d, (x,)) + + # Different seeds should produce different outputs + self.assertFalse(torch.allclose(output1, output2)) + + def test_randn_normal_distribution(self): + """Test that torch.randn_like produces normal distribution (mean≈0, std≈1).""" + + @helion.kernel + def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + return output + + x = torch.ones(128, 128, device=DEVICE) # 16384 samples for better statistics + torch.manual_seed(42) + _code, output = code_and_output(randn_kernel_tiled_2d, (x,)) + + # Check mean is close to 0 + mean = output.mean().item() + self.assertTrue(-0.1 < mean < 0.1, f"Mean {mean} is not close to 0") + + # Check std is close to 1 + std = output.std().item() + self.assertTrue(0.95 < std < 1.05, f"Std {std} is not close to 1") + + # Check we have values outside [-1, 1] (characteristic of normal distribution) + self.assertTrue(torch.any(output < -1.0)) + self.assertTrue(torch.any(output > 1.0)) + + # Roughly 68% should be within 1 std + within_1_std = ( + torch.logical_and(output > -1.0, output < 1.0).float().mean().item() + ) + self.assertTrue( + 0.63 < within_1_std < 0.73, f"Values within 1 std: {within_1_std}" + ) + + def test_randn_3d_tensor(self): + """Test 3D randn with tiled operations.""" + + @helion.kernel + def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + b, m, n = x.shape + for tile_b, tile_m, tile_n in hl.tile([b, m, n]): + output[tile_b, tile_m, tile_n] = torch.randn_like( + x[tile_b, tile_m, tile_n] + ) + return output + + x = torch.ones(8, 32, 64, device=DEVICE) # 3D tensor + torch.manual_seed(77) + _code, output = code_and_output(randn_kernel_3d, (x,)) + + # Check overall distribution + mean = output.mean().item() + std = output.std().item() + self.assertTrue(-0.1 < mean < 0.1, f"3D mean {mean} not close to 0") + self.assertTrue(0.95 < std < 1.05, f"3D std {std} not close to 1") + + # Check distribution across dimensions + for b_idx in range(x.shape[0]): + slice_mean = output[b_idx].mean().item() + slice_std = output[b_idx].std().item() + self.assertTrue( + -0.3 < slice_mean < 0.3, + f"Slice {b_idx} mean {slice_mean} is not well distributed", + ) + self.assertTrue( + 0.85 < slice_std < 1.15, + f"Slice {b_idx} std {slice_std} is not well distributed", + ) + + +if __name__ == "__main__": + unittest.main()