From 65b56f11f5545fcd4e1c973578721e38b4a9d1cb Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Tue, 14 Oct 2025 09:44:27 -0700 Subject: [PATCH] Support tile+offset and tensor descriptors stack-info: PR: https://github.com/pytorch/helion/pull/928, branch: jansel/stack/189 --- helion/_compiler/device_ir.py | 77 ++++++++++++- helion/_compiler/indexing_strategy.py | 159 ++++++++++++++++++++++++-- helion/_compiler/type_propagation.py | 21 +++- helion/language/ref_tile.py | 78 +++++++++++++ test/test_examples.expected | 12 +- test/test_indexing.expected | 82 ++++++++++++- test/test_indexing.py | 68 +++++++++++ 7 files changed, 479 insertions(+), 18 deletions(-) diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index fd89f3ba5..d3611d1ab 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -5,6 +5,7 @@ import contextlib import dataclasses import functools +import itertools import operator import re import textwrap @@ -450,7 +451,18 @@ def _body(self, body: list[ast.stmt]) -> None: self.visit(stmt) def visit_BinOp(self, node: ast.BinOp) -> object: - return _eval_binary(node.op, self.visit(node.left), self.visit(node.right)) + left = self.visit(node.left) + right = self.visit(node.right) + # Special handling for Tile + offset: expand to tile.index + offset + # and mark with metadata for indexing strategies to recognize + if ( + isinstance(node.op, ast.Add) + and isinstance(left, Tile) + and isinstance(right, (int, torch.SymInt)) + ): + # Implicitly expand to tile.index + offset + left = hl.tile_index(left) + return _eval_binary(node.op, left, right) def visit_UnaryOp(self, node: ast.UnaryOp) -> object: return _eval_unary(node.op, self.visit(node.operand)) @@ -1128,6 +1140,7 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: prepare_graph_lowerings(graph.graph) for graph in device_ir.graphs: validate_host_tensor_usage(graph.graph) + add_tile_with_offset_metadata(graph) remove_unnecessary_tile_index(graph.graph) remove_unnecessary_masking(graph.graph) device_ir.build_rolled_reductions() @@ -1193,6 +1206,68 @@ def validate_host_tensor_usage(graph: torch.fx.Graph) -> None: raise exc.HostTensorDirectUsage(scalar_tensor_name, op_name) +def add_tile_with_offset_metadata(graph_info: GraphInfo) -> None: + """ + Recognize tile.index + offset patterns and add metadata to enable tensor descriptor indexing. + + This pass identifies FX nodes that represent `tile.index + offset` (where offset is an + integer or SymInt), and adds the `tile_with_offset` metadata to those nodes so that + indexing strategies can generate efficient code (e.g., tensor descriptors) for them. + """ + graph = graph_info.graph + env = CompileEnvironment.current() + + for node in itertools.chain( + graph.find_nodes(op="call_function", target=operator.add), + graph.find_nodes(op="call_function", target=torch.ops.aten.add.Tensor), + ): + # Check if this is tile.index + offset pattern + # args[0] should be tile_index result, args[1] should be int/SymInt + if len(node.args) != 2 and not node.kwargs: + continue + left_arg, right_arg = node.args + + # Check if left argument is a tile_index call + if ( + not isinstance(left_arg, torch.fx.Node) + or left_arg.op != "call_function" + or left_arg.target != hl.tile_index + ): + continue + + # Check if right argument is an integer offset + # It could be a constant, SymInt node, or another value + # We accept int, SymInt, or nodes that represent them + offset = None + if isinstance(right_arg, (int, torch.SymInt)): + offset = right_arg + elif isinstance(right_arg, torch.fx.Node): + # Check the node's metadata for the value + val = right_arg.meta.get("val") + if isinstance(val, (int, torch.SymInt)): + offset = val + + if offset is None: + continue + + # Extract the block_id from the tile_index call + tile_arg = left_arg.args[0] + block_id = None + if isinstance(tile_arg, torch.fx.Node) and isinstance( + tile_arg.meta["val"], torch.SymInt + ): + block_id = env.get_block_id(tile_arg.meta["val"]) + + if block_id is None: + continue + + # Add metadata to mark this as a tile+offset node + node.meta["tile_with_offset"] = { + "block_id": block_id, + "offset": offset, + } + + def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None: """ Remove unnecessary tile_index nodes from the graph. diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 32b479f9a..e10d34037 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -56,6 +56,51 @@ def _get_padded_iota_original_length( return length_arg except (AttributeError, IndexError, TypeError): pass + + return None + + +def _get_tile_with_offset_info( + k: object, state: CodegenState, k_index: int +) -> tuple[int, int | torch.SymInt] | None: + """Check if k is a tensor marked as tile.index + offset, return (block_id, offset) if so. + + Args: + k: The subscript element (fake value) + state: The codegen state containing the FX node + k_index: The index of k in the subscript list + """ + if not isinstance(k, torch.Tensor): + return None + + # During codegen, we don't have proxy mode, but we have the FX graph + # The state.fx_node is the load/store node, and its second argument (args[1]) + # is the list of subscript indices as FX nodes + if state.fx_node is None: + return None + + # Get the subscript list from the FX node's arguments + # args[0] is the tensor, args[1] is the subscript list + if len(state.fx_node.args) < 2: + return None + + subscript_arg = state.fx_node.args[1] + if not isinstance(subscript_arg, (list, tuple)): + return None + + # Find the FX node corresponding to this subscript element + if k_index >= len(subscript_arg): + return None + + fx_subscript_node = subscript_arg[k_index] + if not isinstance(fx_subscript_node, torch.fx.Node): + return None + + # Check if this FX node has the tile_with_offset metadata + meta = fx_subscript_node.meta.get("tile_with_offset") + if meta is not None: + return (meta["block_id"], meta["offset"]) + return None @@ -261,6 +306,7 @@ def valid_block_size( strides = fake_tensor.stride() size_stride = collections.deque(zip(sizes, strides, strict=True)) config = DeviceFunction.current().config + k_index = 0 # Track position for finding FX nodes for i, k in enumerate(subscript): if k is None: continue @@ -272,6 +318,16 @@ def valid_block_size( block_size = env.allocate_reduction_dimension(size).from_config(config) if not valid_block_size(block_size, stride, i): return False + k_index += 1 + elif ( + tile_info := _get_tile_with_offset_info(k, state, k_index) + ) is not None: + # Tensor marked as tile.index + offset + block_id, _ = tile_info + block_size = env.block_sizes[block_id].from_config(config) + if not valid_block_size(block_size, stride, i): + return False + k_index += 1 elif isinstance(k, torch.SymInt): block_id = env.get_block_id(k) if block_id is None: @@ -279,6 +335,7 @@ def valid_block_size( block_size = env.block_sizes[block_id].from_config(config) if not valid_block_size(block_size, stride, i): return False + k_index += 1 return True @@ -435,7 +492,9 @@ def codegen_load( ) -> ast.AST: tensor_like, dev_ptrs = stack_tensor indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask) - subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript) + subscripts_shape = SubscriptIndexing.compute_shape( + tensor_like, subscript, state + ) stack_shape = [*dev_ptrs.size()] mask_expr = StackIndexingStrategy.get_mask_expr( @@ -471,7 +530,9 @@ def codegen_store( ) -> ast.AST: tensor_like, dev_ptrs = stack_tensor indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask) - subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript) + subscripts_shape = SubscriptIndexing.compute_shape( + tensor_like, subscript, state + ) stack_shape = [*dev_ptrs.size()] mask_expr = StackIndexingStrategy.get_mask_expr( @@ -505,18 +566,33 @@ def has_mask(self) -> bool: @staticmethod def compute_shape( - tensor: torch.Tensor, index: list[object] + tensor: torch.Tensor, index: list[object], state: CodegenState | None = None ) -> list[int | torch.SymInt]: assert isinstance(tensor, torch.Tensor) assert isinstance(index, (list, tuple)), index input_size = collections.deque(tensor.size()) output_size = [] env = CompileEnvironment.current() + k_index = 0 for k in index: if k is None: output_size.append(1) elif isinstance(k, int): input_size.popleft() + elif ( + state is not None + and (tile_info := _get_tile_with_offset_info(k, state, k_index)) + is not None + ): + # Tensor marked as tile.index + offset + input_size.popleft() + block_id, _ = tile_info + block_size = env.block_sizes[block_id].var + if tensor.size(tensor.ndim - len(input_size) - 1) != 1: + output_size.append(block_size) + else: + output_size.append(1) + k_index += 1 elif isinstance(k, torch.SymInt): input_size.popleft() symbol = k._sympy_() @@ -527,6 +603,7 @@ def compute_shape( output_size.append(k) else: output_size.append(1) + k_index += 1 elif isinstance(k, slice): size = input_size.popleft() # Handle slices with steps @@ -537,11 +614,13 @@ def compute_shape( output_size.append(rdim.var) else: output_size.append(1) + k_index += 1 elif isinstance(k, torch.Tensor) and ( k.ndim == 1 or (len(index) == 1 and tensor.ndim == 1) ): input_size.popleft() output_size.extend(k.size()) + k_index += 1 else: raise exc.InvalidIndexingType(k) assert len(input_size) == 0, "invalid subscript" @@ -583,7 +662,7 @@ def create( output_idx = 0 index_values = [] mask_values = {} - output_size = SubscriptIndexing.compute_shape(fake_value, index) + output_size = SubscriptIndexing.compute_shape(fake_value, index, state) env = CompileEnvironment.current() dtype = env.triton_index_type() if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value): @@ -592,11 +671,29 @@ def create( def _is_size_one(size: int | torch.SymInt) -> bool: return env.known_equal(size, 1) + k_index = 0 for n, k in enumerate(index): if k is None: output_idx += 1 elif isinstance(k, int): index_values.append(repr(k)) + elif ( + tile_info := _get_tile_with_offset_info(k, state, k_index) + ) is not None: + # Tensor marked as tile.index + offset + block_id, offset = tile_info + index_var = state.codegen.index_var(block_id) + offset_expr = state.device_function.literal_expr(offset) + expand = tile_strategy.expand_str(output_size, output_idx) + i = len(index_values) + index_values.append(f"(({index_var}) + {offset_expr}){expand}") + # Use the same mask as the underlying tile + if (mask := state.codegen.mask_var(block_id)) and not _is_size_one( + fake_value.size(i) + ): + mask_values.setdefault(f"({mask}){expand}") + output_idx += 1 + k_index += 1 elif isinstance(k, torch.SymInt): symbol = k._sympy_() origin = None @@ -612,6 +709,7 @@ def _is_size_one(size: int | torch.SymInt) -> bool: ) and not _is_size_one(fake_value.size(i)): mask_values.setdefault(f"({mask}){expand}") output_idx += 1 + k_index += 1 else: # When the index is a scalar (no BlockSizeOrigin), the corresponding dim is eliminated. val = state.device_function.literal_expr(k) @@ -651,6 +749,7 @@ def _is_size_one(size: int | torch.SymInt) -> bool: else: index_values.append(f"tl.zeros([1], {dtype}){expand}") output_idx += 1 + k_index += 1 elif isinstance(k, torch.Tensor) and k.ndim == 1: expand = tile_strategy.expand_str(output_size, output_idx) ast_index = state.ast_args[1] @@ -667,6 +766,7 @@ def _is_size_one(size: int | torch.SymInt) -> bool: ) is not None: mask_values.setdefault(f"({index_var} < {original_length}){expand}") output_idx += 1 + k_index += 1 elif ( isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1 ): @@ -684,6 +784,7 @@ def _is_size_one(size: int | torch.SymInt) -> bool: mask_values.setdefault( f"({mask}){tile_strategy.expand_str(output_size, n)}" ) + k_index += 1 else: raise exc.InvalidIndexingType(type(k)) assert len(output_size) == output_idx @@ -812,9 +913,31 @@ def is_supported( # TODO(jansel): support block_ptr with extra_mask return False input_sizes = collections.deque(fake_tensor.size()) + k_index = 0 for k in index: input_size = 1 if k is None else input_sizes.popleft() - if isinstance(k, torch.SymInt): + # Check for tile+offset tensor first before other checks + if ( + isinstance(k, torch.Tensor) + and (tile_info := _get_tile_with_offset_info(k, state, k_index)) + is not None + ): + # Tensor marked as tile.index + offset - treat like TileWithOffset + block_index, _ = tile_info + try: + state.codegen.offset_var(block_index) + except NotImplementedError: + return False + loop_state = state.codegen.active_device_loops[block_index][-1] + if isinstance(loop_state, DeviceLoopState): + if not loop_state.block_id_to_info[block_index].is_end_matching( + input_size + ): + assert state.fx_node is not None + if "masked_value" in state.fx_node.meta: + return False + k_index += 1 + elif isinstance(k, torch.SymInt): symbol = k._sympy_() origin = None if isinstance(symbol, sympy.Symbol): @@ -840,10 +963,11 @@ def is_supported( # TODO(jansel): in this case we should be able to lower to block_ptr+tl.where # see test/test_loops.py::TestLoops::test_data_dependent_bounds2 return False - if isinstance(k, torch.Tensor): + k_index += 1 + elif isinstance(k, torch.Tensor): # indirect loads don't work with block_ptr return False - output_shape = SubscriptIndexing.compute_shape(fake_tensor, index) + output_shape = SubscriptIndexing.compute_shape(fake_tensor, index, state) return len(output_shape) != 0 def validate(self) -> None: @@ -861,14 +985,30 @@ def create( ) -> BlockedSubscriptIndexing: res = BlockedSubscriptIndexing( fake_value, - reshaped_size=SubscriptIndexing.compute_shape(fake_value, index), + reshaped_size=SubscriptIndexing.compute_shape(fake_value, index, state), ) + env = CompileEnvironment.current() + k_index = 0 for k in index: if k is None: pass # handled by reshaped_size elif isinstance(k, int): res.offsets.append(repr(k)) res.block_shape.append(1) + elif ( + tile_info := _get_tile_with_offset_info(k, state, k_index) + ) is not None: + # Tensor marked as tile.index + offset + if fake_value.size(len(res.offsets)) != 1: + block_id, offset = tile_info + offset_var = state.codegen.offset_var(block_id) + offset_expr = state.device_function.literal_expr(offset) + res.offsets.append(f"({offset_var} + {offset_expr})") + res.block_shape.append(env.block_sizes[block_id].var) + else: + res.offsets.append("0") + res.block_shape.append(1) + k_index += 1 elif isinstance(k, torch.SymInt): symbol = k._sympy_() origin = HostFunction.current().expr_to_origin.get(symbol) @@ -881,6 +1021,7 @@ def create( else: res.offsets.append("0") res.block_shape.append(1) + k_index += 1 else: res.offsets.append(state.device_function.literal_expr(k)) res.block_shape.append(1) @@ -894,13 +1035,13 @@ def create( ) # Full slice or slice without step if size != 1: - env = CompileEnvironment.current() rdim = env.allocate_reduction_dimension(size) res.offsets.append(state.codegen.offset_var(rdim.block_id)) res.block_shape.append(rdim.var) else: res.offsets.append("0") res.block_shape.append(1) + k_index += 1 else: raise exc.InvalidIndexingType(k) res.validate() diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index ae6bd811b..11a9fa358 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -1906,13 +1906,32 @@ def visit_BinOp(self, node: ast.BinOp) -> TypeInfo: pass else: try: + # Special case: if this is Tile + offset pattern, expand to tile.index + offset + if ( + isinstance(node.op, ast.Add) + and isinstance(left, TileIndexType) + and isinstance(right, (SymIntType, LiteralType, NumericType)) + ): + # Expand tile + offset to tile.index + offset + tile_index = left.propagate_attribute( + "index", AttributeOrigin(self.origin(), "index") + ) + return TypeInfo.from_example( + _eval_binary(node.op, tile_index.proxy(), right.proxy()), + self.origin(), + ) + return TypeInfo.from_example( _eval_binary(node.op, left_example, right_example), self.origin(), ) except exc.Base: raise - except Exception as e: + except TypeError as e: + # Re-raise as TorchOpTracingError for proper error handling in visit_AugAssign + raise exc.TorchOpTracingError(e) from e + except RuntimeError as e: + # Re-raise as TorchOpTracingError for proper error handling in visit_AugAssign raise exc.TorchOpTracingError(e) from e raise exc.TypeInferenceError( diff --git a/helion/language/ref_tile.py b/helion/language/ref_tile.py index 144d99657..e2f3cc323 100644 --- a/helion/language/ref_tile.py +++ b/helion/language/ref_tile.py @@ -13,6 +13,28 @@ from collections.abc import Callable +_ADD_OPS: set[object] = { + torch.add, + torch.Tensor.add, + torch.Tensor.add_, + torch.Tensor.__add__, + torch.Tensor.__radd__, +} +_SUB_OPS: set[object] = { + torch.sub, + torch.Tensor.sub, + torch.Tensor.sub_, + torch.Tensor.__sub__, + torch.Tensor.__rsub__, +} + +try: + _ADD_OPS.add(torch.ops.aten.add.Tensor) + _SUB_OPS.add(torch.ops.aten.sub.Tensor) +except AttributeError: # pragma: no cover - aten fallback not always defined + pass + + class RefTile(TileInterface, torch.Tensor): _slice: slice _block_size: int @@ -43,6 +65,62 @@ def __torch_function__( if func is torch.Tensor.__format__: return repr(args[0]) + if func in _ADD_OPS: + return cls._handle_add(args) + + if func in _SUB_OPS: + return cls._handle_sub(args) + + raise exc.IncorrectTileUsage(func) + + @classmethod + def _handle_add(cls, args: tuple[object, ...]) -> torch.Tensor: + tile, offset, flipped = cls._extract_tile_and_offset(args, torch.add) + return tile.index + offset if not flipped else offset + tile.index + + @classmethod + def _handle_sub(cls, args: tuple[object, ...]) -> torch.Tensor: + tile, offset, flipped = cls._extract_tile_and_offset(args, torch.sub) + return ( + tile.index - offset + if not flipped + else offset - tile.index # pragma: no cover - defensive + ) + + @classmethod + def _extract_tile_and_offset( + cls, args: tuple[object, ...], func: object + ) -> tuple[RefTile, int, bool]: + if len(args) != 2: + raise exc.IncorrectTileUsage(func) + + lhs, rhs = args + flipped = False + + if isinstance(lhs, RefTile) and cls._is_valid_offset(rhs): + tile = lhs + offset = cls._to_int(rhs, func) + elif isinstance(rhs, RefTile) and cls._is_valid_offset(lhs): + tile = rhs + offset = cls._to_int(lhs, func) + flipped = True + else: + raise exc.IncorrectTileUsage(func) + + return tile, offset, flipped + + @staticmethod + def _is_valid_offset(value: object) -> bool: + if isinstance(value, int): + return True + return bool(isinstance(value, torch.Tensor) and value.ndim == 0) + + @staticmethod + def _to_int(value: object, func: object) -> int: + if isinstance(value, int): + return value + if isinstance(value, torch.Tensor) and value.ndim == 0: + return int(value.item()) raise exc.IncorrectTileUsage(func) @classmethod diff --git a/test/test_examples.expected b/test/test_examples.expected index d6aac9b01..56ba95b8c 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1625,7 +1625,7 @@ def _helion_jagged_dense_add_2d(x_offsets, x_data, y, out, y_size_1, out_stride_ starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0) v_0 = tl.full([], 1, tl.int32) v_1 = indices_0 + v_0 - ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0) + ends = tl.load(x_offsets + (indices_0 + 1) * x_offsets_stride_0, mask_0, other=0) v_2 = ends - starts _mask_to = tl.where(mask_0, v_2, tl.full([], -9223372036854775808, tl.int64)) max_nnz = tl.cast(tl.max(_mask_to, 0), tl.int64) @@ -1798,7 +1798,7 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, out_flat_strid starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0) v_0 = tl.full([], 1, tl.int32) v_1 = indices_0 + v_0 - ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0) + ends = tl.load(x_offsets + (indices_0 + 1) * x_offsets_stride_0, mask_0, other=0) v_2 = ends - starts _mask_to = tl.where(mask_0, v_2, tl.full([], -9223372036854775808, tl.int64)) max_seq_len = tl.cast(tl.max(_mask_to, 0), tl.int64) @@ -2001,7 +2001,7 @@ def _helion_jagged_mean_kernel(x_offsets, x_feature_counts, x_flat, out, out_str starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0) v_0 = tl.full([], 1, tl.int32) v_1 = indices_0 + v_0 - ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0) + ends = tl.load(x_offsets + (indices_0 + 1) * x_offsets_stride_0, mask_0, other=0) v_2 = ends - starts _mask_to = tl.where(mask_0, v_2, tl.full([], -9223372036854775808, tl.int64)) max_nnz = tl.cast(tl.max(_mask_to, 0), tl.int64) @@ -2106,7 +2106,7 @@ def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, out_stride_0, x_flat_s starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0) v_0 = tl.full([], 1, tl.int32) v_1 = indices_0 + v_0 - ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0) + ends = tl.load(x_offsets + (indices_0 + 1) * x_offsets_stride_0, mask_0, other=0) v_2 = ends - starts _mask_to = tl.where(mask_0, v_2, tl.full([], -9223372036854775808, tl.int64)) max_seqlen = tl.cast(tl.max(_mask_to, 0), tl.int64) @@ -2248,7 +2248,7 @@ def _helion_jagged_sum_kernel(x_offsets, x_flat, out, out_stride_0, out_stride_1 starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0) v_0 = tl.full([], 1, tl.int32) v_1 = indices_0 + v_0 - ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0) + ends = tl.load(x_offsets + (indices_0 + 1) * x_offsets_stride_0, mask_0, other=0) v_2 = ends - starts _mask_to = tl.where(mask_0, v_2, tl.full([], -9223372036854775808, tl.int64)) max_nnz = tl.cast(tl.max(_mask_to, 0), tl.int64) @@ -3706,7 +3706,7 @@ def _helion_segmented_reduction_helion(input_data, indices, output, indices_stri sub = -1 + num_elements v_2 = tl.cast(sub, tl.int32) v_3 = indices_0 < v_2 - idxs_next = tl.load(indices + v_1 * indices_stride_0, mask_0 & v_3, other=0) + idxs_next = tl.load(indices + (indices_0 + 1) * indices_stride_0, mask_0 & v_3, other=0) v_4 = tl.cast(idxs, tl.float32) unsqueeze = v_4[:, None] expand = tl.broadcast_to(unsqueeze, [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) diff --git a/test/test_indexing.expected b/test/test_indexing.expected index a0e8ca6b6..adeca9668 100644 --- a/test/test_indexing.expected +++ b/test/test_indexing.expected @@ -341,7 +341,7 @@ def _helion_pairwise_add(out, x, out_size_0, out_stride_0, x_stride_0, _BLOCK_SI load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) v_0 = tl.full([], 1, tl.int32) v_1 = indices_0 + v_0 - load_1 = tl.load(x + v_1 * x_stride_0, mask_0, other=0) + load_1 = tl.load(x + (indices_0 + 1) * x_stride_0, mask_0, other=0) v_2 = load + load_1 tl.store(out + indices_0 * out_stride_0, v_2, mask_0) @@ -437,3 +437,83 @@ def arange_block_size_mul(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 64 _launcher(_helion_arange_block_size_mul, (triton.cdiv(64, _BLOCK_SIZE_0),), ones, out, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=2) return out + +--- assertExpectedJournal(TestIndexing.test_tile_with_offset_block_ptr) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_tile_offset_kernel(out, x, out_size_0, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + v_0 = tl.full([], 10, tl.int32) + v_1 = indices_0 + v_0 + load = tl.load(tl.make_block_ptr(x, [x_size_0], [x_stride_0], [offset_0 + 10], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero') + tl.store(tl.make_block_ptr(out, [out_size_0], [out_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), load, boundary_check=[0]) + +def tile_offset_kernel(x: torch.Tensor, *, _launcher=_default_launcher): + out = x.new_empty(x.size(0) - 10) + _BLOCK_SIZE_0 = 32 + _launcher(_helion_tile_offset_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, x, out.size(0), x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2) + return out + +--- assertExpectedJournal(TestIndexing.test_tile_with_offset_pointer) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_tile_offset_kernel(out, x, out_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < out_size_0 + v_0 = tl.full([], 10, tl.int32) + v_1 = indices_0 + v_0 + load = tl.load(x + (indices_0 + 10) * x_stride_0, mask_0, other=0) + tl.store(out + indices_0 * out_stride_0, load, mask_0) + +def tile_offset_kernel(x: torch.Tensor, *, _launcher=_default_launcher): + out = x.new_empty(x.size(0) - 10) + _BLOCK_SIZE_0 = 32 + _launcher(_helion_tile_offset_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, x, out.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2) + return out + +--- assertExpectedJournal(TestIndexing.test_tile_with_offset_tensor_descriptor) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +helion.runtime.set_triton_allocator() + +@triton.jit +def _helion_tile_offset_2d_kernel(out, x, out_size_0, out_size_1, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + x_desc = tl.make_tensor_descriptor(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [_BLOCK_SIZE_0, _RDIM_SIZE_1]) + out_desc = tl.make_tensor_descriptor(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [_BLOCK_SIZE_0, _RDIM_SIZE_1]) + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + v_0 = tl.full([], 10, tl.int32) + v_1 = indices_0 + v_0 + load = x_desc.load([offset_0 + 10, 0]) + out_desc.store([offset_0, 0], load) + +def tile_offset_2d_kernel(x: torch.Tensor, *, _launcher=_default_launcher): + M, N = x.size() + out = x.new_empty(M - 10, N) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_1 = triton.next_power_of_2(N) + _launcher(_helion_tile_offset_2d_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, x, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2) + return out diff --git a/test/test_indexing.py b/test/test_indexing.py index 1f3ccce96..8e9bd8b2f 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -1234,6 +1234,74 @@ def kernel( torch.testing.assert_close(src_result, expected_src) torch.testing.assert_close(dst_result, expected_dst) + def test_tile_with_offset_pointer(self): + """Test Tile+offset with pointer indexing""" + + @helion.kernel() + def tile_offset_kernel(x: torch.Tensor) -> torch.Tensor: + out = x.new_empty(x.size(0) - 10) + for tile in hl.tile(out.size(0)): + # Use tile + offset pattern + tile_offset = tile + 10 + out[tile] = x[tile_offset] + return out + + x = torch.randn([200], device=DEVICE) + code, result = code_and_output( + tile_offset_kernel, + (x,), + indexing="pointer", + block_size=32, + ) + torch.testing.assert_close(result, x[10:]) + self.assertExpectedJournal(code) + + def test_tile_with_offset_block_ptr(self): + """Test Tile+offset with block_ptr indexing""" + + @helion.kernel() + def tile_offset_kernel(x: torch.Tensor) -> torch.Tensor: + out = x.new_empty(x.size(0) - 10) + for tile in hl.tile(out.size(0)): + # Use tile + offset pattern + tile_offset = tile + 10 + out[tile] = x[tile_offset] + return out + + x = torch.randn([200], device=DEVICE) + code, result = code_and_output( + tile_offset_kernel, + (x,), + indexing="block_ptr", + block_size=32, + ) + torch.testing.assert_close(result, x[10:]) + self.assertExpectedJournal(code) + + @unittest.skipIf(not supports_tensor_descriptor(), "TensorDescriptor not supported") + def test_tile_with_offset_tensor_descriptor(self): + """Test Tile+offset with tensor_descriptor indexing for 2D tensors""" + + @helion.kernel() + def tile_offset_2d_kernel(x: torch.Tensor) -> torch.Tensor: + M, N = x.size() + out = x.new_empty(M - 10, N) + for tile_m in hl.tile(out.size(0)): + # Use tile + offset pattern + tile_offset = tile_m + 10 + out[tile_m, :] = x[tile_offset, :] + return out + + x = torch.randn([128, 64], device=DEVICE) + code, result = code_and_output( + tile_offset_2d_kernel, + (x,), + indexing="tensor_descriptor", + block_size=32, + ) + torch.testing.assert_close(result, x[10:, :]) + self.assertExpectedJournal(code) + if __name__ == "__main__": unittest.main()