From 4daaeea65d6d141e6f006d2ec57c550a071ab68b Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 29 Aug 2025 20:01:45 -0700 Subject: [PATCH 1/5] torch.rand_like and torch.randn_like support --- helion/_compiler/device_function.py | 38 ++- helion/_compiler/generate_ast.py | 23 +- helion/_compiler/inductor_lowering.py | 104 ++++++++ helion/_testing.py | 75 +++++- test/test_rng.expected | 53 ++++ test/test_rng.py | 353 ++++++++++++++++++++++++++ 6 files changed, 636 insertions(+), 10 deletions(-) create mode 100644 test/test_rng.expected create mode 100644 test/test_rng.py diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 914aafa25..2aeac256b 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -241,6 +241,29 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config) self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config) + self.rng_seed_count = 0 + # Name of the RNG seed buffer parameter in kernel signature + self.rng_seed_buffer_param_name = None + + def has_rng_ops(self) -> bool: + """Check if this kernel uses any RNG operations.""" + return self.rng_seed_count > 0 and self.rng_seed_buffer_param_name is not None + + def allocate_rng_seed(self) -> int: + """Allocate a new RNG seed index and ensure buffer argument exists. + + Returns: + The seed index for this RNG operation. + """ + seed_index = self.rng_seed_count + self.rng_seed_count += 1 + + # Ensure seed buffer parameter name exists + if self.rng_seed_buffer_param_name is None: + self.rng_seed_buffer_param_name = self.new_var("rng_seed_buffer") + + return seed_index + def block_size_var(self, block_id: int) -> str | None: return self.block_size_var_cache.get((block_id,)) @@ -487,15 +510,20 @@ def codegen_function_def(self) -> list[ast.stmt]: prefix.append( statement_from_string("helion.runtime.set_triton_allocator()") ) + + args = [arg.arg_def_node() for arg in self.sorted_args()] + if self.has_rng_ops(): + # Add the seed buffer as a pointer parameter to kernel signature + assert self.rng_seed_buffer_param_name is not None + args.append(create_arg(self.rng_seed_buffer_param_name)) + return [ *prefix, ast_rename( create( ast.FunctionDef, name=self.name, - args=create_arguments( - [arg.arg_def_node() for arg in self.sorted_args()] - ), + args=create_arguments(args), body=[*self.preamble, *self.body], decorator_list=[expr_from_string("triton.jit")], type_params=[], @@ -507,6 +535,10 @@ def codegen_function_def(self) -> list[ast.stmt]: def codegen_function_call(self) -> ast.AST: args = [arg.host_str() for arg in self.sorted_args()] + if self.has_rng_ops(): + # Pass the host-side seed buffer variable to the kernel + args.append("_rng_seed_buffer_host") + # Workaround for triton bug: warp_specialize requires at least 4 warps # See: https://github.com/triton-lang/triton/issues/7354 num_warps = self.config.num_warps diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 2963a97ae..749115591 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -74,6 +74,18 @@ def add_statement(self, stmt: ast.AST | str | None) -> None: stmt = statement_from_string(stmt) self.statements_stack[-1].append(stmt) + def get_rng_seed_buffer_statements(self) -> list[ast.AST]: + import_stmt = statement_from_string( + "from torch._inductor import inductor_prims" + ) + + # Create host-side seed buffer with the required number of seeds + seed_buffer_stmt = statement_from_string( + f"_rng_seed_buffer_host = inductor_prims.seeds({self.device_function.rng_seed_count}, torch.device('cuda'))" + ) + + return [import_stmt, seed_buffer_stmt] + def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name: if isinstance(expr, ast.Name): return expr @@ -395,7 +407,16 @@ def generate_ast( codegen.add_statement(codegen.visit(stmt)) kernel_def = codegen.device_function.codegen_function_def() codegen.host_dead_code_elimination() - host_def = func.codegen_function_def(codegen.host_statements) + + # Inject RNG seed buffer creation if needed + rng_statements = ( + codegen.get_rng_seed_buffer_statements() + if codegen.device_function.has_rng_ops() + else [] + ) + final_host_statements = rng_statements + codegen.host_statements + + host_def = func.codegen_function_def(final_host_statements) call_def = [] main_def = [] diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index e4c8c32c7..920742a60 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -1307,3 +1307,107 @@ def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object: step=ctx.to_ast(step), length=ctx.to_ast(length_arg), ) + + +def _codegen_rng_op( + ctx: GraphInterpreter, + node: torch.fx.Node, + rng_function: str, +) -> object: + """Common codegen implementation for all RNG operations. + + Args: + ctx: The graph interpreter context + node: The FX node for this operation + rng_function: Either "rand" or "randn" + """ + assert rng_function in ["rand", "randn"] + + # Get unique seed index for this RNG operation + device_fn = ctx.cg.device_function + seed_index = device_fn.allocate_rng_seed() + + # Get dimensionality and dtype + assert hasattr(node, "meta") and "val" in node.meta + ndim = node.meta["val"].ndim + dtype = node.kwargs.get("dtype", None) + + # Get the dimension variable names from the device function's symbol arguments + device_fn = ctx.cg.device_function + symbol_args = [ + arg + for arg in device_fn.arguments + if hasattr(arg, "__class__") and arg.__class__.__name__ == "SymbolArgument" + ] + + # Extract dimension names - they should be the last ndim symbol arguments + dim_names = [] + assert len(symbol_args) >= ndim, "Not enough symbol arguments for dimensions" + dim_names = [arg.name for arg in symbol_args[-ndim:]] + + offset_parts = [] + + for i in range(ndim): + # Create the index variable with proper broadcasting + index_expr = f"indices_{i}" + + # Add broadcasting slices for this dimension + # For 1D tensors, this will just be indices_0 with no slicing + slice_parts = [] + for j in range(ndim): + if j < i: + slice_parts.append("None") + elif j == i: + slice_parts.append(":") + else: + slice_parts.append("None") + + # Create the broadcasted index expression + if ndim == 1: + # For 1D, no broadcasting needed + broadcasted_index = index_expr + else: + broadcasted_index = f"{index_expr}[{', '.join(slice_parts)}]" + + # Calculate stride (product of dimensions after this one) + if i < ndim - 1: + # Use the actual dimension variable names + stride_parts = dim_names[i + 1 :] + stride_expr = " * ".join(stride_parts) + offset_parts.append(f"{broadcasted_index} * {stride_expr}") + else: + # Last dimension has no stride multiplication + offset_parts.append(broadcasted_index) + + offset_expr = expr_from_string(" + ".join(offset_parts)) + + # Load seed from buffer using the kernel parameter name + assert device_fn.rng_seed_buffer_param_name is not None + seed_expr = expr_from_string( + "tl.load({buffer} + {index})", + buffer=expr_from_string(device_fn.rng_seed_buffer_param_name), + index=create(ast.Constant, value=seed_index), + ) + + # Generate the RNG call + # Note: tl.rand() and tl.randn() always return float32 + rng_expr = expr_from_string( + f"tl.{rng_function}({{seed}}, {{offset}})", seed=seed_expr, offset=offset_expr + ) + + # Cast to target dtype only if explicitly specified + if dtype is not None: + assert isinstance(dtype, torch.dtype) + rng_expr = expr_from_string(f"{{val}}.to({triton_type(dtype)})", val=rng_expr) + + return rng_expr + + +@register_lowering(torch.ops.aten.rand.default) # pyright: ignore[reportAttributeAccessIssue] +def codegen_rand(ctx: GraphInterpreter, node: torch.fx.Node) -> object: + return _codegen_rng_op(ctx, node, "rand") + + +@register_lowering(torch.ops.aten.randn.default) # pyright: ignore[reportAttributeAccessIssue] +def codegen_randn(ctx: GraphInterpreter, node: torch.fx.Node) -> object: + return _codegen_rng_op(ctx, node, "randn") diff --git a/helion/_testing.py b/helion/_testing.py index 57a574825..d9d43741a 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -124,6 +124,13 @@ class RefEagerTestBase: _original_skip_test_func = None # Class-level tracking for pytest.raises patching _original_pytest_raises = None + # Class-level tracking for assertTrue/assertFalse/assertGreater + _assert_true_count = 0 + _original_assert_true_func = None + _assert_false_count = 0 + _original_assert_false_func = None + _assert_greater_count = 0 + _original_assert_greater_func = None def setUp(self) -> None: """Common setup for all ref eager tests.""" @@ -142,6 +149,10 @@ def setUp(self) -> None: RefEagerTestBase._assert_raises_count = 0 # Reset skipTest counter for this test RefEagerTestBase._skip_test_count = 0 + # Reset assertTrue/assertFalse/assertGreater counters + RefEagerTestBase._assert_true_count = 0 + RefEagerTestBase._assert_false_count = 0 + RefEagerTestBase._assert_greater_count = 0 # Patch torch.testing.assert_close to count calls if RefEagerTestBase._original_assert_close_func is None: @@ -189,6 +200,36 @@ def counting_pytest_raises(*args: object, **kwargs: object) -> object: pytest.raises = counting_pytest_raises # type: ignore[assignment] + # Patch self.assertTrue to count calls + if RefEagerTestBase._original_assert_true_func is None: + RefEagerTestBase._original_assert_true_func = self.assertTrue + + def counting_assert_true(*args: object, **kwargs: object) -> None: + RefEagerTestBase._assert_true_count += 1 + return RefEagerTestBase._original_assert_true_func(*args, **kwargs) # type: ignore[misc] + + self.assertTrue = counting_assert_true # type: ignore[assignment] + + # Patch self.assertFalse to count calls + if RefEagerTestBase._original_assert_false_func is None: + RefEagerTestBase._original_assert_false_func = self.assertFalse + + def counting_assert_false(*args: object, **kwargs: object) -> None: + RefEagerTestBase._assert_false_count += 1 + return RefEagerTestBase._original_assert_false_func(*args, **kwargs) # type: ignore[misc] + + self.assertFalse = counting_assert_false # type: ignore[assignment] + + # Patch self.assertGreater to count calls + if RefEagerTestBase._original_assert_greater_func is None: + RefEagerTestBase._original_assert_greater_func = self.assertGreater + + def counting_assert_greater(*args: object, **kwargs: object) -> None: + RefEagerTestBase._assert_greater_count += 1 + return RefEagerTestBase._original_assert_greater_func(*args, **kwargs) # type: ignore[misc] + + self.assertGreater = counting_assert_greater # type: ignore[assignment] + def tearDown(self) -> None: """Common teardown with assertion counting check.""" # If not in ref eager mode, skip the teardown logic @@ -215,17 +256,27 @@ def tearDown(self) -> None: ) if not is_skipped: - # Check that either assert_close, assertRaises, or skipTest was called + # Check that either assert_close, assertRaises, skipTest, assertTrue, assertFalse, or assertGreater was called total_assertions = ( RefEagerTestBase._assert_close_count + RefEagerTestBase._assert_raises_count + RefEagerTestBase._skip_test_count + + RefEagerTestBase._assert_true_count + + RefEagerTestBase._assert_false_count + + RefEagerTestBase._assert_greater_count ) - self.assertGreater( # type: ignore[attr-defined] - total_assertions, - 0, - f"Test {self._testMethodName} did not call torch.testing.assert_close, assertRaises, or skipTest", # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - ) + # Need to use the original assertGreater to avoid recursion + if RefEagerTestBase._original_assert_greater_func is not None: + RefEagerTestBase._original_assert_greater_func( # type: ignore[misc] + total_assertions, + 0, + f"Test {self._testMethodName} did not call torch.testing.assert_close, assertRaises, skipTest, assertTrue, assertFalse, or assertGreater", # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + ) + else: + # Fallback if original not available + assert total_assertions > 0, ( + f"Test {self._testMethodName} did not call any assertion methods" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + ) finally: # Restore the original assert_close function if RefEagerTestBase._original_assert_close_func is not None: @@ -245,6 +296,18 @@ def tearDown(self) -> None: if RefEagerTestBase._original_pytest_raises is not None: # pyright: ignore[reportAttributeAccessIssue] pytest.raises = RefEagerTestBase._original_pytest_raises # pyright: ignore[reportAttributeAccessIssue] + # Restore the original assertTrue function + if RefEagerTestBase._original_assert_true_func is not None: + self.assertTrue = RefEagerTestBase._original_assert_true_func + + # Restore the original assertFalse function + if RefEagerTestBase._original_assert_false_func is not None: + self.assertFalse = RefEagerTestBase._original_assert_false_func + + # Restore the original assertGreater function + if RefEagerTestBase._original_assert_greater_func is not None: + self.assertGreater = RefEagerTestBase._original_assert_greater_func + super().tearDown() # type: ignore[misc] # NOTE: We no-op these methods because they commonly check behaviors that are not relevant in ref eager mode. diff --git a/test/test_rng.expected b/test/test_rng.expected new file mode 100644 index 000000000..5395c6228 --- /dev/null +++ b/test/test_rng.expected @@ -0,0 +1,53 @@ +This file is automatically generated by assertExpectedJournal calls in test_rng.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestRNG.test_multiple_rng_ops) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal_stride_0, normal_stride_1, rand1_stride_0, rand1_stride_1, rand2_stride_0, rand2_stride_1, randn_a_stride_0, randn_a_stride_1, randn_b_stride_0, randn_b_stride_1, randn_c_stride_0, randn_c_stride_1, uniform_stride_0, uniform_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(rand1 + (indices_0[:, None] * rand1_stride_0 + indices_1[None, :] * rand1_stride_1), rand, mask_0[:, None] & mask_1[None, :]) + rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(rand2 + (indices_0[:, None] * rand2_stride_0 + indices_1[None, :] * rand2_stride_1), rand_1, mask_0[:, None] & mask_1[None, :]) + rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(uniform + (indices_0[:, None] * uniform_stride_0 + indices_1[None, :] * uniform_stride_1), rand_2, mask_0[:, None] & mask_1[None, :]) + randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(normal + (indices_0[:, None] * normal_stride_0 + indices_1[None, :] * normal_stride_1), randn, mask_0[:, None] & mask_1[None, :]) + randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(randn_a + (indices_0[:, None] * randn_a_stride_0 + indices_1[None, :] * randn_a_stride_1), randn_1, mask_0[:, None] & mask_1[None, :]) + randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(randn_b + (indices_0[:, None] * randn_b_stride_0 + indices_1[None, :] * randn_b_stride_1), randn_2, mask_0[:, None] & mask_1[None, :]) + randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32) + tl.store(randn_c + (indices_0[:, None] * randn_c_stride_0 + indices_1[None, :] * randn_c_stride_1), randn_3, mask_0[:, None] & mask_1[None, :]) + +def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): + from torch._inductor import inductor_prims + _rng_seed_buffer_host = inductor_prims.seeds(7, torch.device('cuda')) + rand1 = torch.zeros_like(x) + rand2 = torch.zeros_like(x) + uniform = torch.zeros_like(x) + normal = torch.zeros_like(x) + randn_a = torch.zeros_like(x) + randn_b = torch.zeros_like(x) + randn_c = torch.zeros_like(x) + m, n = x.shape + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer_host, num_warps=4, num_stages=3) + randn_sum = randn_a + randn_b + randn_c + return (rand1, rand2, uniform, normal, randn_sum) diff --git a/test/test_rng.py b/test/test_rng.py new file mode 100644 index 000000000..585af03e2 --- /dev/null +++ b/test/test_rng.py @@ -0,0 +1,353 @@ +from __future__ import annotations + +import unittest + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import RefEagerTestBase +from helion._testing import TestCase +from helion._testing import code_and_output +import helion.language as hl + + +class TestRNG(RefEagerTestBase, TestCase): + def test_rand(self): + """Test RNG seeding behavior, reproducibility, output range, and distribution.""" + + @helion.kernel + def rand_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) + return output + + # Test with different tensor sizes for different aspects + x_small = torch.ones(32, 32, device=DEVICE) # For distribution tests + x_large = torch.ones(64, 64, device=DEVICE) # For seeding tests + + # Test 1: Different seeds produce different outputs + torch.manual_seed(42) + _code1, output1 = code_and_output(rand_kernel_tiled_2d, (x_large,)) + + torch.manual_seed(123) + _code2, output2 = code_and_output(rand_kernel_tiled_2d, (x_large,)) + + self.assertFalse( + torch.allclose(output1, output2), + "Different seeds should produce different outputs", + ) + + # Test 2: Same seed produces identical outputs (reproducibility) + torch.manual_seed(42) + _code3, output3 = code_and_output(rand_kernel_tiled_2d, (x_large,)) + + torch.testing.assert_close( + output1, output3, msg="Same seed should produce identical outputs" + ) + + # Test 3: RNG state advances between calls + torch.manual_seed(42) + _code4, output4 = code_and_output(rand_kernel_tiled_2d, (x_large,)) + # No manual_seed here - RNG state should advance + _code5, output5 = code_and_output(rand_kernel_tiled_2d, (x_large,)) + + self.assertFalse( + torch.allclose(output4, output5), + "Sequential calls should produce different outputs (RNG state advanced)", + ) + + # Test 4: Output range and distribution properties + torch.manual_seed(42) + _code6, output6 = code_and_output(rand_kernel_tiled_2d, (x_small,)) + + # All values should be in [0, 1) range + self.assertTrue(torch.all(output6 >= 0.0), "All values should be >= 0") + self.assertTrue(torch.all(output6 < 1.0), "All values should be < 1") + + # Check distribution properties + mean_val = output6.mean().item() + self.assertTrue( + 0.4 < mean_val < 0.6, + f"Mean {mean_val:.3f} should be around 0.5 for uniform distribution", + ) + + # Check spread of values + min_val = output6.min().item() + max_val = output6.max().item() + self.assertTrue( + min_val < 0.2, f"Min value {min_val:.3f} should be < 0.2 for good spread" + ) + self.assertTrue( + max_val > 0.8, f"Max value {max_val:.3f} should be > 0.8 for good spread" + ) + + def test_rand_3d_tensor(self): + """Test 3D RNG with tiled operations.""" + + @helion.kernel + def rand_kernel_3d(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + b, m, n = x.shape + for tile_b, tile_m, tile_n in hl.tile([b, m, n]): + output[tile_b, tile_m, tile_n] = torch.rand_like( + x[tile_b, tile_m, tile_n] + ) + return output + + x = torch.ones(16, 32, 64, device=DEVICE) # 3D tensor + torch.manual_seed(77) + _code, output = code_and_output(rand_kernel_3d, (x,)) + + # All values should be in [0, 1) range + self.assertTrue(torch.all(output >= 0.0)) + self.assertTrue(torch.all(output < 1.0)) + + # Check uniqueness - 3D should generate different values for each element + unique_values = output.unique().numel() + total_values = output.numel() + + # With a good RNG, we should have mostly unique values + uniqueness_ratio = unique_values / total_values + print( + f"3D Unique values: {unique_values}, Total: {total_values}, Percentage: {uniqueness_ratio * 100:.2f}%" + ) + + # Expect at least 95% unique values for good 3D RNG + self.assertGreater(uniqueness_ratio, 0.95) + + # Check distribution across dimensions + # Mean should be around 0.5 for each 2D slice + for b_idx in range(x.shape[0]): + slice_mean = output[b_idx].mean().item() + self.assertTrue( + 0.35 < slice_mean < 0.65, + f"Slice {b_idx} mean {slice_mean} is not well distributed", + ) + + # Verify different seeds produce different results + torch.manual_seed(88) + _code2, output2 = code_and_output(rand_kernel_3d, (x,)) + self.assertFalse(torch.allclose(output, output2)) + + def test_multiple_rng_ops(self): + """Test multiple RNG operations: independence, reproducibility, mixed rand/randn.""" + + @helion.kernel + def multiple_rng_ops_kernel( + x: torch.Tensor, + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor + ]: + # Two independent rand operations + rand1 = torch.zeros_like(x) + rand2 = torch.zeros_like(x) + + # Mixed rand and randn + uniform = torch.zeros_like(x) + normal = torch.zeros_like(x) + + # Multiple randn for sum + randn_a = torch.zeros_like(x) + randn_b = torch.zeros_like(x) + randn_c = torch.zeros_like(x) + + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + # Two independent rand operations + rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) + rand2[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) + + # Mixed rand and randn + uniform[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n]) + normal[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + + # Multiple randn + randn_a[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + randn_b[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + randn_c[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + + # Combine the three randn outside the loop + randn_sum = randn_a + randn_b + randn_c + + return rand1, rand2, uniform, normal, randn_sum + + x = torch.ones(64, 64, device=DEVICE) + + # Test 1: Independence and distribution properties + torch.manual_seed(42) + _code1, (rand1, rand2, uniform, normal, randn_sum) = code_and_output( + multiple_rng_ops_kernel, (x,) + ) + + # Check two independent rand operations + self.assertTrue( + torch.all(rand1 >= 0.0) and torch.all(rand1 < 1.0), + "First rand output should be in [0, 1)", + ) + self.assertTrue( + torch.all(rand2 >= 0.0) and torch.all(rand2 < 1.0), + "Second rand output should be in [0, 1)", + ) + self.assertFalse( + torch.allclose(rand1, rand2), + "Two independent RNG ops should produce different outputs", + ) + self.assertTrue( + 0.45 < rand1.mean().item() < 0.55, + f"First rand mean {rand1.mean().item():.3f} should be ~0.5", + ) + self.assertTrue( + 0.45 < rand2.mean().item() < 0.55, + f"Second rand mean {rand2.mean().item():.3f} should be ~0.5", + ) + + # Check mixed rand and randn + self.assertTrue( + torch.all(uniform >= 0.0) and torch.all(uniform < 1.0), + "Uniform (rand) values should be in [0, 1)", + ) + self.assertTrue( + 0.4 < uniform.mean().item() < 0.6, + f"Uniform mean {uniform.mean().item():.3f} should be ~0.5", + ) + self.assertTrue( + -0.2 < normal.mean().item() < 0.2, + f"Normal mean {normal.mean().item():.3f} should be ~0", + ) + self.assertTrue( + 0.9 < normal.std().item() < 1.1, + f"Normal std {normal.std().item():.3f} should be ~1", + ) + self.assertTrue( + torch.any(normal < 0.0), "Normal distribution should have negative values" + ) + self.assertFalse( + torch.allclose(uniform, normal), + "Uniform and normal distributions should be different", + ) + + # Check sum of multiple randn + expected_std = 3**0.5 + mean = randn_sum.mean().item() + std = randn_sum.std().item() + self.assertTrue(-0.2 < mean < 0.2, f"Combined mean {mean:.3f} should be ~0") + self.assertTrue( + expected_std * 0.9 < std < expected_std * 1.1, + f"Combined std {std:.3f} should be ~{expected_std:.3f}", + ) + + # Test 2: Reproducibility with same seed + torch.manual_seed(42) + _code2, outputs_a = code_and_output(multiple_rng_ops_kernel, (x,)) + + torch.manual_seed(42) + _code3, outputs_b = code_and_output(multiple_rng_ops_kernel, (x,)) + + # All outputs should be identical with same seed + for i, (a, b) in enumerate(zip(outputs_a, outputs_b, strict=False)): + torch.testing.assert_close( + a, b, msg=f"Output {i} should be identical with same seed" + ) + + # Verify generated code with expected journal + self.assertExpectedJournal(_code1) + + def test_randn_different_seeds_tiled(self): + """Test that different torch.manual_seed values produce different outputs for randn.""" + + @helion.kernel + def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + return output + + x = torch.ones(64, 64, device=DEVICE) + + torch.manual_seed(42) + _code1, output1 = code_and_output(randn_kernel_tiled_2d, (x,)) + + torch.manual_seed(123) + _code2, output2 = code_and_output(randn_kernel_tiled_2d, (x,)) + + # Different seeds should produce different outputs + self.assertFalse(torch.allclose(output1, output2)) + + def test_randn_normal_distribution(self): + """Test that torch.randn_like produces normal distribution (mean≈0, std≈1).""" + + @helion.kernel + def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + output[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n]) + return output + + x = torch.ones(128, 128, device=DEVICE) # 16384 samples for better statistics + torch.manual_seed(42) + _code, output = code_and_output(randn_kernel_tiled_2d, (x,)) + + # Check mean is close to 0 + mean = output.mean().item() + self.assertTrue(-0.1 < mean < 0.1, f"Mean {mean} is not close to 0") + + # Check std is close to 1 + std = output.std().item() + self.assertTrue(0.95 < std < 1.05, f"Std {std} is not close to 1") + + # Check we have values outside [-1, 1] (characteristic of normal distribution) + self.assertTrue(torch.any(output < -1.0)) + self.assertTrue(torch.any(output > 1.0)) + + # Roughly 68% should be within 1 std + within_1_std = ( + torch.logical_and(output > -1.0, output < 1.0).float().mean().item() + ) + self.assertTrue( + 0.63 < within_1_std < 0.73, f"Values within 1 std: {within_1_std}" + ) + + def test_randn_3d_tensor(self): + """Test 3D randn with tiled operations.""" + + @helion.kernel + def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + b, m, n = x.shape + for tile_b, tile_m, tile_n in hl.tile([b, m, n]): + output[tile_b, tile_m, tile_n] = torch.randn_like( + x[tile_b, tile_m, tile_n] + ) + return output + + x = torch.ones(8, 32, 64, device=DEVICE) # 3D tensor + torch.manual_seed(77) + _code, output = code_and_output(randn_kernel_3d, (x,)) + + # Check overall distribution + mean = output.mean().item() + std = output.std().item() + self.assertTrue(-0.1 < mean < 0.1, f"3D mean {mean} not close to 0") + self.assertTrue(0.95 < std < 1.05, f"3D std {std} not close to 1") + + # Check distribution across dimensions + for b_idx in range(x.shape[0]): + slice_mean = output[b_idx].mean().item() + slice_std = output[b_idx].std().item() + self.assertTrue( + -0.3 < slice_mean < 0.3, + f"Slice {b_idx} mean {slice_mean} is not well distributed", + ) + self.assertTrue( + 0.85 < slice_std < 1.15, + f"Slice {b_idx} std {slice_std} is not well distributed", + ) + + +if __name__ == "__main__": + unittest.main() From 5fe0fe483bfc885eaf62e7d3d8d2e3d8174b5871 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 2 Sep 2025 12:54:37 -0700 Subject: [PATCH 2/5] try improve --- helion/_compiler/device_function.py | 36 +--- helion/_compiler/generate_ast.py | 38 ++--- helion/_compiler/inductor_lowering.py | 100 ++--------- helion/_compiler/transforms/__init__.py | 6 + helion/_compiler/transforms/base.py | 30 ++++ helion/_compiler/transforms/rng_transform.py | 165 +++++++++++++++++++ test/test_rng.expected | 4 +- 7 files changed, 240 insertions(+), 139 deletions(-) create mode 100644 helion/_compiler/transforms/__init__.py create mode 100644 helion/_compiler/transforms/base.py create mode 100644 helion/_compiler/transforms/rng_transform.py diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 2aeac256b..c688f25ec 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -41,6 +41,7 @@ from .device_ir import HelperFunctionGraphInfo from .generate_ast import GenerateAST from .program_id import ProgramIDs + from .transforms import TransformPass _P = TypeVar("_P", bound="TensorPropertyArg") @@ -241,28 +242,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config) self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config) - self.rng_seed_count = 0 - # Name of the RNG seed buffer parameter in kernel signature - self.rng_seed_buffer_param_name = None - - def has_rng_ops(self) -> bool: - """Check if this kernel uses any RNG operations.""" - return self.rng_seed_count > 0 and self.rng_seed_buffer_param_name is not None - - def allocate_rng_seed(self) -> int: - """Allocate a new RNG seed index and ensure buffer argument exists. - - Returns: - The seed index for this RNG operation. - """ - seed_index = self.rng_seed_count - self.rng_seed_count += 1 - - # Ensure seed buffer parameter name exists - if self.rng_seed_buffer_param_name is None: - self.rng_seed_buffer_param_name = self.new_var("rng_seed_buffer") - - return seed_index + self.transform_passes: list[TransformPass] = [] def block_size_var(self, block_id: int) -> str | None: return self.block_size_var_cache.get((block_id,)) @@ -512,10 +492,9 @@ def codegen_function_def(self) -> list[ast.stmt]: ) args = [arg.arg_def_node() for arg in self.sorted_args()] - if self.has_rng_ops(): - # Add the seed buffer as a pointer parameter to kernel signature - assert self.rng_seed_buffer_param_name is not None - args.append(create_arg(self.rng_seed_buffer_param_name)) + + for transform_pass in self.transform_passes: + transform_pass.add_kernel_arguments(args) return [ *prefix, @@ -535,9 +514,8 @@ def codegen_function_def(self) -> list[ast.stmt]: def codegen_function_call(self) -> ast.AST: args = [arg.host_str() for arg in self.sorted_args()] - if self.has_rng_ops(): - # Pass the host-side seed buffer variable to the kernel - args.append("_rng_seed_buffer_host") + for transform_pass in self.transform_passes: + transform_pass.add_host_arguments(args) # Workaround for triton bug: warp_specialize requires at least 4 warps # See: https://github.com/triton-lang/triton/issues/7354 diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 749115591..1c5984ab9 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -25,6 +25,7 @@ from .inductor_lowering import CodegenState from .inductor_lowering import codegen_call_with_graph from .program_id import ForEachProgramID +from .transforms import RNGTransformPass from .variable_origin import ArgumentOrigin if TYPE_CHECKING: @@ -56,6 +57,18 @@ def __init__(self, func: HostFunction, config: Config) -> None: self.device_function = DeviceFunction(f"_helion_{func.name}", config, self) CodegenInterface.__init__(self, self.device_function) + # Initialize transformation passes + if RNGTransformPass.has_rng_ops(func): + # Check if RNGTransformPass is already added + has_rng_pass = any( + isinstance(p, RNGTransformPass) + for p in self.device_function.transform_passes + ) + if not has_rng_pass: + self.device_function.transform_passes.append( + RNGTransformPass(self.device_function) + ) + def offset_var(self, block_idx: int) -> str: return self.active_device_loops[block_idx][-1].strategy.offset_var(block_idx) @@ -74,18 +87,6 @@ def add_statement(self, stmt: ast.AST | str | None) -> None: stmt = statement_from_string(stmt) self.statements_stack[-1].append(stmt) - def get_rng_seed_buffer_statements(self) -> list[ast.AST]: - import_stmt = statement_from_string( - "from torch._inductor import inductor_prims" - ) - - # Create host-side seed buffer with the required number of seeds - seed_buffer_stmt = statement_from_string( - f"_rng_seed_buffer_host = inductor_prims.seeds({self.device_function.rng_seed_count}, torch.device('cuda'))" - ) - - return [import_stmt, seed_buffer_stmt] - def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name: if isinstance(expr, ast.Name): return expr @@ -408,13 +409,12 @@ def generate_ast( kernel_def = codegen.device_function.codegen_function_def() codegen.host_dead_code_elimination() - # Inject RNG seed buffer creation if needed - rng_statements = ( - codegen.get_rng_seed_buffer_statements() - if codegen.device_function.has_rng_ops() - else [] - ) - final_host_statements = rng_statements + codegen.host_statements + preamble_statements = [] + for transform_pass in codegen.device_function.transform_passes: + preamble_statements.extend( + transform_pass.get_host_preamble_statements() + ) + final_host_statements = preamble_statements + codegen.host_statements host_def = func.codegen_function_def(final_host_statements) diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 920742a60..b0e04b2a6 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -57,6 +57,7 @@ from .node_masking import getitem_masked_value from .node_masking import inductor_masked_value from .node_masking import mask_node_inputs +from .transforms import RNGTransformPass if TYPE_CHECKING: from collections.abc import Callable @@ -1310,97 +1311,18 @@ def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object: def _codegen_rng_op( - ctx: GraphInterpreter, - node: torch.fx.Node, - rng_function: str, -) -> object: - """Common codegen implementation for all RNG operations. - - Args: - ctx: The graph interpreter context - node: The FX node for this operation - rng_function: Either "rand" or "randn" - """ - assert rng_function in ["rand", "randn"] - - # Get unique seed index for this RNG operation - device_fn = ctx.cg.device_function - seed_index = device_fn.allocate_rng_seed() - - # Get dimensionality and dtype - assert hasattr(node, "meta") and "val" in node.meta - ndim = node.meta["val"].ndim - dtype = node.kwargs.get("dtype", None) - - # Get the dimension variable names from the device function's symbol arguments - device_fn = ctx.cg.device_function - symbol_args = [ - arg - for arg in device_fn.arguments - if hasattr(arg, "__class__") and arg.__class__.__name__ == "SymbolArgument" - ] - - # Extract dimension names - they should be the last ndim symbol arguments - dim_names = [] - assert len(symbol_args) >= ndim, "Not enough symbol arguments for dimensions" - dim_names = [arg.name for arg in symbol_args[-ndim:]] - - offset_parts = [] - - for i in range(ndim): - # Create the index variable with proper broadcasting - index_expr = f"indices_{i}" - - # Add broadcasting slices for this dimension - # For 1D tensors, this will just be indices_0 with no slicing - slice_parts = [] - for j in range(ndim): - if j < i: - slice_parts.append("None") - elif j == i: - slice_parts.append(":") - else: - slice_parts.append("None") - - # Create the broadcasted index expression - if ndim == 1: - # For 1D, no broadcasting needed - broadcasted_index = index_expr - else: - broadcasted_index = f"{index_expr}[{', '.join(slice_parts)}]" - - # Calculate stride (product of dimensions after this one) - if i < ndim - 1: - # Use the actual dimension variable names - stride_parts = dim_names[i + 1 :] - stride_expr = " * ".join(stride_parts) - offset_parts.append(f"{broadcasted_index} * {stride_expr}") - else: - # Last dimension has no stride multiplication - offset_parts.append(broadcasted_index) - - offset_expr = expr_from_string(" + ".join(offset_parts)) - - # Load seed from buffer using the kernel parameter name - assert device_fn.rng_seed_buffer_param_name is not None - seed_expr = expr_from_string( - "tl.load({buffer} + {index})", - buffer=expr_from_string(device_fn.rng_seed_buffer_param_name), - index=create(ast.Constant, value=seed_index), - ) - - # Generate the RNG call - # Note: tl.rand() and tl.randn() always return float32 - rng_expr = expr_from_string( - f"tl.{rng_function}({{seed}}, {{offset}})", seed=seed_expr, offset=offset_expr - ) + ctx: GraphInterpreter, node: torch.fx.Node, rng_function: str +) -> ast.AST: + rng_pass = None + for transform_pass in ctx.cg.device_function.transform_passes: + if isinstance(transform_pass, RNGTransformPass): + rng_pass = transform_pass + break - # Cast to target dtype only if explicitly specified - if dtype is not None: - assert isinstance(dtype, torch.dtype) - rng_expr = expr_from_string(f"{{val}}.to({triton_type(dtype)})", val=rng_expr) + if rng_pass is None: + raise RuntimeError("RNG operation found but RNGTransformPass not initialized") - return rng_expr + return rng_pass.codegen_rng_op(ctx, node, rng_function) @register_lowering(torch.ops.aten.rand.default) # pyright: ignore[reportAttributeAccessIssue] diff --git a/helion/_compiler/transforms/__init__.py b/helion/_compiler/transforms/__init__.py new file mode 100644 index 000000000..09ea6a904 --- /dev/null +++ b/helion/_compiler/transforms/__init__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from .base import TransformPass +from .rng_transform import RNGTransformPass + +__all__ = ["RNGTransformPass", "TransformPass"] diff --git a/helion/_compiler/transforms/base.py b/helion/_compiler/transforms/base.py new file mode 100644 index 000000000..b435d2974 --- /dev/null +++ b/helion/_compiler/transforms/base.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import ast + + +class TransformPass: + def add_kernel_arguments(self, args: list[ast.arg]) -> None: + """Add any required arguments to the kernel signature. + + Args: + args: List of kernel arguments to modify + """ + + def add_host_arguments(self, args: list[str]) -> None: + """Add any required arguments to the host function call. + + Args: + args: List of host arguments to modify + """ + + def get_host_preamble_statements(self) -> list[ast.AST]: + """Get statements to inject into the host function preamble. + + Returns: + List of AST statements to add to the host function + """ + return [] diff --git a/helion/_compiler/transforms/rng_transform.py b/helion/_compiler/transforms/rng_transform.py new file mode 100644 index 000000000..87a71f154 --- /dev/null +++ b/helion/_compiler/transforms/rng_transform.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import ast +from typing import TYPE_CHECKING + +import torch +import torch.fx + +from ..ast_extension import create +from ..ast_extension import create_arg +from ..ast_extension import expr_from_string +from ..generate_ast import statement_from_string +from .base import TransformPass + +if TYPE_CHECKING: + from ..device_function import DeviceFunction + from ..host_function import HostFunction + from ..inductor_lowering import GraphInterpreter + + +class RNGTransformPass(TransformPass): + def __init__(self, device_function: DeviceFunction) -> None: + self.rng_seed_buffer_param_name = device_function.new_var("rng_seed_buffer") + self.seed_index_map: dict[torch.fx.Node, int] = {} + + @staticmethod + def has_rng_ops(host_function: HostFunction) -> bool: + """Check if any graph in the device IR contains RNG operations.""" + rng_ops = { + torch.ops.aten.rand.default, + torch.ops.aten.randn.default, + } + + for graph_info in host_function.device_ir.graphs: + for node in graph_info.graph.nodes: + if node.op == "call_function" and node.target in rng_ops: + return True + + return False + + def get_or_allocate_seed_index(self, node: torch.fx.Node) -> int: + """Get the seed index for a given RNG operation node, allocating if necessary.""" + if node not in self.seed_index_map: + seed_index = len(self.seed_index_map) + self.seed_index_map[node] = seed_index + + return self.seed_index_map[node] + + def add_kernel_arguments(self, args: list[ast.arg]) -> None: + """Add RNG seed buffer argument to kernel signature.""" + args.append(create_arg(self.rng_seed_buffer_param_name)) + + def add_host_arguments(self, args: list[str]) -> None: + """Add RNG seed buffer argument to host function call.""" + args.append("_rng_seed_buffer") + + def get_host_preamble_statements(self) -> list[ast.AST]: + """Get statements to inject into host function preamble.""" + import_stmt = statement_from_string( + "from torch._inductor import inductor_prims" + ) + + # Create host-side seed buffer with the required number of seeds + seed_buffer_stmt = statement_from_string( + f"_rng_seed_buffer = inductor_prims.seeds({len(self.seed_index_map)}, torch.device('cuda'))" + ) + + return [import_stmt, seed_buffer_stmt] + + def codegen_rng_op( + self, ctx: GraphInterpreter, node: torch.fx.Node, rng_function: str + ) -> ast.AST: + """Generate Triton code for an RNG op. + + Args: + ctx: The graph interpreter context + node: The FX node for this operation + rng_function: Either "rand" or "randn" + + Returns: + AST expression for the RNG operation + """ + assert rng_function in ["rand", "randn"] + + # Get the seed index for this operation + seed_index = self.get_or_allocate_seed_index(node) + + # Get dimensionality and dtype + assert hasattr(node, "meta") and "val" in node.meta + ndim = node.meta["val"].ndim + dtype = node.kwargs.get("dtype", None) + + # Get the dimension variable names from the device function's symbol arguments + device_fn = ctx.cg.device_function + symbol_args = [ + arg + for arg in device_fn.arguments + if hasattr(arg, "__class__") and arg.__class__.__name__ == "SymbolArgument" + ] + + # Extract dimension names - they should be the last ndim symbol arguments + dim_names = [] + assert len(symbol_args) >= ndim, "Not enough symbol arguments for dimensions" + dim_names = [arg.name for arg in symbol_args[-ndim:]] + + offset_parts = [] + + for i in range(ndim): + # Create the index variable with proper broadcasting + index_expr = f"indices_{i}" + + # Add broadcasting slices for this dimension + slice_parts = [] + for j in range(ndim): + if j < i: + slice_parts.append("None") + elif j == i: + slice_parts.append(":") + else: + slice_parts.append("None") + + # Create the broadcasted index expression + if ndim == 1: + # For 1D, no broadcasting needed + broadcasted_index = index_expr + else: + broadcasted_index = f"{index_expr}[{', '.join(slice_parts)}]" + + # Calculate stride (product of dimensions after this one) + if i < ndim - 1: + # Use the actual dimension variable names + stride_parts = dim_names[i + 1 :] + stride_expr = " * ".join(stride_parts) + offset_parts.append(f"{broadcasted_index} * {stride_expr}") + else: + # Last dimension has no stride multiplication + offset_parts.append(broadcasted_index) + + offset_expr = expr_from_string(" + ".join(offset_parts)) + + # Load seed from buffer using the kernel parameter name + seed_expr = expr_from_string( + "tl.load({buffer} + {index})", + buffer=expr_from_string(self.rng_seed_buffer_param_name), + index=create(ast.Constant, value=seed_index), + ) + + # Generate the RNG call + # Note: tl.rand() and tl.randn() always return float32 + rng_expr = expr_from_string( + f"tl.{rng_function}({{seed}}, {{offset}})", + seed=seed_expr, + offset=offset_expr, + ) + + # Cast to target dtype only if explicitly specified + if dtype is not None: + assert isinstance(dtype, torch.dtype) + from torch._inductor.utils import triton_type + + rng_expr = expr_from_string( + f"{{val}}.to({triton_type(dtype)})", val=rng_expr + ) + + return rng_expr diff --git a/test/test_rng.expected b/test/test_rng.expected index 5395c6228..da0721a6a 100644 --- a/test/test_rng.expected +++ b/test/test_rng.expected @@ -37,7 +37,7 @@ def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, rand def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): from torch._inductor import inductor_prims - _rng_seed_buffer_host = inductor_prims.seeds(7, torch.device('cuda')) + _rng_seed_buffer = inductor_prims.seeds(7, torch.device('cuda')) rand1 = torch.zeros_like(x) rand2 = torch.zeros_like(x) uniform = torch.zeros_like(x) @@ -48,6 +48,6 @@ def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): m, n = x.shape _BLOCK_SIZE_0 = 32 _BLOCK_SIZE_1 = 32 - _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer_host, num_warps=4, num_stages=3) + _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=3) randn_sum = randn_a + randn_b + randn_c return (rand1, rand2, uniform, normal, randn_sum) From 3fbadd0e450bb9cba36dc917a2c55803364468b0 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 3 Sep 2025 15:37:41 -0700 Subject: [PATCH 3/5] Revert "try improve" This reverts commit 5fe0fe483bfc885eaf62e7d3d8d2e3d8174b5871. --- helion/_compiler/device_function.py | 36 +++- helion/_compiler/generate_ast.py | 38 ++--- helion/_compiler/inductor_lowering.py | 100 +++++++++-- helion/_compiler/transforms/__init__.py | 6 - helion/_compiler/transforms/base.py | 30 ---- helion/_compiler/transforms/rng_transform.py | 165 ------------------- test/test_rng.expected | 4 +- 7 files changed, 139 insertions(+), 240 deletions(-) delete mode 100644 helion/_compiler/transforms/__init__.py delete mode 100644 helion/_compiler/transforms/base.py delete mode 100644 helion/_compiler/transforms/rng_transform.py diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index c688f25ec..2aeac256b 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -41,7 +41,6 @@ from .device_ir import HelperFunctionGraphInfo from .generate_ast import GenerateAST from .program_id import ProgramIDs - from .transforms import TransformPass _P = TypeVar("_P", bound="TensorPropertyArg") @@ -242,7 +241,28 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config) self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config) - self.transform_passes: list[TransformPass] = [] + self.rng_seed_count = 0 + # Name of the RNG seed buffer parameter in kernel signature + self.rng_seed_buffer_param_name = None + + def has_rng_ops(self) -> bool: + """Check if this kernel uses any RNG operations.""" + return self.rng_seed_count > 0 and self.rng_seed_buffer_param_name is not None + + def allocate_rng_seed(self) -> int: + """Allocate a new RNG seed index and ensure buffer argument exists. + + Returns: + The seed index for this RNG operation. + """ + seed_index = self.rng_seed_count + self.rng_seed_count += 1 + + # Ensure seed buffer parameter name exists + if self.rng_seed_buffer_param_name is None: + self.rng_seed_buffer_param_name = self.new_var("rng_seed_buffer") + + return seed_index def block_size_var(self, block_id: int) -> str | None: return self.block_size_var_cache.get((block_id,)) @@ -492,9 +512,10 @@ def codegen_function_def(self) -> list[ast.stmt]: ) args = [arg.arg_def_node() for arg in self.sorted_args()] - - for transform_pass in self.transform_passes: - transform_pass.add_kernel_arguments(args) + if self.has_rng_ops(): + # Add the seed buffer as a pointer parameter to kernel signature + assert self.rng_seed_buffer_param_name is not None + args.append(create_arg(self.rng_seed_buffer_param_name)) return [ *prefix, @@ -514,8 +535,9 @@ def codegen_function_def(self) -> list[ast.stmt]: def codegen_function_call(self) -> ast.AST: args = [arg.host_str() for arg in self.sorted_args()] - for transform_pass in self.transform_passes: - transform_pass.add_host_arguments(args) + if self.has_rng_ops(): + # Pass the host-side seed buffer variable to the kernel + args.append("_rng_seed_buffer_host") # Workaround for triton bug: warp_specialize requires at least 4 warps # See: https://github.com/triton-lang/triton/issues/7354 diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 1c5984ab9..749115591 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -25,7 +25,6 @@ from .inductor_lowering import CodegenState from .inductor_lowering import codegen_call_with_graph from .program_id import ForEachProgramID -from .transforms import RNGTransformPass from .variable_origin import ArgumentOrigin if TYPE_CHECKING: @@ -57,18 +56,6 @@ def __init__(self, func: HostFunction, config: Config) -> None: self.device_function = DeviceFunction(f"_helion_{func.name}", config, self) CodegenInterface.__init__(self, self.device_function) - # Initialize transformation passes - if RNGTransformPass.has_rng_ops(func): - # Check if RNGTransformPass is already added - has_rng_pass = any( - isinstance(p, RNGTransformPass) - for p in self.device_function.transform_passes - ) - if not has_rng_pass: - self.device_function.transform_passes.append( - RNGTransformPass(self.device_function) - ) - def offset_var(self, block_idx: int) -> str: return self.active_device_loops[block_idx][-1].strategy.offset_var(block_idx) @@ -87,6 +74,18 @@ def add_statement(self, stmt: ast.AST | str | None) -> None: stmt = statement_from_string(stmt) self.statements_stack[-1].append(stmt) + def get_rng_seed_buffer_statements(self) -> list[ast.AST]: + import_stmt = statement_from_string( + "from torch._inductor import inductor_prims" + ) + + # Create host-side seed buffer with the required number of seeds + seed_buffer_stmt = statement_from_string( + f"_rng_seed_buffer_host = inductor_prims.seeds({self.device_function.rng_seed_count}, torch.device('cuda'))" + ) + + return [import_stmt, seed_buffer_stmt] + def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name: if isinstance(expr, ast.Name): return expr @@ -409,12 +408,13 @@ def generate_ast( kernel_def = codegen.device_function.codegen_function_def() codegen.host_dead_code_elimination() - preamble_statements = [] - for transform_pass in codegen.device_function.transform_passes: - preamble_statements.extend( - transform_pass.get_host_preamble_statements() - ) - final_host_statements = preamble_statements + codegen.host_statements + # Inject RNG seed buffer creation if needed + rng_statements = ( + codegen.get_rng_seed_buffer_statements() + if codegen.device_function.has_rng_ops() + else [] + ) + final_host_statements = rng_statements + codegen.host_statements host_def = func.codegen_function_def(final_host_statements) diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index b0e04b2a6..920742a60 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -57,7 +57,6 @@ from .node_masking import getitem_masked_value from .node_masking import inductor_masked_value from .node_masking import mask_node_inputs -from .transforms import RNGTransformPass if TYPE_CHECKING: from collections.abc import Callable @@ -1311,18 +1310,97 @@ def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object: def _codegen_rng_op( - ctx: GraphInterpreter, node: torch.fx.Node, rng_function: str -) -> ast.AST: - rng_pass = None - for transform_pass in ctx.cg.device_function.transform_passes: - if isinstance(transform_pass, RNGTransformPass): - rng_pass = transform_pass - break + ctx: GraphInterpreter, + node: torch.fx.Node, + rng_function: str, +) -> object: + """Common codegen implementation for all RNG operations. + + Args: + ctx: The graph interpreter context + node: The FX node for this operation + rng_function: Either "rand" or "randn" + """ + assert rng_function in ["rand", "randn"] + + # Get unique seed index for this RNG operation + device_fn = ctx.cg.device_function + seed_index = device_fn.allocate_rng_seed() + + # Get dimensionality and dtype + assert hasattr(node, "meta") and "val" in node.meta + ndim = node.meta["val"].ndim + dtype = node.kwargs.get("dtype", None) + + # Get the dimension variable names from the device function's symbol arguments + device_fn = ctx.cg.device_function + symbol_args = [ + arg + for arg in device_fn.arguments + if hasattr(arg, "__class__") and arg.__class__.__name__ == "SymbolArgument" + ] + + # Extract dimension names - they should be the last ndim symbol arguments + dim_names = [] + assert len(symbol_args) >= ndim, "Not enough symbol arguments for dimensions" + dim_names = [arg.name for arg in symbol_args[-ndim:]] + + offset_parts = [] + + for i in range(ndim): + # Create the index variable with proper broadcasting + index_expr = f"indices_{i}" + + # Add broadcasting slices for this dimension + # For 1D tensors, this will just be indices_0 with no slicing + slice_parts = [] + for j in range(ndim): + if j < i: + slice_parts.append("None") + elif j == i: + slice_parts.append(":") + else: + slice_parts.append("None") + + # Create the broadcasted index expression + if ndim == 1: + # For 1D, no broadcasting needed + broadcasted_index = index_expr + else: + broadcasted_index = f"{index_expr}[{', '.join(slice_parts)}]" + + # Calculate stride (product of dimensions after this one) + if i < ndim - 1: + # Use the actual dimension variable names + stride_parts = dim_names[i + 1 :] + stride_expr = " * ".join(stride_parts) + offset_parts.append(f"{broadcasted_index} * {stride_expr}") + else: + # Last dimension has no stride multiplication + offset_parts.append(broadcasted_index) + + offset_expr = expr_from_string(" + ".join(offset_parts)) + + # Load seed from buffer using the kernel parameter name + assert device_fn.rng_seed_buffer_param_name is not None + seed_expr = expr_from_string( + "tl.load({buffer} + {index})", + buffer=expr_from_string(device_fn.rng_seed_buffer_param_name), + index=create(ast.Constant, value=seed_index), + ) + + # Generate the RNG call + # Note: tl.rand() and tl.randn() always return float32 + rng_expr = expr_from_string( + f"tl.{rng_function}({{seed}}, {{offset}})", seed=seed_expr, offset=offset_expr + ) - if rng_pass is None: - raise RuntimeError("RNG operation found but RNGTransformPass not initialized") + # Cast to target dtype only if explicitly specified + if dtype is not None: + assert isinstance(dtype, torch.dtype) + rng_expr = expr_from_string(f"{{val}}.to({triton_type(dtype)})", val=rng_expr) - return rng_pass.codegen_rng_op(ctx, node, rng_function) + return rng_expr @register_lowering(torch.ops.aten.rand.default) # pyright: ignore[reportAttributeAccessIssue] diff --git a/helion/_compiler/transforms/__init__.py b/helion/_compiler/transforms/__init__.py deleted file mode 100644 index 09ea6a904..000000000 --- a/helion/_compiler/transforms/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import annotations - -from .base import TransformPass -from .rng_transform import RNGTransformPass - -__all__ = ["RNGTransformPass", "TransformPass"] diff --git a/helion/_compiler/transforms/base.py b/helion/_compiler/transforms/base.py deleted file mode 100644 index b435d2974..000000000 --- a/helion/_compiler/transforms/base.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import ast - - -class TransformPass: - def add_kernel_arguments(self, args: list[ast.arg]) -> None: - """Add any required arguments to the kernel signature. - - Args: - args: List of kernel arguments to modify - """ - - def add_host_arguments(self, args: list[str]) -> None: - """Add any required arguments to the host function call. - - Args: - args: List of host arguments to modify - """ - - def get_host_preamble_statements(self) -> list[ast.AST]: - """Get statements to inject into the host function preamble. - - Returns: - List of AST statements to add to the host function - """ - return [] diff --git a/helion/_compiler/transforms/rng_transform.py b/helion/_compiler/transforms/rng_transform.py deleted file mode 100644 index 87a71f154..000000000 --- a/helion/_compiler/transforms/rng_transform.py +++ /dev/null @@ -1,165 +0,0 @@ -from __future__ import annotations - -import ast -from typing import TYPE_CHECKING - -import torch -import torch.fx - -from ..ast_extension import create -from ..ast_extension import create_arg -from ..ast_extension import expr_from_string -from ..generate_ast import statement_from_string -from .base import TransformPass - -if TYPE_CHECKING: - from ..device_function import DeviceFunction - from ..host_function import HostFunction - from ..inductor_lowering import GraphInterpreter - - -class RNGTransformPass(TransformPass): - def __init__(self, device_function: DeviceFunction) -> None: - self.rng_seed_buffer_param_name = device_function.new_var("rng_seed_buffer") - self.seed_index_map: dict[torch.fx.Node, int] = {} - - @staticmethod - def has_rng_ops(host_function: HostFunction) -> bool: - """Check if any graph in the device IR contains RNG operations.""" - rng_ops = { - torch.ops.aten.rand.default, - torch.ops.aten.randn.default, - } - - for graph_info in host_function.device_ir.graphs: - for node in graph_info.graph.nodes: - if node.op == "call_function" and node.target in rng_ops: - return True - - return False - - def get_or_allocate_seed_index(self, node: torch.fx.Node) -> int: - """Get the seed index for a given RNG operation node, allocating if necessary.""" - if node not in self.seed_index_map: - seed_index = len(self.seed_index_map) - self.seed_index_map[node] = seed_index - - return self.seed_index_map[node] - - def add_kernel_arguments(self, args: list[ast.arg]) -> None: - """Add RNG seed buffer argument to kernel signature.""" - args.append(create_arg(self.rng_seed_buffer_param_name)) - - def add_host_arguments(self, args: list[str]) -> None: - """Add RNG seed buffer argument to host function call.""" - args.append("_rng_seed_buffer") - - def get_host_preamble_statements(self) -> list[ast.AST]: - """Get statements to inject into host function preamble.""" - import_stmt = statement_from_string( - "from torch._inductor import inductor_prims" - ) - - # Create host-side seed buffer with the required number of seeds - seed_buffer_stmt = statement_from_string( - f"_rng_seed_buffer = inductor_prims.seeds({len(self.seed_index_map)}, torch.device('cuda'))" - ) - - return [import_stmt, seed_buffer_stmt] - - def codegen_rng_op( - self, ctx: GraphInterpreter, node: torch.fx.Node, rng_function: str - ) -> ast.AST: - """Generate Triton code for an RNG op. - - Args: - ctx: The graph interpreter context - node: The FX node for this operation - rng_function: Either "rand" or "randn" - - Returns: - AST expression for the RNG operation - """ - assert rng_function in ["rand", "randn"] - - # Get the seed index for this operation - seed_index = self.get_or_allocate_seed_index(node) - - # Get dimensionality and dtype - assert hasattr(node, "meta") and "val" in node.meta - ndim = node.meta["val"].ndim - dtype = node.kwargs.get("dtype", None) - - # Get the dimension variable names from the device function's symbol arguments - device_fn = ctx.cg.device_function - symbol_args = [ - arg - for arg in device_fn.arguments - if hasattr(arg, "__class__") and arg.__class__.__name__ == "SymbolArgument" - ] - - # Extract dimension names - they should be the last ndim symbol arguments - dim_names = [] - assert len(symbol_args) >= ndim, "Not enough symbol arguments for dimensions" - dim_names = [arg.name for arg in symbol_args[-ndim:]] - - offset_parts = [] - - for i in range(ndim): - # Create the index variable with proper broadcasting - index_expr = f"indices_{i}" - - # Add broadcasting slices for this dimension - slice_parts = [] - for j in range(ndim): - if j < i: - slice_parts.append("None") - elif j == i: - slice_parts.append(":") - else: - slice_parts.append("None") - - # Create the broadcasted index expression - if ndim == 1: - # For 1D, no broadcasting needed - broadcasted_index = index_expr - else: - broadcasted_index = f"{index_expr}[{', '.join(slice_parts)}]" - - # Calculate stride (product of dimensions after this one) - if i < ndim - 1: - # Use the actual dimension variable names - stride_parts = dim_names[i + 1 :] - stride_expr = " * ".join(stride_parts) - offset_parts.append(f"{broadcasted_index} * {stride_expr}") - else: - # Last dimension has no stride multiplication - offset_parts.append(broadcasted_index) - - offset_expr = expr_from_string(" + ".join(offset_parts)) - - # Load seed from buffer using the kernel parameter name - seed_expr = expr_from_string( - "tl.load({buffer} + {index})", - buffer=expr_from_string(self.rng_seed_buffer_param_name), - index=create(ast.Constant, value=seed_index), - ) - - # Generate the RNG call - # Note: tl.rand() and tl.randn() always return float32 - rng_expr = expr_from_string( - f"tl.{rng_function}({{seed}}, {{offset}})", - seed=seed_expr, - offset=offset_expr, - ) - - # Cast to target dtype only if explicitly specified - if dtype is not None: - assert isinstance(dtype, torch.dtype) - from torch._inductor.utils import triton_type - - rng_expr = expr_from_string( - f"{{val}}.to({triton_type(dtype)})", val=rng_expr - ) - - return rng_expr diff --git a/test/test_rng.expected b/test/test_rng.expected index da0721a6a..5395c6228 100644 --- a/test/test_rng.expected +++ b/test/test_rng.expected @@ -37,7 +37,7 @@ def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, rand def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): from torch._inductor import inductor_prims - _rng_seed_buffer = inductor_prims.seeds(7, torch.device('cuda')) + _rng_seed_buffer_host = inductor_prims.seeds(7, torch.device('cuda')) rand1 = torch.zeros_like(x) rand2 = torch.zeros_like(x) uniform = torch.zeros_like(x) @@ -48,6 +48,6 @@ def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): m, n = x.shape _BLOCK_SIZE_0 = 32 _BLOCK_SIZE_1 = 32 - _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=3) + _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer_host, num_warps=4, num_stages=3) randn_sum = randn_a + randn_b + randn_c return (rand1, rand2, uniform, normal, randn_sum) From 4f4018ca7e88b477736f6b7f8461e396671b1d41 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 3 Sep 2025 15:43:46 -0700 Subject: [PATCH 4/5] clean --- helion/_compiler/device_function.py | 2 +- helion/_compiler/generate_ast.py | 2 +- test/test_rng.expected | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 2aeac256b..8a37f177d 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -537,7 +537,7 @@ def codegen_function_call(self) -> ast.AST: if self.has_rng_ops(): # Pass the host-side seed buffer variable to the kernel - args.append("_rng_seed_buffer_host") + args.append("_rng_seed_buffer") # Workaround for triton bug: warp_specialize requires at least 4 warps # See: https://github.com/triton-lang/triton/issues/7354 diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 749115591..e4caf8e02 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -81,7 +81,7 @@ def get_rng_seed_buffer_statements(self) -> list[ast.AST]: # Create host-side seed buffer with the required number of seeds seed_buffer_stmt = statement_from_string( - f"_rng_seed_buffer_host = inductor_prims.seeds({self.device_function.rng_seed_count}, torch.device('cuda'))" + f"_rng_seed_buffer = inductor_prims.seeds({self.device_function.rng_seed_count}, torch.device('cuda'))" ) return [import_stmt, seed_buffer_stmt] diff --git a/test/test_rng.expected b/test/test_rng.expected index 5395c6228..da0721a6a 100644 --- a/test/test_rng.expected +++ b/test/test_rng.expected @@ -37,7 +37,7 @@ def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, rand def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): from torch._inductor import inductor_prims - _rng_seed_buffer_host = inductor_prims.seeds(7, torch.device('cuda')) + _rng_seed_buffer = inductor_prims.seeds(7, torch.device('cuda')) rand1 = torch.zeros_like(x) rand2 = torch.zeros_like(x) uniform = torch.zeros_like(x) @@ -48,6 +48,6 @@ def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): m, n = x.shape _BLOCK_SIZE_0 = 32 _BLOCK_SIZE_1 = 32 - _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer_host, num_warps=4, num_stages=3) + _launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=3) randn_sum = randn_a + randn_b + randn_c return (rand1, rand2, uniform, normal, randn_sum) From 4f27af2a1459263c32b33d234b79175552a2b919 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 3 Sep 2025 15:47:54 -0700 Subject: [PATCH 5/5] clean up --- helion/_compiler/inductor_lowering.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 920742a60..29cb21efa 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -50,6 +50,7 @@ from .ast_extension import expr_from_string from .ast_extension import statement_from_string from .compile_environment import CompileEnvironment +from .device_function import SymbolArgument from .device_function import VarInfo from .device_function import contains_only_block_size_symbols from .node_masking import apply_masking @@ -1335,9 +1336,7 @@ def _codegen_rng_op( # Get the dimension variable names from the device function's symbol arguments device_fn = ctx.cg.device_function symbol_args = [ - arg - for arg in device_fn.arguments - if hasattr(arg, "__class__") and arg.__class__.__name__ == "SymbolArgument" + arg for arg in device_fn.arguments if isinstance(arg, SymbolArgument) ] # Extract dimension names - they should be the last ndim symbol arguments