Skip to content

Support calling a Helion device loop defined in a helper function #865

@yf225

Description

@yf225

Raised by @v0i0 - this would be helpful for consolidating common logic for attention variants, e.g. in #764.

Example:

from __future__ import annotations

import torch

import helion
import helion.language as hl
from helion.language import Tile

# TODO: we could add some decorator here to specifically say that "this is a Helion device loop"
# e.g. `@helion.device_loop()`
def inner_device_loop(tile: Tile, x_chunk: torch.Tensor, y_chunk: torch.Tensor) -> torch.Tensor:
    """Device helper that performs its own hl.tile iteration."""
    tmp = torch.empty_like(x_chunk)

    # Second-level device loop: iterate over the elements owned by ``tile``
    for local_tile in hl.tile(tile.block_size, block_size=32):
        tmp[local_tile] = x_chunk[local_tile] + y_chunk[local_tile]

    return tmp


@helion.kernel()
def nested_device_loops(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Outer kernel that delegates a chunk of work to ``inner_device_loop``."""
    assert x.shape == y.shape
    out = torch.empty_like(x)

    # First-level device loop tiles the full iteration space.
    for tile in hl.tile(x.numel(), block_size=128):
        x_chunk = x[tile]
        y_chunk = y[tile]

        # Call into a helper that contains another device loop.
        out[tile] = inner_device_loop(tile, x_chunk, y_chunk)

    return out


def main() -> None:
    if not torch.cuda.is_available():
        raise RuntimeError("This example expects a CUDA-capable device.")

    size = 1 << 12
    x = torch.randn(size, device="cuda", dtype=torch.float32)
    y = torch.randn(size, device="cuda", dtype=torch.float32)

    out = nested_device_loops(x, y)
    torch.testing.assert_close(out, x + y)


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions