Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from .ast_extension import statement_from_string
from .compile_environment import CompileEnvironment
from .compile_environment import FixedBlockSizeSource
from .device_function import SymbolArgument
from .device_function import VarInfo
from .device_function import contains_only_block_size_symbols
from .dtype_utils import cast_ast
Expand Down Expand Up @@ -1559,27 +1558,29 @@ def _codegen_rng_op(
node: The FX node for this operation
rng_function: Either "rand" or "randn"
"""
from .generate_ast import GenerateAST

assert rng_function in ["rand", "randn"]
assert isinstance(ctx.cg, GenerateAST)

# Get unique seed index for this RNG operation
device_fn = ctx.cg.device_function
seed_index = device_fn.allocate_rng_seed()

# Get dimensionality and dtype
assert hasattr(node, "meta") and "val" in node.meta
ndim = node.meta["val"].ndim
fake_value = node.meta["val"]
ndim = fake_value.ndim
dtype = node.kwargs.get("dtype", None)

# Get the dimension variable names from the device function's symbol arguments
device_fn = ctx.cg.device_function
symbol_args = [
arg for arg in device_fn.arguments if isinstance(arg, SymbolArgument)
]

# Extract dimension names - they should be the last ndim symbol arguments
# Get dimension names for offset calculation
env = CompileEnvironment.current()
dim_names = []
assert len(symbol_args) >= ndim, "Not enough symbol arguments for dimensions"
dim_names = [arg.name for arg in symbol_args[-ndim:]]
for size in fake_value.size():
block_id = env.get_block_id(size)
assert block_id is not None
block_size = env.block_sizes[block_id].size
dim_names.append(device_fn.literal_expr(block_size))

offset_parts = []

Expand Down
20 changes: 20 additions & 0 deletions helion/language/ref_tile.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TypeVar

import torch
from torch.utils._pytree import tree_map_only

from .. import exc
from .._utils import convert_tile_indices_to_slices
Expand All @@ -12,6 +14,8 @@
if TYPE_CHECKING:
from collections.abc import Callable

_T = TypeVar("_T")


_ADD_OPS: set[object] = {
torch.add,
Expand Down Expand Up @@ -71,8 +75,24 @@ def __torch_function__(
if func in _SUB_OPS:
return cls._handle_sub(args)

# For any other torch.* function or torch.Tensor.* method, convert tiles to sizes
is_torch_func = getattr(func, "__module__", "") == "torch"
is_tensor_method = hasattr(torch.Tensor, getattr(func, "__name__", ""))
if is_torch_func or is_tensor_method:
new_args = cls._tiles_to_sizes(args)
new_kwargs = cls._tiles_to_sizes(kwargs) if kwargs else {}
return func(*new_args, **new_kwargs)

raise exc.IncorrectTileUsage(func)

@classmethod
def _tiles_to_sizes(cls, it: _T) -> _T:
return tree_map_only(RefTile, cls._tile_to_size, it)

@staticmethod
def _tile_to_size(tile: RefTile) -> int:
return tile.block_size

@classmethod
def _handle_add(cls, args: tuple[object, ...]) -> torch.Tensor:
tile, offset, flipped = cls._extract_tile_and_offset(args, torch.add)
Expand Down
36 changes: 17 additions & 19 deletions test/test_rng.expected
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,36 @@ import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal_stride_0, normal_stride_1, rand1_stride_0, rand1_stride_1, rand2_stride_0, rand2_stride_1, randn_a_stride_0, randn_a_stride_1, randn_b_stride_0, randn_b_stride_1, randn_c_stride_0, randn_c_stride_1, uniform_stride_0, uniform_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer):
def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, rng_seed_buffer):
# src[test_rng.py:N]: for tile_m, tile_n in hl.tile([m, n]):
num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0)
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < m
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < n
# src[test_rng.py:N]: rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(rand1 + (indices_0[:, None] * rand1_stride_0 + indices_1[None, :] * rand1_stride_1), rand, mask_0[:, None] & mask_1[None, :])
rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
tl.store(rand1 + (indices_0[:, None] * 64 + indices_1[None, :] * 1), rand, None)
# src[test_rng.py:N]: rand2[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(rand2 + (indices_0[:, None] * rand2_stride_0 + indices_1[None, :] * rand2_stride_1), rand_1, mask_0[:, None] & mask_1[None, :])
rand_1 = tl.rand(tl.load(rng_seed_buffer + 1), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
tl.store(rand2 + (indices_0[:, None] * 64 + indices_1[None, :] * 1), rand_1, None)
# src[test_rng.py:N]: uniform[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(uniform + (indices_0[:, None] * uniform_stride_0 + indices_1[None, :] * uniform_stride_1), rand_2, mask_0[:, None] & mask_1[None, :])
rand_2 = tl.rand(tl.load(rng_seed_buffer + 2), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
tl.store(uniform + (indices_0[:, None] * 64 + indices_1[None, :] * 1), rand_2, None)
# src[test_rng.py:N]: normal[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(normal + (indices_0[:, None] * normal_stride_0 + indices_1[None, :] * normal_stride_1), randn, mask_0[:, None] & mask_1[None, :])
randn = tl.randn(tl.load(rng_seed_buffer + 3), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
tl.store(normal + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn, None)
# src[test_rng.py:N]: randn_a[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(randn_a + (indices_0[:, None] * randn_a_stride_0 + indices_1[None, :] * randn_a_stride_1), randn_1, mask_0[:, None] & mask_1[None, :])
randn_1 = tl.randn(tl.load(rng_seed_buffer + 4), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
tl.store(randn_a + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn_1, None)
# src[test_rng.py:N]: randn_b[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(randn_b + (indices_0[:, None] * randn_b_stride_0 + indices_1[None, :] * randn_b_stride_1), randn_2, mask_0[:, None] & mask_1[None, :])
randn_2 = tl.randn(tl.load(rng_seed_buffer + 5), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
tl.store(randn_b + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn_2, None)
# src[test_rng.py:N]: randn_c[tile_m, tile_n] = torch.randn_like(x[tile_m, tile_n])
randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * n + indices_1[None, :]).to(tl.float32)
tl.store(randn_c + (indices_0[:, None] * randn_c_stride_0 + indices_1[None, :] * randn_c_stride_1), randn_3, mask_0[:, None] & mask_1[None, :])
randn_3 = tl.randn(tl.load(rng_seed_buffer + 6), indices_0[:, None] * 64 + indices_1[None, :]).to(tl.float32)
tl.store(randn_c + (indices_0[:, None] * 64 + indices_1[None, :] * 1), randn_3, None)

def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
from torch._inductor import inductor_prims
Expand Down Expand Up @@ -73,7 +71,7 @@ def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_rng.py:N]: # Two independent rand operations
# src[test_rng.py:N]: rand1[tile_m, tile_n] = torch.rand_like(x[tile_m, tile_n])
# src[test_rng.py:N-N]: ...
_launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, normal.stride(0), normal.stride(1), rand1.stride(0), rand1.stride(1), rand2.stride(0), rand2.stride(1), randn_a.stride(0), randn_a.stride(1), randn_b.stride(0), randn_b.stride(1), randn_c.stride(0), randn_c.stride(1), uniform.stride(0), uniform.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1)
_launcher(_helion_multiple_rng_ops_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1),), rand1, rand2, uniform, normal, randn_a, randn_b, randn_c, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1)
# src[test_rng.py:N]: randn_sum = randn_a + randn_b + randn_c
randn_sum = randn_a + randn_b + randn_c
# src[test_rng.py:N]: return rand1, rand2, uniform, normal, randn_sum
Expand Down
114 changes: 108 additions & 6 deletions test/test_rng.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from typing import Callable
import unittest

import torch
Expand All @@ -16,7 +17,7 @@ class TestRNG(RefEagerTestBase, TestCase):
def test_rand(self):
"""Test RNG seeding behavior, reproducibility, output range, and distribution."""

@helion.kernel(static_shapes=False)
@helion.kernel(static_shapes=True, autotune_effort="none")
def rand_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
output = torch.zeros_like(x)
m, n = x.shape
Expand Down Expand Up @@ -87,7 +88,7 @@ def rand_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
def test_rand_3d_tensor(self):
"""Test 3D RNG with tiled operations."""

@helion.kernel(static_shapes=False)
@helion.kernel(static_shapes=True, autotune_effort="none")
def rand_kernel_3d(x: torch.Tensor) -> torch.Tensor:
output = torch.zeros_like(x)
b, m, n = x.shape
Expand Down Expand Up @@ -135,7 +136,7 @@ def rand_kernel_3d(x: torch.Tensor) -> torch.Tensor:
def test_multiple_rng_ops(self):
"""Test multiple RNG operations: independence, reproducibility, mixed rand/randn."""

@helion.kernel(static_shapes=False)
@helion.kernel(static_shapes=True, autotune_effort="none")
def multiple_rng_ops_kernel(
x: torch.Tensor,
) -> tuple[
Expand Down Expand Up @@ -258,7 +259,7 @@ def multiple_rng_ops_kernel(
def test_randn_different_seeds_tiled(self):
"""Test that different torch.manual_seed values produce different outputs for randn."""

@helion.kernel(static_shapes=False)
@helion.kernel(static_shapes=True, autotune_effort="none")
def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
output = torch.zeros_like(x)
m, n = x.shape
Expand All @@ -280,7 +281,7 @@ def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
def test_randn_normal_distribution(self):
"""Test that torch.randn_like produces normal distribution (mean≈0, std≈1)."""

@helion.kernel(static_shapes=False)
@helion.kernel(static_shapes=True, autotune_effort="none")
def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
output = torch.zeros_like(x)
m, n = x.shape
Expand Down Expand Up @@ -315,7 +316,7 @@ def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor:
def test_randn_3d_tensor(self):
"""Test 3D randn with tiled operations."""

@helion.kernel(static_shapes=False)
@helion.kernel(static_shapes=True, autotune_effort="none")
def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor:
output = torch.zeros_like(x)
b, m, n = x.shape
Expand Down Expand Up @@ -348,6 +349,107 @@ def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor:
f"Slice {b_idx} std {slice_std} is not well distributed",
)

def _test_rng_with_dynamic_tile_sizes(self, rng_func, is_uniform, rng_name):
"""Common test logic for RNG operations with dynamic tile sizes."""

# Single kernel that takes an RNG callable as a parameter
@helion.kernel(static_shapes=True, autotune_effort="none")
def rng_kernel(
x: torch.Tensor,
rng_func: Callable[[int, int, torch.dtype], torch.Tensor],
) -> torch.Tensor:
output = torch.zeros_like(x)
m, n = x.shape
for tile_m, tile_n in hl.tile([m, n]):
output[tile_m, tile_n] = rng_func(tile_m, tile_n, x.dtype)
return output

x = torch.ones(64, 64, device=DEVICE)
torch.manual_seed(42)
_code, output = code_and_output(rng_kernel, (x, rng_func))

# Check distribution properties based on RNG type
if is_uniform:
# For rand: values in [0, 1), mean ~0.5
self.assertTrue(
torch.all(output >= 0.0), f"{rng_name}: All values should be >= 0"
)
self.assertTrue(
torch.all(output < 1.0), f"{rng_name}: All values should be < 1"
)
mean_val = output.mean().item()
self.assertTrue(
0.4 < mean_val < 0.6,
f"{rng_name}: Mean {mean_val:.3f} should be ~0.5",
)
else:
# For randn: mean ~0, std ~1
mean_val = output.mean().item()
std_val = output.std().item()
self.assertTrue(
-0.15 < mean_val < 0.15, f"{rng_name}: Mean {mean_val:.3f} should be ~0"
)
self.assertTrue(
0.9 < std_val < 1.1, f"{rng_name}: Std {std_val:.3f} should be ~1"
)

# Test reproducibility with same seed
torch.manual_seed(42)
_code2, output2 = code_and_output(rng_kernel, (x, rng_func))
torch.testing.assert_close(
output,
output2,
msg=f"{rng_name}: Same seed should produce identical outputs",
)

# Test that different seeds produce different outputs
torch.manual_seed(99)
_code3, output3 = code_and_output(rng_kernel, (x, rng_func))
self.assertFalse(
torch.allclose(output, output3),
f"{rng_name}: Different seeds should produce different outputs",
)

def test_rand_with_dynamic_tile_sizes(self):
"""Test torch.rand with dynamic tile dimensions."""
self._test_rng_with_dynamic_tile_sizes(
rng_func=lambda tile_m, tile_n, dtype: torch.rand(
(tile_m, tile_n), dtype=dtype, device=DEVICE
),
is_uniform=True,
rng_name="rand",
)

def test_rand_like_with_dynamic_tile_sizes(self):
"""Test torch.rand_like with dynamic tile dimensions."""
self._test_rng_with_dynamic_tile_sizes(
rng_func=lambda tile_m, tile_n, dtype: torch.rand_like(
torch.ones((tile_m, tile_n), dtype=dtype, device=DEVICE)
),
is_uniform=True,
rng_name="rand_like",
)

def test_randn_with_dynamic_tile_sizes(self):
"""Test torch.randn with dynamic tile dimensions."""
self._test_rng_with_dynamic_tile_sizes(
rng_func=lambda tile_m, tile_n, dtype: torch.randn(
(tile_m, tile_n), dtype=dtype, device=DEVICE
),
is_uniform=False,
rng_name="randn",
)

def test_randn_like_with_dynamic_tile_sizes(self):
"""Test torch.randn_like with dynamic tile dimensions."""
self._test_rng_with_dynamic_tile_sizes(
rng_func=lambda tile_m, tile_n, dtype: torch.randn_like(
torch.ones((tile_m, tile_n), dtype=dtype, device=DEVICE)
),
is_uniform=False,
rng_name="randn_like",
)


if __name__ == "__main__":
unittest.main()
Loading