Skip to content

Random number construction fails within helion loop #1041

@tomasruizt

Description

@tomasruizt

Describe the bug
Calling torch.rand with helion tiles as sizes fails, e.g torch.rand((tile_m, tile_n)).
This behavior contrasts with other supported torch APIs like torch.ones((tile_m, tile_n))

To Reproduce
Below is the simple matrix-multiply example. I added 1 + noise to each tile before saving. The 1s work, but the noise construction fails.

import torch
import helion
import helion.language as hl
from torch import Tensor
from typing import Callable


@helion.kernel(autotune_effort="none")
def matmul(
    x: Tensor,
    y: Tensor,
    epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor] = lambda acc, tile: acc,
) -> Tensor:
    """
    Performs matrix multiplication of x and y with an optional epilogue function.
    Args:
        x (Tensor): Left matrix of shape [m, k].
        y (Tensor): Right matrix of shape [k, n].
        epilogue (Callable, optional): Function applied to the accumulator and tile indices
            after the matmul. Defaults to identity (no change).
    Returns:
        Tensor: Resulting matrix of shape [m, n].
    """
    m, k = x.size()
    k2, n = y.size()
    assert k == k2, f"size mismatch {k} != {k2}"
    out = torch.empty(
        [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
    )
    for tile_m, tile_n in hl.tile([m, n]):
        acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
        for tile_k in hl.tile(k):
            acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
            ones = torch.ones((tile_m, tile_n), dtype=torch.float32)  # <<< succeeds
            noise = torch.rand((tile_m, tile_n), dtype=torch.float32)  # <<< fails
            acc = acc + ones + noise
        out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n))
    return out


torch.set_default_device("cuda:0")
x = torch.randn(10, 10)
y = torch.randn(10, 10)
x = matmul(x, y)
print(x.shape)
print("finished")

Expected behavior
torch.rand() should work just like torch.ones()

Versions
helion=0.2.1
torch=2.9.0

Running on RTX 3090 GPU.

Metadata

Metadata

Assignees

Labels

random opsRelated to random number generation in Helion (rng generator)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions