Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .matmul_ops import dot as dot
from .memory_ops import load as load
from .memory_ops import store as store
from .random_ops import rand as rand
from .reduce_ops import reduce as reduce
from .scan_ops import associative_scan as associative_scan
from .scan_ops import cumprod as cumprod
Expand Down
121 changes: 121 additions & 0 deletions helion/language/random_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from .._compiler.ast_extension import expr_from_string
from .._compiler.compile_environment import CompileEnvironment
from ..exc import NotInsideKernel
from . import _decorators
from .ref_tile import RefTile

if TYPE_CHECKING:
import ast

from .._compiler.inductor_lowering import CodegenState

__all__ = ["rand"]


@_decorators.api(tiles_as_sizes=True)
def rand(
shape: list[object],
seed: int,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
"""
The main propose of ``hl.rand`` is to explicitly pass a seed arg for deterministic
randomness in helion kernels, whereas ``torch.rand_like`` doesn't take seed arg
(though it can seeded globally)`. ``hl.rand`` lower to ``tl.rand(seed, offset)`` with ``offset``
built from a linear range over the allocation and reshaped to the given shape.

Note:
Only use within ``hl.tile()`` loops for creating local tensors.
For host allocations, use ``torch.rand()``.

Args:
shape: A list of sizes
seed: int seed for the random number generator
dtype: currently only float32 supported

Returns:
torch.Tensor: A device tensor of the given shape and dtype filled with random values

Examples:
.. code-block:: python

@helion.kernel
def process_kernel(x: torch.Tensor) -> torch.Tensor:
output = torch.zeros_like(x)
(m,) = x.shape
for (tile_m,) in hl.tile([m]):
output[tile_m] = hl.rand([tile_m], seed=seed)
return output

"""
raise NotInsideKernel


@_decorators.register_fake(rand)
def _rand_fake(
shape: list[int | torch.SymInt],
seed: int,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
if not isinstance(shape, (list, tuple)):
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
env = CompileEnvironment.current()
env.add_kernel_tensor_size(shape)
return torch.empty(
[*shape],
dtype=dtype,
device=env.device if device is None else device,
)


@_decorators.codegen(rand)
def _rand_codegen(state: CodegenState) -> ast.AST:
fake_value = state.fake_value
assert isinstance(fake_value, torch.Tensor)
shape_str = state.device_function.tile_strategy.shape_str(fake_value.size())

numel = " * ".join(shape_str.strip("[]").split(","))
seed_ast = state.ast_arg(1)
offs_expr = f"tl.arange(0, {numel}).reshape({shape_str})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect.

  1. Every tile will get the same RNG values.
  2. The RNG values will depend on the tile size due to the reshape

expr = f"tl.rand({{seed}}, {offs_expr})"

return expr_from_string(expr, seed=seed_ast)


@_decorators.get_masked_value(rand)
def _(
node: torch.fx.Node,
) -> float:
return 0


@_decorators.ref(rand)
def _(
shape: list[int | RefTile],
seed: int,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
processed_shape: list[int] = []
for s in shape:
if isinstance(s, RefTile):
processed_shape.append(s.end - s.begin)
else:
processed_shape.append(int(s))
env = CompileEnvironment.current()
gen = torch.Generator(device=env.device if device is None else device)
gen.manual_seed(seed)
return torch.rand(
processed_shape,
dtype=dtype,
generator=gen,
device=env.device if device is None else device,
)
91 changes: 91 additions & 0 deletions test/test_rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,97 @@ def randn_kernel_3d(x: torch.Tensor) -> torch.Tensor:
f"Slice {b_idx} std {slice_std} is not well distributed",
)

def test_hl_rand_1d(self):
@helion.kernel
def rand_kernel_tiled_1d(x: torch.Tensor, seed: int) -> torch.Tensor:
output = torch.zeros_like(x)
(m,) = x.shape
for (tile_m,) in hl.tile([m]):
output[tile_m] = hl.rand([tile_m], seed=seed)
return output

x_small = torch.ones(128, device=DEVICE)
_, output = code_and_output(rand_kernel_tiled_1d, (x_small, 42))
_, output2 = code_and_output(rand_kernel_tiled_1d, (x_small, 1337))

self.assertFalse(
torch.allclose(output, output2),
"Different seeds should produce different outputs",
)

_, output3 = code_and_output(rand_kernel_tiled_1d, (x_small, 42))
self.assertTrue(
torch.allclose(output, output3),
"Same seed should produce identical outputs",
)

# Check that all values are in [0, 1) range
self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0")
self.assertTrue(torch.all(output < 1.0), "All values should be < 1")

def test_hl_rand_2d(self):
@helion.kernel
def rand_kernel_tiled_2d(x: torch.Tensor, seed: int) -> torch.Tensor:
output = torch.zeros_like(x)
m, n = x.shape
for tile_m, tile_n in hl.tile([m, n]):
output[tile_m, tile_n] = hl.rand([tile_m, tile_n], seed=seed)
return output

x_small = torch.ones(128, 128, device=DEVICE)
_, output = code_and_output(rand_kernel_tiled_2d, (x_small, 42))
_, output2 = code_and_output(rand_kernel_tiled_2d, (x_small, 1337))

self.assertFalse(
torch.allclose(output, output2),
"Different seeds should produce different outputs",
)

_, output3 = code_and_output(rand_kernel_tiled_2d, (x_small, 42))
self.assertTrue(
torch.allclose(output, output3),
"Same seed should produce identical outputs",
)

self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0")
self.assertTrue(torch.all(output < 1.0), "All values should be < 1")

def test_hl_rand_3d(self):
@helion.kernel
def rand_kernel_tiled_3d(x: torch.Tensor, seed: int) -> torch.Tensor:
output = torch.zeros_like(x)
b, m, n = x.shape
for tile_b, tile_m, tile_n in hl.tile([b, m, n]):
output[tile_b, tile_m, tile_n] = hl.rand(
[tile_b, tile_m, tile_n], seed=seed
)
return output

x_small = torch.ones(16, 32, 64, device=DEVICE)
_, output = code_and_output(rand_kernel_tiled_3d, (x_small, 42))
_, output2 = code_and_output(rand_kernel_tiled_3d, (x_small, 1337))

self.assertFalse(
torch.allclose(output, output2),
"Different seeds should produce different outputs",
)

_, output3 = code_and_output(rand_kernel_tiled_3d, (x_small, 42))
self.assertTrue(
torch.allclose(output, output3),
"Same seed should produce identical outputs",
)

self.assertTrue(torch.all(output >= 0.0), "All values should be >= 0")
self.assertTrue(torch.all(output < 1.0), "All values should be < 1")

# Check distribution properties
mean_val = output.mean().item()
self.assertTrue(
0.4 < mean_val < 0.6,
f"Mean {mean_val:.3f} should be around 0.5 for uniform distribution",
)


if __name__ == "__main__":
unittest.main()
Loading