-
Notifications
You must be signed in to change notification settings - Fork 68
Closed
Labels
random opsRelated to random number generation in Helion (rng generator)Related to random number generation in Helion (rng generator)
Description
Describe the bug
Calling torch.rand with helion tiles as sizes fails, e.g torch.rand((tile_m, tile_n)).
This behavior contrasts with other supported torch APIs like torch.ones((tile_m, tile_n))
To Reproduce
Below is the simple matrix-multiply example. I added 1 + noise to each tile before saving. The 1s work, but the noise construction fails.
import torch
import helion
import helion.language as hl
from torch import Tensor
from typing import Callable
@helion.kernel(autotune_effort="none")
def matmul(
x: Tensor,
y: Tensor,
epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor] = lambda acc, tile: acc,
) -> Tensor:
"""
Performs matrix multiplication of x and y with an optional epilogue function.
Args:
x (Tensor): Left matrix of shape [m, k].
y (Tensor): Right matrix of shape [k, n].
epilogue (Callable, optional): Function applied to the accumulator and tile indices
after the matmul. Defaults to identity (no change).
Returns:
Tensor: Resulting matrix of shape [m, n].
"""
m, k = x.size()
k2, n = y.size()
assert k == k2, f"size mismatch {k} != {k2}"
out = torch.empty(
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
ones = torch.ones((tile_m, tile_n), dtype=torch.float32) # <<< succeeds
noise = torch.rand((tile_m, tile_n), dtype=torch.float32) # <<< fails
acc = acc + ones + noise
out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n))
return out
torch.set_default_device("cuda:0")
x = torch.randn(10, 10)
y = torch.randn(10, 10)
x = matmul(x, y)
print(x.shape)
print("finished")Expected behavior
torch.rand() should work just like torch.ones()
Versions
helion=0.2.1
torch=2.9.0
Running on RTX 3090 GPU.
Metadata
Metadata
Assignees
Labels
random opsRelated to random number generation in Helion (rng generator)Related to random number generation in Helion (rng generator)