Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,10 @@ def input_asts(self, ctx: LoweringContext, node: torch.fx.Node) -> list[ast.AST]
def visit(n: torch.fx.Node) -> None:
ast_val = cast("ast.AST", ctx.env[n])
if isinstance(fake_val := n.meta["val"], torch.Tensor):
if fake_val.ndim < ndim:
# Broadcast to force ranks to match
# Don't expand scalars (0-D tensors) - let Triton handle broadcasting naturally
# Expanding scalars with [None, None] creates incorrect broadcast shapes
if fake_val.ndim < ndim and fake_val.ndim > 0:
# Broadcast to force ranks to match (but only for non-scalar tensors)
expand = ["None"] * (ndim - fake_val.ndim) + [":"] * fake_val.ndim
ast_val = expr_from_string(
"{tensor}[" + ", ".join(expand) + "]", tensor=ast_val
Expand Down
20 changes: 9 additions & 11 deletions test/test_control_flow.expected
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,18 @@ def _helion_mul_relu_block_backward_kernel(x, y, dz, dx, dy, _BLOCK_SIZE_0: tl.c
# src[test_control_flow.py:N]: relu_grad = torch.where(relu_mask, 1, 0)
v_3 = tl.full([], 0, tl.int64)
v_4 = tl.full([], 1, tl.int64)
v_5 = v_4[None, None]
v_6 = v_3[None, None]
v_7 = tl.where(v_2, v_5, v_6)
v_5 = tl.where(v_2, v_4, v_3)
# src[test_control_flow.py:N]: dx[tile_i, tile_j] = dz_tile * relu_grad * y_tile[:, None]
v_8 = tl.cast(v_7, tl.float32)
v_9 = dz_tile * v_8
v_6 = tl.cast(v_5, tl.float32)
v_7 = dz_tile * v_6
subscript_1 = y_tile[:, None]
v_10 = v_9 * subscript_1
tl.store(dx + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_10, None)
v_8 = v_7 * subscript_1
tl.store(dx + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_8, None)
# src[test_control_flow.py:N]: local_dy_grad = torch.sum(dz_tile * relu_grad * x_tile, dim=1)
v_11 = tl.cast(v_7, tl.float32)
v_12 = dz_tile * v_11
v_13 = v_12 * x_tile
local_dy_grad = tl.cast(tl.sum(v_13, 1), tl.float32)
v_9 = tl.cast(v_5, tl.float32)
v_10 = dz_tile * v_9
v_11 = v_10 * x_tile
local_dy_grad = tl.cast(tl.sum(v_11, 1), tl.float32)
# src[test_control_flow.py:N]: hl.atomic_add(dy, [tile_i], local_dy_grad)
tl.atomic_add(dy + indices_0 * 1, local_dy_grad, mask=None, sem='relaxed')

Expand Down
Loading
Loading