Skip to content

Support tuple indexing with hl.static_range(size) iterator within device loop #1118

@yf225

Description

@yf225

e.g.

for i in hl.static_range(WORLD_SIZE):
    buffer_rank = buf_tuple[i]
    ...

This is generally useful for Helion + distributed examples where we have tuple of length world_size. Example: https://github.com/meta-pytorch/kraken/blob/693f252a3ec39309703e65ae47d0de144adfaeac/kraken/fused/gemm_one_shot_all_reduce_fused.py#L87

Repro:

"""
Comparison of tuple indexing in Helion vs Triton
This script demonstrates why hl.static_range fails with tuples
but tl.static_range works in Triton.
"""

import torch
import triton
import triton.language as tl
import helion
import helion.language as hl


# ============================================================================
# HELION KERNELS (demonstrating the limitation)
# ============================================================================

@helion.kernel(autotune_effort="none")
def helion_static_tuple_indexing_unrolled_works(
    buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
) -> torch.Tensor:
    """Helion: This works - static indexing with constants"""
    M, N = buf_tuple[0].shape
    result = torch.zeros_like(buf_tuple[0])

    for tile_m, tile_n in hl.tile([M, N]):
        # Static indexing works fine
        val0 = buf_tuple[0][tile_m, tile_n]  # OK - constant index
        val1 = buf_tuple[1][tile_m, tile_n]  # OK - constant index
        val2 = buf_tuple[2][tile_m, tile_n]  # OK - constant index
        val3 = buf_tuple[3][tile_m, tile_n]  # OK - constant index

        result[tile_m, tile_n] = val0 + val1 + val2 + val3

    return result


@helion.kernel(autotune_effort="none")
def helion_static_tuple_indexing_rolled_fails(
    buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
    WORLD_SIZE: hl.constexpr,
) -> torch.Tensor:
    """Helion: This fails - static indexing with hl.static_range"""
    M, N = buf_tuple[0].shape
    result = torch.zeros_like(buf_tuple[0])

    for tile_m, tile_n in hl.tile([M, N]):
        acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)

        # This will fail with TypeInferenceError because 'i' is a SymIntType
        # even though we're using static_range
        for i in hl.static_range(WORLD_SIZE):
            acc += buf_tuple[i][tile_m, tile_n]  # ERROR: Cannot index tuple with SymIntType

        result[tile_m, tile_n] = acc

    return result


# ============================================================================
# TRITON KERNEL (showing the solution)
# ============================================================================

@triton.jit
def triton_static_tuple_indexing_kernel(
    buf_tuple,  # Tuple of buffer pointers
    output_ptr,
    M: tl.constexpr,
    N: tl.constexpr,
    world_size: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
):
    """Triton: This works - tl.static_range allows tuple indexing"""

    # Get program IDs
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Calculate offsets
    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

    # Initialize accumulator
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # This works! tl.static_range properly unrolls before type checking
    # allowing buf_tuple[i] to work
    for i in tl.static_range(world_size):
        buffer_ptr = buf_tuple[i]  # This works in Triton!

        # Load from this buffer
        ptrs = buffer_ptr + offs_m[:, None] * N + offs_n[None, :]
        mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
        vals = tl.load(ptrs, mask=mask, other=0.0)

        # Accumulate
        acc += vals

    # Store result
    output_ptrs = output_ptr + offs_m[:, None] * N + offs_n[None, :]
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(output_ptrs, acc, mask=mask)


def triton_tuple_reduction(buf_tuple: tuple[torch.Tensor, ...]) -> torch.Tensor:
    """Wrapper to run the Triton kernel"""
    M, N = buf_tuple[0].shape
    world_size = len(buf_tuple)

    output = torch.zeros((M, N), device=buf_tuple[0].device, dtype=torch.float32)

    BLOCK_SIZE_M = 64
    BLOCK_SIZE_N = 64

    grid = (
        triton.cdiv(M, BLOCK_SIZE_M),
        triton.cdiv(N, BLOCK_SIZE_N),
    )

    triton_static_tuple_indexing_kernel[grid](
        buf_tuple,
        output,
        M=M,
        N=N,
        world_size=world_size,
        BLOCK_SIZE_M=BLOCK_SIZE_M,
        BLOCK_SIZE_N=BLOCK_SIZE_N,
    )

    return output


# ============================================================================
# COMPARISON TEST
# ============================================================================

def main():
    # Create test tensors
    world_size = 4
    M, N = 128, 128

    tensors = tuple(
        torch.ones((M, N), device='cuda', dtype=torch.float32) * (i + 1)
        for i in range(world_size)
    )

    print(f"\nTest setup:")
    print(f"  - {world_size} tensors of size {M}x{N}")
    print(f"  - tensor[0] filled with 1.0, tensor[1] with 2.0, etc.")
    print(f"  - Expected sum: {sum(i+1 for i in range(world_size))}")

    # ========================================================================
    # Test 1: Helion with static indexing (unrolled)
    # ========================================================================
    print("\n" + "-" * 80)
    print("TEST 1: Helion with static indexing (buf_tuple[0], buf_tuple[1], ...)")
    print("-" * 80)
    try:
        result = helion_static_tuple_indexing_unrolled_works(tensors)
        print(f"✓ SUCCESS: Helion with static indexing (unrolled) works!")
        print(f"  Result: {result[0, 0].item()}")
    except Exception as e:
        print(f"✗ FAILED: {type(e).__name__}")
        print(f"  {str(e)[:200]}...")

    # ========================================================================
    # Test 2: Helion with hl.static_range (rolled)
    # ========================================================================
    print("\n" + "-" * 80)
    print("TEST 2: Helion with hl.static_range (buf_tuple[i] where i from static_range)")
    print("-" * 80)
    try:
        result = helion_static_tuple_indexing_rolled_fails(tensors, WORLD_SIZE=world_size)
        print(f"✓ SUCCESS: Helion with hl.static_range indexing works!")
        print(f"  Result: {result[0, 0].item()}")
    except Exception as e:
        print(f"✗ FAILED: {type(e).__name__}")
        print(f"  {str(e)[:200]}...")
        print("\n  Key issue: 'i' is treated as SymIntType, not a constant")
        print("  This happens because type checking occurs BEFORE loop unrolling")

    # ========================================================================
    # Test 3: Triton with tl.static_range
    # ========================================================================
    print("\n" + "-" * 80)
    print("TEST 3: Triton with tl.static_range (buf_tuple[i] where i from static_range)")
    print("-" * 80)
    try:
        result = triton_tuple_reduction(tensors)
        print(f"✓ SUCCESS: Triton with tl.static_range indexing works!")
        print(f"  Result: {result[0, 0].item()}")
    except Exception as e:
        print(f"✗ FAILED: {type(e).__name__}")
        print(f"  {str(e)[:200]}...")


if __name__ == "__main__":
    main()

Branch: tuple_indexing_by_static_range

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions