Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 66 additions & 40 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import contextlib
import dataclasses
import functools
import itertools
import operator
import re
import textwrap
Expand Down Expand Up @@ -1216,55 +1215,82 @@ def add_tile_with_offset_metadata(graph_info: GraphInfo) -> None:
"""
graph = graph_info.graph
env = CompileEnvironment.current()

for node in itertools.chain(
graph.find_nodes(op="call_function", target=operator.add),
graph.find_nodes(op="call_function", target=torch.ops.aten.add.Tensor),
):
# Check if this is tile.index + offset pattern
# args[0] should be tile_index result, args[1] should be int/SymInt
if len(node.args) != 2 and not node.kwargs:
continue
left_arg, right_arg = node.args

# Check if left argument is a tile_index call
add_targets = (operator.add, torch.ops.aten.add.Tensor)
offset_types = (int, torch.SymInt)
for node in graph.nodes:
if (
not isinstance(left_arg, torch.fx.Node)
or left_arg.op != "call_function"
or left_arg.target != hl.tile_index
node.op != "call_function"
or node.target not in add_targets
or node.kwargs
or len(node.args) != 2
):
continue

# Check if right argument is an integer offset
# It could be a constant, SymInt node, or another value
# We accept int, SymInt, or nodes that represent them
offset = None
if isinstance(right_arg, (int, torch.SymInt)):
offset = right_arg
elif isinstance(right_arg, torch.fx.Node):
# Check the node's metadata for the value
val = right_arg.meta.get("val")
if isinstance(val, (int, torch.SymInt)):
offset = val

if offset is None:
continue
block_id: int | None = None
total_offset: int | torch.SymInt = 0
valid = True

# Extract the block_id from the tile_index call
tile_arg = left_arg.args[0]
block_id = None
if isinstance(tile_arg, torch.fx.Node) and isinstance(
tile_arg.meta["val"], torch.SymInt
):
block_id = env.get_block_id(tile_arg.meta["val"])
for arg in node.args:
tile_offset_value: int | torch.SymInt | None = None
arg_block_id: int | None = None

if isinstance(arg, torch.fx.Node):
meta_tile = arg.meta.get("tile_with_offset")
if meta_tile is not None:
arg_block_id = meta_tile.get("block_id")
if arg_block_id is None:
valid = False
break
tile_offset_value = meta_tile.get("offset", 0)
elif (
arg.op == "call_function"
and arg.target == hl.tile_index
and arg.args
and isinstance(arg.args[0], torch.fx.Node)
):
tile_val = arg.args[0].meta.get("val")
if isinstance(tile_val, torch.SymInt):
arg_block_id = env.get_block_id(tile_val)
if arg_block_id is None:
valid = False
break
tile_offset_value = 0
else:
val = arg.meta.get("val")
if isinstance(val, offset_types):
total_offset = total_offset + val
continue

if arg_block_id is not None:
if block_id is not None:
valid = False
break
if tile_offset_value is None:
tile_offset_value = 0
block_id = arg_block_id
total_offset = total_offset + tile_offset_value
continue

val = arg.meta.get("val")
if isinstance(val, offset_types):
total_offset = total_offset + val
continue

valid = False
break

if isinstance(arg, offset_types):
total_offset = total_offset + arg
continue
valid = False
break

if block_id is None:
if not valid or block_id is None:
continue

# Add metadata to mark this as a tile+offset node
node.meta["tile_with_offset"] = {
"block_id": block_id,
"offset": offset,
"offset": total_offset,
}


Expand Down
30 changes: 30 additions & 0 deletions test/test_indexing.expected
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,36 @@ def pairwise_add(x: torch.Tensor, *, _launcher=_default_launcher):
_launcher(_helion_pairwise_add, (triton.cdiv(499, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
return out

--- assertExpectedJournal(TestIndexing.test_pairwise_add_commuted_and_multi_offset)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_pairwise_add_variants(x, out, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
v_0 = tl.full([], 1, tl.int32)
v_1 = indices_0 + v_0
left = tl.load(tl.make_block_ptr(x, [256], [1], [offset_0 + 1], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero')
v_2 = tl.full([], 1, tl.int32)
v_3 = indices_0 + v_2
v_4 = tl.full([], 2, tl.int32)
v_5 = v_3 + v_4
right = tl.load(tl.make_block_ptr(x, [256], [1], [offset_0 + 3], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero')
v_6 = left + right
tl.store(tl.make_block_ptr(out, [253], [1], [offset_0], [_BLOCK_SIZE_0], [0]), v_6, boundary_check=[0])

def pairwise_add_variants(x: torch.Tensor, *, _launcher=_default_launcher):
out = x.new_empty([x.size(0) - 3])
_BLOCK_SIZE_0 = 32
_launcher(_helion_pairwise_add_variants, (triton.cdiv(253, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
return out

--- assertExpectedJournal(TestIndexing.test_reduction_tensor_descriptor_indexing_block_size)
from __future__ import annotations

Expand Down
34 changes: 34 additions & 0 deletions test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,27 @@ def pairwise_add(x: torch.Tensor) -> torch.Tensor:
torch.testing.assert_close(result, x[:-1] + x[1:])
self.assertExpectedJournal(code)

def test_pairwise_add_commuted_and_multi_offset(self):
@helion.kernel()
def pairwise_add_variants(x: torch.Tensor) -> torch.Tensor:
out = x.new_empty([x.size(0) - 3])
for tile in hl.tile(out.size(0)):
left = x[1 + tile.index]
right = x[tile.index + 1 + 2]
out[tile] = left + right
return out

x = torch.randn([256], device=DEVICE)
code, result = code_and_output(
pairwise_add_variants,
(x,),
block_size=32,
indexing="block_ptr",
)
expected = x[1:-2] + x[3:]
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

def test_mask_store(self):
@helion.kernel
def masked_store(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -434,6 +455,19 @@ def run_case(
small_shape = (128, 128)
large_shape = (51200, 51200)

if DEVICE.type == "cuda":
free_bytes, _ = torch.cuda.mem_get_info()
element_size = 2 # torch.bfloat16 element size in bytes
# Worst case: inputs, kernel output, reference output, and temporary buffers.
# Give ourselves margin by budgeting for 5 tensors of this shape.
required_bytes = 5 * math.prod(large_shape) * element_size
if free_bytes < required_bytes:
required_gib = required_bytes / (1024**3)
available_gib = free_bytes / (1024**3)
self.skipTest(
f"Large BF16 add needs ~{required_gib:.1f} GiB free, only {available_gib:.1f} GiB available"
)

run_case(
small_shape,
index_dtype=torch.int32,
Expand Down
14 changes: 6 additions & 8 deletions test/test_persistent_kernels.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1598,32 +1598,30 @@ import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_test_kernel(x, result, x_size_0, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
def _helion_test_kernel(x, result, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
total_pids = tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(96, _BLOCK_SIZE_1)
block_size = tl.cdiv(total_pids, _NUM_SM)
start_pid = tl.program_id(0) * block_size
end_pid = tl.minimum(start_pid + block_size, total_pids)
for virtual_pid in tl.range(start_pid, end_pid, warp_specialize=True):
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
pid_0 = virtual_pid % num_blocks_0
pid_1 = virtual_pid // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < x_size_1
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
load = tl.load(x + (indices_0[:, None] * 96 + indices_1[None, :] * 1), None)
v_0 = 1.0
v_1 = load + v_0
tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
tl.store(result + (indices_0[:, None] * 96 + indices_1[None, :] * 1), v_1, None)

def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
result = x.new_empty(x.size())
_NUM_SM = helion.runtime.get_num_sm(x.device)
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
_launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
_launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
return result

--- assertExpectedJournal(TestPersistentKernels.test_persistent_loop_variable_names)
Expand Down
Loading