Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import contextlib
import dataclasses
import functools
import itertools
import operator
import re
import textwrap
Expand Down Expand Up @@ -450,7 +451,18 @@ def _body(self, body: list[ast.stmt]) -> None:
self.visit(stmt)

def visit_BinOp(self, node: ast.BinOp) -> object:
return _eval_binary(node.op, self.visit(node.left), self.visit(node.right))
left = self.visit(node.left)
right = self.visit(node.right)
# Special handling for Tile + offset: expand to tile.index + offset
# and mark with metadata for indexing strategies to recognize
if (
isinstance(node.op, ast.Add)
and isinstance(left, Tile)
and isinstance(right, (int, torch.SymInt))
):
# Implicitly expand to tile.index + offset
left = hl.tile_index(left)
return _eval_binary(node.op, left, right)

def visit_UnaryOp(self, node: ast.UnaryOp) -> object:
return _eval_unary(node.op, self.visit(node.operand))
Expand Down Expand Up @@ -1128,6 +1140,7 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
prepare_graph_lowerings(graph.graph)
for graph in device_ir.graphs:
validate_host_tensor_usage(graph.graph)
add_tile_with_offset_metadata(graph)
remove_unnecessary_tile_index(graph.graph)
remove_unnecessary_masking(graph.graph)
device_ir.build_rolled_reductions()
Expand Down Expand Up @@ -1193,6 +1206,68 @@ def validate_host_tensor_usage(graph: torch.fx.Graph) -> None:
raise exc.HostTensorDirectUsage(scalar_tensor_name, op_name)


def add_tile_with_offset_metadata(graph_info: GraphInfo) -> None:
"""
Recognize tile.index + offset patterns and add metadata to enable tensor descriptor indexing.

This pass identifies FX nodes that represent `tile.index + offset` (where offset is an
integer or SymInt), and adds the `tile_with_offset` metadata to those nodes so that
indexing strategies can generate efficient code (e.g., tensor descriptors) for them.
"""
graph = graph_info.graph
env = CompileEnvironment.current()

for node in itertools.chain(
graph.find_nodes(op="call_function", target=operator.add),
graph.find_nodes(op="call_function", target=torch.ops.aten.add.Tensor),
):
# Check if this is tile.index + offset pattern
# args[0] should be tile_index result, args[1] should be int/SymInt
if len(node.args) != 2 and not node.kwargs:
continue
left_arg, right_arg = node.args

# Check if left argument is a tile_index call
if (
not isinstance(left_arg, torch.fx.Node)
or left_arg.op != "call_function"
or left_arg.target != hl.tile_index
):
continue

# Check if right argument is an integer offset
# It could be a constant, SymInt node, or another value
# We accept int, SymInt, or nodes that represent them
offset = None
if isinstance(right_arg, (int, torch.SymInt)):
offset = right_arg
elif isinstance(right_arg, torch.fx.Node):
# Check the node's metadata for the value
val = right_arg.meta.get("val")
if isinstance(val, (int, torch.SymInt)):
offset = val

if offset is None:
continue

# Extract the block_id from the tile_index call
tile_arg = left_arg.args[0]
block_id = None
if isinstance(tile_arg, torch.fx.Node) and isinstance(
tile_arg.meta["val"], torch.SymInt
):
block_id = env.get_block_id(tile_arg.meta["val"])

if block_id is None:
continue

# Add metadata to mark this as a tile+offset node
node.meta["tile_with_offset"] = {
"block_id": block_id,
"offset": offset,
}


def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
"""
Remove unnecessary tile_index nodes from the graph.
Expand Down
Loading
Loading