Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,29 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config)
self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config)

self.rng_seed_count = 0
# Name of the RNG seed buffer parameter in kernel signature
self.rng_seed_buffer_param_name = None

def has_rng_ops(self) -> bool:
"""Check if this kernel uses any RNG operations."""
return self.rng_seed_count > 0 and self.rng_seed_buffer_param_name is not None

def allocate_rng_seed(self) -> int:
"""Allocate a new RNG seed index and ensure buffer argument exists.

Returns:
The seed index for this RNG operation.
"""
seed_index = self.rng_seed_count
self.rng_seed_count += 1

# Ensure seed buffer parameter name exists
if self.rng_seed_buffer_param_name is None:
self.rng_seed_buffer_param_name = self.new_var("rng_seed_buffer")

return seed_index

def block_size_var(self, block_id: int) -> str | None:
return self.block_size_var_cache.get((block_id,))

Expand Down Expand Up @@ -487,15 +510,20 @@ def codegen_function_def(self) -> list[ast.stmt]:
prefix.append(
statement_from_string("helion.runtime.set_triton_allocator()")
)

args = [arg.arg_def_node() for arg in self.sorted_args()]
if self.has_rng_ops():
# Add the seed buffer as a pointer parameter to kernel signature
assert self.rng_seed_buffer_param_name is not None
args.append(create_arg(self.rng_seed_buffer_param_name))

