-
Notifications
You must be signed in to change notification settings - Fork 67
Closed
Labels
Description
e.g.
for i in hl.static_range(WORLD_SIZE):
buffer_rank = buf_tuple[i]
...This is generally useful for Helion + distributed examples where we have tuple of length world_size. Example: https://github.com/meta-pytorch/kraken/blob/693f252a3ec39309703e65ae47d0de144adfaeac/kraken/fused/gemm_one_shot_all_reduce_fused.py#L87
Repro:
"""
Comparison of tuple indexing in Helion vs Triton
This script demonstrates why hl.static_range fails with tuples
but tl.static_range works in Triton.
"""
import torch
import triton
import triton.language as tl
import helion
import helion.language as hl
# ============================================================================
# HELION KERNELS (demonstrating the limitation)
# ============================================================================
@helion.kernel(autotune_effort="none")
def helion_static_tuple_indexing_unrolled_works(
buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
) -> torch.Tensor:
"""Helion: This works - static indexing with constants"""
M, N = buf_tuple[0].shape
result = torch.zeros_like(buf_tuple[0])
for tile_m, tile_n in hl.tile([M, N]):
# Static indexing works fine
val0 = buf_tuple[0][tile_m, tile_n] # OK - constant index
val1 = buf_tuple[1][tile_m, tile_n] # OK - constant index
val2 = buf_tuple[2][tile_m, tile_n] # OK - constant index
val3 = buf_tuple[3][tile_m, tile_n] # OK - constant index
result[tile_m, tile_n] = val0 + val1 + val2 + val3
return result
@helion.kernel(autotune_effort="none")
def helion_static_tuple_indexing_rolled_fails(
buf_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
WORLD_SIZE: hl.constexpr,
) -> torch.Tensor:
"""Helion: This fails - static indexing with hl.static_range"""
M, N = buf_tuple[0].shape
result = torch.zeros_like(buf_tuple[0])
for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
# This will fail with TypeInferenceError because 'i' is a SymIntType
# even though we're using static_range
for i in hl.static_range(WORLD_SIZE):
acc += buf_tuple[i][tile_m, tile_n] # ERROR: Cannot index tuple with SymIntType
result[tile_m, tile_n] = acc
return result
# ============================================================================
# TRITON KERNEL (showing the solution)
# ============================================================================
@triton.jit
def triton_static_tuple_indexing_kernel(
buf_tuple, # Tuple of buffer pointers
output_ptr,
M: tl.constexpr,
N: tl.constexpr,
world_size: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
"""Triton: This works - tl.static_range allows tuple indexing"""
# Get program IDs
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# Calculate offsets
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
# Initialize accumulator
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# This works! tl.static_range properly unrolls before type checking
# allowing buf_tuple[i] to work
for i in tl.static_range(world_size):
buffer_ptr = buf_tuple[i] # This works in Triton!
# Load from this buffer
ptrs = buffer_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
vals = tl.load(ptrs, mask=mask, other=0.0)
# Accumulate
acc += vals
# Store result
output_ptrs = output_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(output_ptrs, acc, mask=mask)
def triton_tuple_reduction(buf_tuple: tuple[torch.Tensor, ...]) -> torch.Tensor:
"""Wrapper to run the Triton kernel"""
M, N = buf_tuple[0].shape
world_size = len(buf_tuple)
output = torch.zeros((M, N), device=buf_tuple[0].device, dtype=torch.float32)
BLOCK_SIZE_M = 64
BLOCK_SIZE_N = 64
grid = (
triton.cdiv(M, BLOCK_SIZE_M),
triton.cdiv(N, BLOCK_SIZE_N),
)
triton_static_tuple_indexing_kernel[grid](
buf_tuple,
output,
M=M,
N=N,
world_size=world_size,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
return output
# ============================================================================
# COMPARISON TEST
# ============================================================================
def main():
# Create test tensors
world_size = 4
M, N = 128, 128
tensors = tuple(
torch.ones((M, N), device='cuda', dtype=torch.float32) * (i + 1)
for i in range(world_size)
)
print(f"\nTest setup:")
print(f" - {world_size} tensors of size {M}x{N}")
print(f" - tensor[0] filled with 1.0, tensor[1] with 2.0, etc.")
print(f" - Expected sum: {sum(i+1 for i in range(world_size))}")
# ========================================================================
# Test 1: Helion with static indexing (unrolled)
# ========================================================================
print("\n" + "-" * 80)
print("TEST 1: Helion with static indexing (buf_tuple[0], buf_tuple[1], ...)")
print("-" * 80)
try:
result = helion_static_tuple_indexing_unrolled_works(tensors)
print(f"✓ SUCCESS: Helion with static indexing (unrolled) works!")
print(f" Result: {result[0, 0].item()}")
except Exception as e:
print(f"✗ FAILED: {type(e).__name__}")
print(f" {str(e)[:200]}...")
# ========================================================================
# Test 2: Helion with hl.static_range (rolled)
# ========================================================================
print("\n" + "-" * 80)
print("TEST 2: Helion with hl.static_range (buf_tuple[i] where i from static_range)")
print("-" * 80)
try:
result = helion_static_tuple_indexing_rolled_fails(tensors, WORLD_SIZE=world_size)
print(f"✓ SUCCESS: Helion with hl.static_range indexing works!")
print(f" Result: {result[0, 0].item()}")
except Exception as e:
print(f"✗ FAILED: {type(e).__name__}")
print(f" {str(e)[:200]}...")
print("\n Key issue: 'i' is treated as SymIntType, not a constant")
print(" This happens because type checking occurs BEFORE loop unrolling")
# ========================================================================
# Test 3: Triton with tl.static_range
# ========================================================================
print("\n" + "-" * 80)
print("TEST 3: Triton with tl.static_range (buf_tuple[i] where i from static_range)")
print("-" * 80)
try:
result = triton_tuple_reduction(tensors)
print(f"✓ SUCCESS: Triton with tl.static_range indexing works!")
print(f" Result: {result[0, 0].item()}")
except Exception as e:
print(f"✗ FAILED: {type(e).__name__}")
print(f" {str(e)[:200]}...")
if __name__ == "__main__":
main()Branch: tuple_indexing_by_static_range