diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index d3611d1ab..932235661 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -5,7 +5,6 @@ import contextlib import dataclasses import functools -import itertools import operator import re import textwrap @@ -1216,55 +1215,82 @@ def add_tile_with_offset_metadata(graph_info: GraphInfo) -> None: """ graph = graph_info.graph env = CompileEnvironment.current() - - for node in itertools.chain( - graph.find_nodes(op="call_function", target=operator.add), - graph.find_nodes(op="call_function", target=torch.ops.aten.add.Tensor), - ): - # Check if this is tile.index + offset pattern - # args[0] should be tile_index result, args[1] should be int/SymInt - if len(node.args) != 2 and not node.kwargs: - continue - left_arg, right_arg = node.args - - # Check if left argument is a tile_index call + add_targets = (operator.add, torch.ops.aten.add.Tensor) + offset_types = (int, torch.SymInt) + for node in graph.nodes: if ( - not isinstance(left_arg, torch.fx.Node) - or left_arg.op != "call_function" - or left_arg.target != hl.tile_index + node.op != "call_function" + or node.target not in add_targets + or node.kwargs + or len(node.args) != 2 ): continue - # Check if right argument is an integer offset - # It could be a constant, SymInt node, or another value - # We accept int, SymInt, or nodes that represent them - offset = None - if isinstance(right_arg, (int, torch.SymInt)): - offset = right_arg - elif isinstance(right_arg, torch.fx.Node): - # Check the node's metadata for the value - val = right_arg.meta.get("val") - if isinstance(val, (int, torch.SymInt)): - offset = val - - if offset is None: - continue + block_id: int | None = None + total_offset: int | torch.SymInt = 0 + valid = True - # Extract the block_id from the tile_index call - tile_arg = left_arg.args[0] - block_id = None - if isinstance(tile_arg, torch.fx.Node) and isinstance( - tile_arg.meta["val"], torch.SymInt - ): - block_id = env.get_block_id(tile_arg.meta["val"]) + for arg in node.args: + tile_offset_value: int | torch.SymInt | None = None + arg_block_id: int | None = None + + if isinstance(arg, torch.fx.Node): + meta_tile = arg.meta.get("tile_with_offset") + if meta_tile is not None: + arg_block_id = meta_tile.get("block_id") + if arg_block_id is None: + valid = False + break + tile_offset_value = meta_tile.get("offset", 0) + elif ( + arg.op == "call_function" + and arg.target == hl.tile_index + and arg.args + and isinstance(arg.args[0], torch.fx.Node) + ): + tile_val = arg.args[0].meta.get("val") + if isinstance(tile_val, torch.SymInt): + arg_block_id = env.get_block_id(tile_val) + if arg_block_id is None: + valid = False + break + tile_offset_value = 0 + else: + val = arg.meta.get("val") + if isinstance(val, offset_types): + total_offset = total_offset + val + continue + + if arg_block_id is not None: + if block_id is not None: + valid = False + break + if tile_offset_value is None: + tile_offset_value = 0 + block_id = arg_block_id + total_offset = total_offset + tile_offset_value + continue + + val = arg.meta.get("val") + if isinstance(val, offset_types): + total_offset = total_offset + val + continue + + valid = False + break + + if isinstance(arg, offset_types): + total_offset = total_offset + arg + continue + valid = False + break - if block_id is None: + if not valid or block_id is None: continue - # Add metadata to mark this as a tile+offset node node.meta["tile_with_offset"] = { "block_id": block_id, - "offset": offset, + "offset": total_offset, } diff --git a/test/test_indexing.expected b/test/test_indexing.expected index 8e3a7752d..6843d16ba 100644 --- a/test/test_indexing.expected +++ b/test/test_indexing.expected @@ -347,6 +347,36 @@ def pairwise_add(x: torch.Tensor, *, _launcher=_default_launcher): _launcher(_helion_pairwise_add, (triton.cdiv(499, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) return out +--- assertExpectedJournal(TestIndexing.test_pairwise_add_commuted_and_multi_offset) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_pairwise_add_variants(x, out, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + v_0 = tl.full([], 1, tl.int32) + v_1 = indices_0 + v_0 + left = tl.load(tl.make_block_ptr(x, [256], [1], [offset_0 + 1], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero') + v_2 = tl.full([], 1, tl.int32) + v_3 = indices_0 + v_2 + v_4 = tl.full([], 2, tl.int32) + v_5 = v_3 + v_4 + right = tl.load(tl.make_block_ptr(x, [256], [1], [offset_0 + 3], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero') + v_6 = left + right + tl.store(tl.make_block_ptr(out, [253], [1], [offset_0], [_BLOCK_SIZE_0], [0]), v_6, boundary_check=[0]) + +def pairwise_add_variants(x: torch.Tensor, *, _launcher=_default_launcher): + out = x.new_empty([x.size(0) - 3]) + _BLOCK_SIZE_0 = 32 + _launcher(_helion_pairwise_add_variants, (triton.cdiv(253, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + return out + --- assertExpectedJournal(TestIndexing.test_reduction_tensor_descriptor_indexing_block_size) from __future__ import annotations diff --git a/test/test_indexing.py b/test/test_indexing.py index b9695b4bb..d7df03078 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -204,6 +204,27 @@ def pairwise_add(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, x[:-1] + x[1:]) self.assertExpectedJournal(code) + def test_pairwise_add_commuted_and_multi_offset(self): + @helion.kernel() + def pairwise_add_variants(x: torch.Tensor) -> torch.Tensor: + out = x.new_empty([x.size(0) - 3]) + for tile in hl.tile(out.size(0)): + left = x[1 + tile.index] + right = x[tile.index + 1 + 2] + out[tile] = left + right + return out + + x = torch.randn([256], device=DEVICE) + code, result = code_and_output( + pairwise_add_variants, + (x,), + block_size=32, + indexing="block_ptr", + ) + expected = x[1:-2] + x[3:] + torch.testing.assert_close(result, expected) + self.assertExpectedJournal(code) + def test_mask_store(self): @helion.kernel def masked_store(x: torch.Tensor) -> torch.Tensor: @@ -434,6 +455,19 @@ def run_case( small_shape = (128, 128) large_shape = (51200, 51200) + if DEVICE.type == "cuda": + free_bytes, _ = torch.cuda.mem_get_info() + element_size = 2 # torch.bfloat16 element size in bytes + # Worst case: inputs, kernel output, reference output, and temporary buffers. + # Give ourselves margin by budgeting for 5 tensors of this shape. + required_bytes = 5 * math.prod(large_shape) * element_size + if free_bytes < required_bytes: + required_gib = required_bytes / (1024**3) + available_gib = free_bytes / (1024**3) + self.skipTest( + f"Large BF16 add needs ~{required_gib:.1f} GiB free, only {available_gib:.1f} GiB available" + ) + run_case( small_shape, index_dtype=torch.int32, diff --git a/test/test_persistent_kernels.expected b/test/test_persistent_kernels.expected index 5f2453bcf..3d53434da 100644 --- a/test/test_persistent_kernels.expected +++ b/test/test_persistent_kernels.expected @@ -1598,32 +1598,30 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_test_kernel(x, result, x_size_0, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): - total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1) +def _helion_test_kernel(x, result, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + total_pids = tl.cdiv(64, _BLOCK_SIZE_0) * tl.cdiv(96, _BLOCK_SIZE_1) block_size = tl.cdiv(total_pids, _NUM_SM) start_pid = tl.program_id(0) * block_size end_pid = tl.minimum(start_pid + block_size, total_pids) for virtual_pid in tl.range(start_pid, end_pid, warp_specialize=True): - num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0) + num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0) pid_0 = virtual_pid % num_blocks_0 pid_1 = virtual_pid // num_blocks_0 offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) - mask_0 = indices_0 < x_size_0 offset_1 = pid_1 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) - mask_1 = indices_1 < x_size_1 - load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + load = tl.load(x + (indices_0[:, None] * 96 + indices_1[None, :] * 1), None) v_0 = 1.0 v_1 = load + v_0 - tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), v_1, mask_0[:, None] & mask_1[None, :]) + tl.store(result + (indices_0[:, None] * 96 + indices_1[None, :] * 1), v_1, None) def test_kernel(x: torch.Tensor, *, _launcher=_default_launcher): result = x.new_empty(x.size()) _NUM_SM = helion.runtime.get_num_sm(x.device) _BLOCK_SIZE_0 = 32 _BLOCK_SIZE_1 = 16 - _launcher(_helion_test_kernel, (_NUM_SM,), x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) + _launcher(_helion_test_kernel, (_NUM_SM,), x, result, _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2) return result --- assertExpectedJournal(TestPersistentKernels.test_persistent_loop_variable_names)