return [
*prefix,
ast_rename(
create(
ast.FunctionDef,
name=self.name,
args=create_arguments(
[arg.arg_def_node() for arg in self.sorted_args()]
),
args=create_arguments(args),
body=[*self.preamble, *self.body],
decorator_list=[expr_from_string("triton.jit")],
type_params=[],
Expand All @@ -507,6 +535,10 @@ def codegen_function_def(self) -> list[ast.stmt]:
def codegen_function_call(self) -> ast.AST:
args = [arg.host_str() for arg in self.sorted_args()]

if self.has_rng_ops():
# Pass the host-side seed buffer variable to the kernel
args.append("_rng_seed_buffer")

# Workaround for triton bug: warp_specialize requires at least 4 warps
# See: https://github.com/triton-lang/triton/issues/7354
num_warps = self.config.num_warps
Expand Down
23 changes: 22 additions & 1 deletion helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ def add_statement(self, stmt: ast.AST | str | None) -> None:
stmt = statement_from_string(stmt)
self.statements_stack[-1].append(stmt)

def get_rng_seed_buffer_statements(self) -> list[ast.AST]:
import_stmt = statement_from_string(
"from torch._inductor import inductor_prims"
)

# Create host-side seed buffer with the required number of seeds
seed_buffer_stmt = statement_from_string(
f"_rng_seed_buffer = inductor_prims.seeds({self.device_function.rng_seed_count}, torch.device('cuda'))"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we use the same approach as torch.compile (allocate one seed value per RNG op).

)

return [import_stmt, seed_buffer_stmt]

def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name:
if isinstance(expr, ast.Name):
return expr
Expand Down Expand Up @@ -395,7 +407,16 @@ def generate_ast(
codegen.add_statement(codegen.visit(stmt))
kernel_def = codegen.device_function.codegen_function_def()
codegen.host_dead_code_elimination()
host_def = func.codegen_function_def(codegen.host_statements)

# Inject RNG seed buffer creation if needed
rng_statements = (
codegen.get_rng_seed_buffer_statements()
if codegen.device_function.has_rng_ops()
else []
)
final_host_statements = rng_statements + codegen.host_statements

host_def = func.codegen_function_def(final_host_statements)

call_def = []
main_def = []
Expand Down
103 changes: 103 additions & 0 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .ast_extension import expr_from_string
from .ast_extension import statement_from_string
from .compile_environment import CompileEnvironment
from .device_function import SymbolArgument
from .device_function import VarInfo
from .device_function import contains_only_block_size_symbols
from .node_masking import apply_masking
Expand Down Expand Up @@ -1307,3 +1308,105 @@ def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
step=ctx.to_ast(step),
length=ctx.to_ast(length_arg),
)


def _codegen_rng_op(
ctx: GraphInterpreter,
node: torch.fx.Node,
rng_function: str,
) -> object:
"""Common codegen implementation for all RNG operations.

Args:
ctx: The graph interpreter context
node: The FX node for this operation
rng_function: Either "rand" or "randn"
"""
assert rng_function in ["rand", "randn"]

# Get unique seed index for this RNG operation
device_fn = ctx.cg.device_function
seed_index = device_fn.allocate_rng_seed()

# Get dimensionality and dtype
assert hasattr(node, "meta") and "val" in node.meta
ndim = node.meta["val"].ndim
dtype = node.kwargs.get("dtype", None)

# Get the dimension variable names from the device function's symbol arguments
device_fn = ctx.cg.device_function
symbol_args = [
arg for arg in device_fn.arguments if isinstance(arg, SymbolArgument)
]

# Extract dimension names - they should be the last ndim symbol arguments
dim_names = []
assert len(symbol_args) >= ndim, "Not enough symbol arguments for dimensions"
dim_names = [arg.name for arg in symbol_args[-ndim:]]

offset_parts = []

for i in range(ndim):
# Create the index variable with proper broadcasting
index_expr = f"indices_{i}"

# Add broadcasting slices for this dimension
# For 1D tensors, this will just be indices_0 with no slicing
slice_parts = []
for j in range(ndim):
if j < i:
slice_parts.append("None")
elif j == i:
slice_parts.append(":")
else:
slice_parts.append("None")

# Create the broadcasted index expression
if ndim == 1:
# For 1D, no broadcasting needed
broadcasted_index = index_expr
else:
broadcasted_index = f"{index_expr}[{', '.join(slice_parts)}]"

# Calculate stride (product of dimensions after this one)
if i < ndim - 1:
# Use the actual dimension variable names
stride_parts = dim_names[i + 1 :]
stride_expr = " * ".join(stride_parts)
offset_parts.append(f"{broadcasted_index} * {stride_expr}")
else:
# Last dimension has no stride multiplication
offset_parts.append(broadcasted_index)

offset_expr = expr_from_string(" + ".join(offset_parts))

# Load seed from buffer using the kernel parameter name
assert device_fn.rng_seed_buffer_param_name is not None
seed_expr = expr_from_string(
"tl.load({buffer} + {index})",
buffer=expr_from_string(device_fn.rng_seed_buffer_param_name),
index=create(ast.Constant, value=seed_index),
)

# Generate the RNG call
# Note: tl.rand() and tl.randn() always return float32
rng_expr = expr_from_string(
f"tl.{rng_function}({{seed}}, {{offset}})", seed=seed_expr, offset=offset_expr
)

# Cast to target dtype only if explicitly specified
if dtype is not None:
assert isinstance(dtype, torch.dtype)
rng_expr = expr_from_string(f"{{val}}.to({triton_type(dtype)})", val=rng_expr)

return rng_expr


@register_lowering(torch.ops.aten.rand.default) # pyright: ignore[reportAttributeAccessIssue]
def codegen_rand(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
return _codegen_rng_op(ctx, node, "rand")


@register_lowering(torch.ops.aten.randn.default) # pyright: ignore[reportAttributeAccessIssue]
def codegen_randn(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
return _codegen_rng_op(ctx, node, "randn")
75 changes: 69 additions & 6 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ class RefEagerTestBase:
_original_skip_test_func = None
# Class-level tracking for pytest.raises patching
_original_pytest_raises = None
# Class-level tracking for assertTrue/assertFalse/assertGreater
_assert_true_count = 0
_original_assert_true_func = None
_assert_false_count = 0
_original_assert_false_func = None
_assert_greater_count = 0
_original_assert_greater_func = None

def setUp(self) -> None:
"""Common setup for all ref eager tests."""
Expand All @@ -142,6 +149,10 @@ def setUp(self) -> None:
RefEagerTestBase._assert_raises_count = 0
# Reset skipTest counter for this test
RefEagerTestBase._skip_test_count = 0
# Reset assertTrue/assertFalse/assertGreater counters
RefEagerTestBase._assert_true_count = 0
RefEagerTestBase._assert_false_count = 0
RefEagerTestBase._assert_greater_count = 0

# Patch torch.testing.assert_close to count calls
if RefEagerTestBase._original_assert_close_func is None:
Expand Down Expand Up @@ -189,6 +200,36 @@ def counting_pytest_raises(*args: object, **kwargs: object) -> object:

pytest.raises = counting_pytest_raises # type: ignore[assignment]

# Patch self.assertTrue to count calls
if RefEagerTestBase._original_assert_true_func is None:
RefEagerTestBase._original_assert_true_func = self.assertTrue

def counting_assert_true(*args: object, **kwargs: object) -> None:
RefEagerTestBase._assert_true_count += 1
return RefEagerTestBase._original_assert_true_func(*args, **kwargs) # type: ignore[misc]

self.assertTrue = counting_assert_true # type: ignore[assignment]

# Patch self.assertFalse to count calls
if RefEagerTestBase._original_assert_false_func is None:
RefEagerTestBase._original_assert_false_func = self.assertFalse

def counting_assert_false(*args: object, **kwargs: object) -> None:
RefEagerTestBase._assert_false_count += 1
return RefEagerTestBase._original_assert_false_func(*args, **kwargs) # type: ignore[misc]

self.assertFalse = counting_assert_false # type: ignore[assignment]

# Patch self.assertGreater to count calls
if RefEagerTestBase._original_assert_greater_func is None:
RefEagerTestBase._original_assert_greater_func = self.assertGreater

def counting_assert_greater(*args: object, **kwargs: object) -> None:
RefEagerTestBase._assert_greater_count += 1
return RefEagerTestBase._original_assert_greater_func(*args, **kwargs) # type: ignore[misc]

self.assertGreater = counting_assert_greater # type: ignore[assignment]

def tearDown(self) -> None:
"""Common teardown with assertion counting check."""
# If not in ref eager mode, skip the teardown logic
Expand All @@ -215,17 +256,27 @@ def tearDown(self) -> None:
)

if not is_skipped:
# Check that either assert_close, assertRaises, or skipTest was called
# Check that either assert_close, assertRaises, skipTest, assertTrue, assertFalse, or assertGreater was called
total_assertions = (
RefEagerTestBase._assert_close_count
+ RefEagerTestBase._assert_raises_count
+ RefEagerTestBase._skip_test_count
+ RefEagerTestBase._assert_true_count
+ RefEagerTestBase._assert_false_count
+ RefEagerTestBase._assert_greater_count
)
self.assertGreater( # type: ignore[attr-defined]
total_assertions,
0,
f"Test {self._testMethodName} did not call torch.testing.assert_close, assertRaises, or skipTest", # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
)
# Need to use the original assertGreater to avoid recursion
if RefEagerTestBase._original_assert_greater_func is not None:
RefEagerTestBase._original_assert_greater_func( # type: ignore[misc]
total_assertions,
0,
f"Test {self._testMethodName} did not call torch.testing.assert_close, assertRaises, skipTest, assertTrue, assertFalse, or assertGreater", # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
)
else:
# Fallback if original not available
assert total_assertions > 0, (
f"Test {self._testMethodName} did not call any assertion methods" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
)
finally:
# Restore the original assert_close function
if RefEagerTestBase._original_assert_close_func is not None:
Expand All @@ -245,6 +296,18 @@ def tearDown(self) -> None:
if RefEagerTestBase._original_pytest_raises is not None: # pyright: ignore[reportAttributeAccessIssue]
pytest.raises = RefEagerTestBase._original_pytest_raises # pyright: ignore[reportAttributeAccessIssue]

# Restore the original assertTrue function
if RefEagerTestBase._original_assert_true_func is not None:
self.assertTrue = RefEagerTestBase._original_assert_true_func

# Restore the original assertFalse function
if RefEagerTestBase._original_assert_false_func is not None:
self.assertFalse = RefEagerTestBase._original_assert_false_func

# Restore the original assertGreater function
if RefEagerTestBase._original_assert_greater_func is not None:
self.assertGreater = RefEagerTestBase._original_assert_greater_func

super().tearDown() # type: ignore[misc]

# NOTE: We no-op these methods because they commonly check behaviors that are not relevant in ref eager mode.
Expand Down
53 changes: 53 additions & 0 deletions test/test_rng.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
This file is automatically generated by assertExpectedJournal calls in test_rng.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestRNG.test_multiple_rng_ops)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal_stride_0, normal_stride_1, rand1_stride_0, rand1_stride_1, rand2_stride_0, rand2_stride_1, randn_a_stride_0, randn_a_stride_1, randn_b_stride_0, randn_b_stride_1, randn_c_stride_0, randn_c_stride_1, uniform_stride_0, uniform_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer):
num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < m
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < n
rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(rand1 + (indices_0[:, None] * rand1_stride_0 + indices_1[None, :] * rand1_stride_1), rand, mask_0[:, None] & mask_1[None, :])
rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(rand2 + (indices_0[:, None] * rand2_stride_0 + indices_1[None, :] * rand2_stride_1), rand_1, mask_0[:, None] & mask_1[None, :])
rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(uniform + (indices_0[:, None] * uniform_stride_0 + indices_1[None, :] * uniform_stride_1), rand_2, mask_0[:, None] & mask_1[None, :])
randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(normal + (indices_0[:, None] * normal_stride_0 + indices_1[None, :] * normal_stride_1), randn, mask_0[:, None] & mask_1[None, :])
randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(randn_a + (indices_0[:, None] * randn_a_stride_0 + indices_1[None, :] * randn_a_stride_1), randn_1, mask_0[:, None] & mask_1[None, :])
randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(randn_b + (indices_0[:, None] * randn_b_stride_0 + indices_1[None, :] * randn_b_stride_1), randn_2, mask_0[:, None] & mask_1[None, :])
randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(randn_c + (indices_0[:, None] * randn_c_stride_0 + indices_1[None, :] * randn_c_stride_1), randn_3, mask_0[:, None] & mask_1[None, :])

def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
from torch._inductor import inductor_prims
_rng_seed_buffer = inductor_prims.seeds(7, torch.device('cuda'))
rand1 = torch.zeros_like(x)
rand2 = torch.zeros_like(x)
uniform = torch.zeros_like(x)
normal = torch.zeros_like(x)
randn_a = torch.zeros_like(x)
randn_b = torch.zeros_like(x)
randn_c = torch.zeros_like(x)
m, n = x.shape
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 32
_launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=3)
randn_sum = randn_a + randn_b + randn_c
return (rand1, rand2, uniform, normal, randn_sum)
Loading
Loading