From ff8d4e7c175e72915ee483dd0f05fefb8bd948e1 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Thu, 20 Nov 2025 23:11:00 -0800 Subject: [PATCH] Fix scalar broadcast bug in inductor lowering Bug: InductorLowering was incorrectly expanding scalars (0-D tensors) with [None, None] to match the max ndim of all inputs. This created broadcast shape mismatches in generated Triton code like `scale_val[None, None]` when multiplying a 2D tensor by a scalar. Fix: Skip dimension expansion for 0-D tensors (fake_val.ndim > 0 check). Triton naturally handles scalar broadcasting without explicit expansion, following standard NumPy broadcasting rules. Added regression test test_scalar_broadcast_2d() with a config known to trigger the bug (block_sizes=[2, 64], flatten_loops=[True]). --- helion/_compiler/inductor_lowering.py | 6 +- test/test_control_flow.expected | 20 +- test/test_examples.expected | 299 ++++++++++++-------------- test/test_misc.expected | 8 +- test/test_views.py | 24 +++ 5 files changed, 178 insertions(+), 179 deletions(-) diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index d223cbd51..f7f7695cf 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -358,8 +358,10 @@ def input_asts(self, ctx: LoweringContext, node: torch.fx.Node) -> list[ast.AST] def visit(n: torch.fx.Node) -> None: ast_val = cast("ast.AST", ctx.env[n]) if isinstance(fake_val := n.meta["val"], torch.Tensor): - if fake_val.ndim < ndim: - # Broadcast to force ranks to match + # Don't expand scalars (0-D tensors) - let Triton handle broadcasting naturally + # Expanding scalars with [None, None] creates incorrect broadcast shapes + if fake_val.ndim < ndim and fake_val.ndim > 0: + # Broadcast to force ranks to match (but only for non-scalar tensors) expand = ["None"] * (ndim - fake_val.ndim) + [":"] * fake_val.ndim ast_val = expr_from_string( "{tensor}[" + ", ".join(expand) + "]", tensor=ast_val diff --git a/test/test_control_flow.expected b/test/test_control_flow.expected index 9df928683..c1b10224b 100644 --- a/test/test_control_flow.expected +++ b/test/test_control_flow.expected @@ -102,20 +102,18 @@ def _helion_mul_relu_block_backward_kernel(x, y, dz, dx, dy, _BLOCK_SIZE_0: tl.c # src[test_control_flow.py:N]: relu_grad = torch.where(relu_mask, 1, 0) v_3 = tl.full([], 0, tl.int64) v_4 = tl.full([], 1, tl.int64) - v_5 = v_4[None, None] - v_6 = v_3[None, None] - v_7 = tl.where(v_2, v_5, v_6) + v_5 = tl.where(v_2, v_4, v_3) # src[test_control_flow.py:N]: dx[tile_i, tile_j] = dz_tile * relu_grad * y_tile[:, None] - v_8 = tl.cast(v_7, tl.float32) - v_9 = dz_tile * v_8 + v_6 = tl.cast(v_5, tl.float32) + v_7 = dz_tile * v_6 subscript_1 = y_tile[:, None] - v_10 = v_9 * subscript_1 - tl.store(dx + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_10, None) + v_8 = v_7 * subscript_1 + tl.store(dx + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_8, None) # src[test_control_flow.py:N]: local_dy_grad = torch.sum(dz_tile * relu_grad * x_tile, dim=1) - v_11 = tl.cast(v_7, tl.float32) - v_12 = dz_tile * v_11 - v_13 = v_12 * x_tile - local_dy_grad = tl.cast(tl.sum(v_13, 1), tl.float32) + v_9 = tl.cast(v_5, tl.float32) + v_10 = dz_tile * v_9 + v_11 = v_10 * x_tile + local_dy_grad = tl.cast(tl.sum(v_11, 1), tl.float32) # src[test_control_flow.py:N]: hl.atomic_add(dy, [tile_i], local_dy_grad) tl.atomic_add(dy + indices_0 * 1, local_dy_grad, mask=None, sem='relaxed') diff --git a/test/test_examples.expected b/test/test_examples.expected index fc2b53f4f..aa54e2ce6 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1945,18 +1945,16 @@ def _helion_grouped_gemm_jagged(group_offsets, A_packed, B, out, A_packed_stride start_copy_0_copy_0_copy_0 = start_copy_0_copy_0_copy acc_copy_0 = acc_copy # src[grouped_gemm.py:N]: a_blk = A_packed[start + tile_m.index, tile_k] - v_3 = start_copy_0_copy_0_copy_0[None] - v_4 = v_3 + indices_1 - a_blk = tl.load(A_packed + (v_4[:, None] * A_packed_stride_0 + indices_3[None, :] * A_packed_stride_1), mask_1[:, None] & mask_3[None, :], other=0) + v_3 = start_copy_0_copy_0_copy_0 + indices_1 + a_blk = tl.load(A_packed + (v_3[:, None] * A_packed_stride_0 + indices_3[None, :] * A_packed_stride_1), mask_1[:, None] & mask_3[None, :], other=0) # src[grouped_gemm.py:N]: b_blk = B[tile_k, tile_n] b_blk = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_3[:, None] & mask_2[None, :], other=0) # src[grouped_gemm.py:N]: acc = torch.addmm(acc, a_blk, b_blk) acc = tl.dot(tl.cast(a_blk, tl.bfloat16), tl.cast(b_blk, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) # src[grouped_gemm.py:N]: out[start + tile_m.index, tile_n] = acc.to(out.dtype) - v_5 = tl.cast(acc, tl.bfloat16) - v_6 = start_copy_0_copy_0[None] - v_7 = v_6 + indices_1 - tl.store(out + (v_7[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_5, mask_1[:, None] & mask_2[None, :]) + v_4 = tl.cast(acc, tl.bfloat16) + v_5 = start_copy_0_copy_0 + indices_1 + tl.store(out + (v_5[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_4, mask_1[:, None] & mask_2[None, :]) def grouped_gemm_jagged(A_packed: torch.Tensor, B: torch.Tensor, group_offsets: torch.Tensor, *, _launcher=_default_launcher): """ @@ -2117,18 +2115,15 @@ def _helion_grouped_gemm_jagged_persistent(group_offsets, A_packed, B, out, A_pa v_29 = v_28 * _BLOCK_SIZE_1_ # src[grouped_gemm.py:N]: row_idx = base_row + hl.arange(BLOCK_M) iota = tl.arange(0, _BLOCK_SIZE_0) - v_30 = v_27[None] - v_31 = v_30 + iota + v_30 = v_27 + iota # src[grouped_gemm.py:N]: col_idx = base_col + hl.arange(BLOCK_N) iota_1 = tl.arange(0, _BLOCK_SIZE_1) - v_32 = v_29[None] - v_33 = v_32 + iota_1 + v_31 = v_29 + iota_1 # src[grouped_gemm.py:N]: rows_valid = row_idx < group_end - v_34 = group_end_copy_0_copy_0_copy_0[None] - v_35 = v_31 < v_34 + v_32 = v_30 < group_end_copy_0_copy_0_copy_0 # src[grouped_gemm.py:N]: cols_valid = col_idx < N - v_36 = tl.cast(N, tl.int32) - v_37 = v_33 < v_36 + v_33 = tl.cast(N, tl.int32) + v_34 = v_31 < v_33 # src[grouped_gemm.py:N]: acc = hl.zeros([BLOCK_M, BLOCK_N], dtype=torch.float32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) # src[grouped_gemm.py:N]: for k_tile in hl.tile(K): @@ -2137,43 +2132,43 @@ def _helion_grouped_gemm_jagged_persistent(group_offsets, A_packed, B, out, A_pa for offset_5 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_5): indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_5).to(tl.int32) mask_5 = indices_5 < K + v_30_copy = v_30 + v_32_copy = v_32 v_31_copy = v_31 - v_35_copy = v_35 - v_33_copy = v_33 - v_37_copy = v_37 + v_34_copy = v_34 acc_copy = acc + v_30_copy_0 = v_30_copy + v_32_copy_0 = v_32_copy v_31_copy_0 = v_31_copy - v_35_copy_0 = v_35_copy - v_33_copy_0 = v_33_copy - v_37_copy_0 = v_37_copy + v_34_copy_0 = v_34_copy acc_copy_0 = acc_copy # src[grouped_gemm.py:N]: extra_mask=rows_valid[:, None], - subscript = v_35_copy_0[:, None] + subscript = v_32_copy_0[:, None] # src[grouped_gemm.py:N]: a_blk = hl.load( # src[grouped_gemm.py:N]: A_packed, # src[grouped_gemm.py:N]: [row_idx, k_idx], # src[grouped_gemm.py:N-N]: ... - a_blk = tl.load(A_packed + (v_31_copy_0[:, None] * A_packed_stride_0 + indices_5[None, :] * A_packed_stride_1), mask_5[None, :] & subscript, other=0) + a_blk = tl.load(A_packed + (v_30_copy_0[:, None] * A_packed_stride_0 + indices_5[None, :] * A_packed_stride_1), mask_5[None, :] & subscript, other=0) # src[grouped_gemm.py:N]: extra_mask=cols_valid[None, :], - subscript_1 = v_37_copy_0[None, :] + subscript_1 = v_34_copy_0[None, :] # src[grouped_gemm.py:N]: b_blk = hl.load( # src[grouped_gemm.py:N]: B, # src[grouped_gemm.py:N]: [k_idx, col_idx], # src[grouped_gemm.py:N-N]: ... - b_blk = tl.load(B + (indices_5[:, None] * B_stride_0 + v_33_copy_0[None, :] * B_stride_1), mask_5[:, None] & subscript_1, other=0) + b_blk = tl.load(B + (indices_5[:, None] * B_stride_0 + v_31_copy_0[None, :] * B_stride_1), mask_5[:, None] & subscript_1, other=0) # src[grouped_gemm.py:N]: acc = torch.addmm(acc, a_blk, b_blk) acc = tl.dot(tl.cast(a_blk, tl.bfloat16), tl.cast(b_blk, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) # src[grouped_gemm.py:N]: valid_2d = rows_valid[:, None] & cols_valid[None, :] - subscript_2 = v_35[:, None] - subscript_3 = v_37[None, :] - v_38 = subscript_2 & subscript_3 + subscript_2 = v_32[:, None] + subscript_3 = v_34[None, :] + v_35 = subscript_2 & subscript_3 # src[grouped_gemm.py:N]: acc.to(out.dtype), - v_39 = tl.cast(acc, tl.bfloat16) + v_36 = tl.cast(acc, tl.bfloat16) # src[grouped_gemm.py:N]: hl.store( # src[grouped_gemm.py:N]: out, # src[grouped_gemm.py:N]: [row_idx, col_idx], # src[grouped_gemm.py:N-N]: ... - tl.store(out + (v_31[:, None] * out_stride_0 + v_33[None, :] * out_stride_1), v_39, v_38) + tl.store(out + (v_30[:, None] * out_stride_0 + v_31[None, :] * out_stride_1), v_36, v_35) def grouped_gemm_jagged_persistent(A_packed: torch.Tensor, B: torch.Tensor, group_offsets: torch.Tensor, *, _launcher=_default_launcher): """ @@ -2844,14 +2839,12 @@ def _helion__helion_jagged_attention_kernel(seq_offsets, q, k, v, out, max_seq_l v_0_copy_0 = v_0_copy starts_copy_0 = starts_copy # src[jagged_hstu_attn.py:N]: mask_q = tile_q.index < seq_len - v_2 = v_0_copy_0[None] - v_3 = tl.cast(v_2, tl.int32) - v_4 = indices_2 < v_3 + v_2 = tl.cast(v_0_copy_0, tl.int32) + v_3 = indices_2 < v_2 # src[jagged_hstu_attn.py:N]: q_blk = q[tile_q.index + starts, tile_h.begin, :] - v_5 = starts_copy_0[None] - v_6 = tl.cast(v_5, tl.int32) - v_7 = indices_2 + v_6 - q_blk = tl.load(q + (v_7[:, None] * 256 + offset_1 * 32 + indices_5[None, :] * 1), mask_2[:, None], other=0) + v_4 = tl.cast(starts_copy_0, tl.int32) + v_5 = indices_2 + v_4 + q_blk = tl.load(q + (v_5[:, None] * 256 + offset_1 * 32 + indices_5[None, :] * 1), mask_2[:, None], other=0) # src[jagged_hstu_attn.py:N]: acc = hl.zeros([tile_q, dimV], dtype=torch.float32) acc = tl.full([_BLOCK_SIZE_2, 32], 0.0, tl.float32) # src[jagged_hstu_attn.py:N]: for tile_kv in hl.tile(0, tile_q.end, block_size=None): @@ -2866,73 +2859,68 @@ def _helion__helion_jagged_attention_kernel(seq_offsets, q, k, v, out, max_seq_l v_0_copy_0_copy = v_0_copy_0 starts_copy_0_copy = starts_copy_0 q_blk_copy = q_blk - v_4_copy = v_4 + v_3_copy = v_3 acc_copy = acc v_0_copy_0_copy_0 = v_0_copy_0_copy starts_copy_0_copy_0 = starts_copy_0_copy q_blk_copy_0 = q_blk_copy - v_4_copy_0 = v_4_copy + v_3_copy_0 = v_3_copy acc_copy_0 = acc_copy # src[jagged_hstu_attn.py:N]: mask_kv = tile_kv.index < seq_len - v_8 = v_0_copy_0_copy_0[None] - v_9 = tl.cast(v_8, tl.int32) - v_10 = indices_3 < v_9 + v_6 = tl.cast(v_0_copy_0_copy_0, tl.int32) + v_7 = indices_3 < v_6 # src[jagged_hstu_attn.py:N]: k_blk = k[tile_kv.index + starts, tile_h.begin, :] - v_11 = starts_copy_0_copy_0[None] - v_12 = tl.cast(v_11, tl.int32) - v_13 = indices_3 + v_12 - k_blk = tl.load(k + (v_13[:, None] * 256 + offset_1 * 32 + indices_5[None, :] * 1), mask_4[:, None], other=0) + v_8 = tl.cast(starts_copy_0_copy_0, tl.int32) + v_9 = indices_3 + v_8 + k_blk = tl.load(k + (v_9[:, None] * 256 + offset_1 * 32 + indices_5[None, :] * 1), mask_4[:, None], other=0) # src[jagged_hstu_attn.py:N]: v_blk = v[tile_kv.index + starts, tile_h.begin, :] - v_14 = starts_copy_0_copy_0[None] - v_15 = tl.cast(v_14, tl.int32) - v_16 = indices_3 + v_15 - v_blk = tl.load(v + (v_16[:, None] * 256 + offset_1 * 32 + indices_5[None, :] * 1), mask_4[:, None], other=0) + v_10 = tl.cast(starts_copy_0_copy_0, tl.int32) + v_11 = indices_3 + v_10 + v_blk = tl.load(v + (v_11[:, None] * 256 + offset_1 * 32 + indices_5[None, :] * 1), mask_4[:, None], other=0) # src[jagged_hstu_attn.py:N]: torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha) permute = tl.permute(k_blk, [1, 0]) mm = tl.cast(tl.dot(tl.cast(q_blk_copy_0, tl.bfloat16), tl.cast(permute, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) - v_17 = tl.cast(alpha, tl.bfloat16) - v_18 = mm * v_17 - v_19 = tl.cast(v_18, tl.float32) - v_20 = tl.sigmoid(tl.cast(v_19, tl.float32)) - v_21 = v_19 * v_20 - v_22 = tl.cast(v_21, tl.bfloat16) + v_12 = tl.cast(alpha, tl.bfloat16) + v_13 = mm * v_12 + v_14 = tl.cast(v_13, tl.float32) + v_15 = tl.sigmoid(tl.cast(v_14, tl.float32)) + v_16 = v_14 * v_15 + v_17 = tl.cast(v_16, tl.bfloat16) # src[jagged_hstu_attn.py:N]: torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha) # src[jagged_hstu_attn.py:N]: * scale - v_23 = tl.cast(scale, tl.bfloat16) - v_24 = v_22 * v_23 + v_18 = tl.cast(scale, tl.bfloat16) + v_19 = v_17 * v_18 # src[jagged_hstu_attn.py:N]: (tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0)) unsqueeze = indices_2[:, None] unsqueeze_1 = indices_3[None, :] - v_25 = unsqueeze > unsqueeze_1 + v_20 = unsqueeze > unsqueeze_1 # src[jagged_hstu_attn.py:N]: & mask_q[:, None] - subscript = v_4_copy_0[:, None] + subscript = v_3_copy_0[:, None] # src[jagged_hstu_attn.py:N]: (tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0)) # src[jagged_hstu_attn.py:N]: & mask_q[:, None] - v_26 = v_25 & subscript + v_21 = v_20 & subscript # src[jagged_hstu_attn.py:N]: & mask_kv[None, :], - subscript_1 = v_10[None, :] + subscript_1 = v_7[None, :] # src[jagged_hstu_attn.py:N]: (tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0)) # src[jagged_hstu_attn.py:N]: & mask_q[:, None] # src[jagged_hstu_attn.py:N]: & mask_kv[None, :], - v_27 = v_26 & subscript_1 + v_22 = v_21 & subscript_1 # src[jagged_hstu_attn.py:N]: scores = torch.where( # src[jagged_hstu_attn.py:N]: (tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0)) # src[jagged_hstu_attn.py:N]: & mask_q[:, None] # src[jagged_hstu_attn.py:N-N]: ... - v_28 = 0.0 - v_29 = v_28[None, None] - v_30 = tl.where(v_27, v_24, v_29) + v_23 = 0.0 + v_24 = tl.where(v_22, v_19, v_23) # src[jagged_hstu_attn.py:N]: acc += torch.matmul(scores.to(v.dtype), v_blk) - _mask_to_2 = tl.where(mask_2[:, None] & mask_4[None, :], v_30, tl.full([], 0, tl.bfloat16)) + _mask_to_2 = tl.where(mask_2[:, None] & mask_4[None, :], v_24, tl.full([], 0, tl.bfloat16)) mm_1 = tl.cast(tl.dot(tl.cast(_mask_to_2, tl.bfloat16), tl.cast(v_blk, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) - v_31 = tl.cast(mm_1, tl.float32) - acc = acc_copy_0 + v_31 + v_25 = tl.cast(mm_1, tl.float32) + acc = acc_copy_0 + v_25 # src[jagged_hstu_attn.py:N]: out[tile_q.index + starts, tile_h.begin, :] = acc.to(out.dtype) - v_33 = tl.cast(acc, tl.bfloat16) - v_34 = starts_copy_0[None] - v_35 = tl.cast(v_34, tl.int32) - v_36 = indices_2 + v_35 - tl.store(out + (v_36[:, None] * 256 + offset_1 * 32 + indices_5[None, :] * 1), v_33, mask_2[:, None]) + v_27 = tl.cast(acc, tl.bfloat16) + v_28 = tl.cast(starts_copy_0, tl.int32) + v_29 = indices_2 + v_28 + tl.store(out + (v_29[:, None] * 256 + offset_1 * 32 + indices_5[None, :] * 1), v_27, mask_2[:, None]) def _helion_jagged_attention_kernel(max_seq_len: int, alpha: float, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_offsets: torch.Tensor, *, _launcher=_default_launcher): """Helion implementation of HSTU jagged attention""" @@ -3116,23 +3104,22 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, eps, _BLOCK_SI # src[jagged_layer_norm.py:N]: x_slice.to(torch.float32) - mean_acc[:, None, None], # src[jagged_layer_norm.py:N-N]: ... v_26 = 0.0 - v_27 = v_26[None, None, None] - v_28 = tl.where(combined_mask_1, v_25, v_27) + v_27 = tl.where(combined_mask_1, v_25, v_26) # src[jagged_layer_norm.py:N]: var_sums = var_sums + (centered * centered).sum(dim=1) - v_29 = v_28 * v_28 - _mask_to = tl.where(tl.broadcast_to(mask_4[None, :, None], [_BLOCK_SIZE_0, _BLOCK_SIZE_4, _BLOCK_SIZE_3]), v_29, tl.full([], 0, tl.float32)) + v_28 = v_27 * v_27 + _mask_to = tl.where(tl.broadcast_to(mask_4[None, :, None], [_BLOCK_SIZE_0, _BLOCK_SIZE_4, _BLOCK_SIZE_3]), v_28, tl.full([], 0, tl.float32)) sum_3 = tl.cast(tl.sum(_mask_to, 1), tl.float32) var_sums = var_sums_copy_0 + sum_3 # src[jagged_layer_norm.py:N]: var_acc = var_acc + var_sums.sum(dim=1) sum_4 = tl.cast(tl.sum(var_sums, 1), tl.float32) var_acc = var_acc_copy_0 + sum_4 # src[jagged_layer_norm.py:N]: variance = var_acc / (seq_lengths_float * M) - v_32 = 8.0 - v_33 = v_13 * v_32 - v_34 = var_acc / v_33 + v_31 = 8.0 + v_32 = v_13 * v_31 + v_33 = var_acc / v_32 # src[jagged_layer_norm.py:N]: rstd = torch.rsqrt(variance + eps) - v_35 = v_34 + eps - v_36 = tl.rsqrt(v_35) + v_34 = v_33 + eps + v_35 = tl.rsqrt(v_34) # src[jagged_layer_norm.py:N]: for tile_m in hl.tile(M): # src[jagged_layer_norm.py:N]: for tile_k in hl.tile(0, max_seq_len): # src[jagged_layer_norm.py:N]: # Compute indices into x_values @@ -3143,12 +3130,12 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, eps, _BLOCK_SI starts_copy_2 = starts v_2_copy_2 = v_2 v_16_copy_1 = v_16 - v_36_copy = v_36 + v_35_copy = v_35 max_seq_len_copy_2_0 = max_seq_len_copy_2 starts_copy_2_0 = starts_copy_2 v_2_copy_2_0 = v_2_copy_2 v_16_copy_1_0 = v_16_copy_1 - v_36_copy_0 = v_36_copy + v_35_copy_0 = v_35_copy # src[jagged_layer_norm.py:N]: for tile_k in hl.tile(0, max_seq_len): # src[jagged_layer_norm.py:N]: # Compute indices into x_values # src[jagged_layer_norm.py:N]: indices = starts[:, None] + tile_k.index[None, :] @@ -3159,55 +3146,54 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, eps, _BLOCK_SI starts_copy_2_0_copy = starts_copy_2_0 v_2_copy_2_0_copy = v_2_copy_2_0 v_16_copy_1_0_copy = v_16_copy_1_0 - v_36_copy_0_copy = v_36_copy_0 + v_35_copy_0_copy = v_35_copy_0 starts_copy_2_0_copy_0 = starts_copy_2_0_copy v_2_copy_2_0_copy_0 = v_2_copy_2_0_copy v_16_copy_1_0_copy_0 = v_16_copy_1_0_copy - v_36_copy_0_copy_0 = v_36_copy_0_copy + v_35_copy_0_copy_0 = v_35_copy_0_copy # src[jagged_layer_norm.py:N]: indices = starts[:, None] + tile_k.index[None, :] subscript_13 = starts_copy_2_0_copy_0[:, None] subscript_14 = indices_6[None, :] - v_37 = tl.cast(subscript_14, tl.int64) - v_38 = subscript_13 + v_37 + v_36 = tl.cast(subscript_14, tl.int64) + v_37 = subscript_13 + v_36 # src[jagged_layer_norm.py:N]: flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :] - subscript_15 = v_38[:, :, None] - v_39 = tl.full([], 8, tl.int64) - v_40 = tl.cast(subscript_15 * v_39, tl.int64) + subscript_15 = v_37[:, :, None] + v_38 = tl.full([], 8, tl.int64) + v_39 = tl.cast(subscript_15 * v_38, tl.int64) subscript_16 = indices_5[None, None, :] - v_41 = tl.cast(subscript_16, tl.int64) - v_42 = v_40 + v_41 + v_40 = tl.cast(subscript_16, tl.int64) + v_41 = v_39 + v_40 # src[jagged_layer_norm.py:N]: row_mask = tile_k.index[None, :] < seq_lengths[:, None] subscript_17 = indices_6[None, :] subscript_18 = v_2_copy_2_0_copy_0[:, None] - v_43 = tl.cast(subscript_17, tl.int64) - v_44 = v_43 < subscript_18 + v_42 = tl.cast(subscript_17, tl.int64) + v_43 = v_42 < subscript_18 # src[jagged_layer_norm.py:N]: combined_mask = row_mask[:, :, None] - combined_mask_2 = v_44[:, :, None] + combined_mask_2 = v_43[:, :, None] # src[jagged_layer_norm.py:N]: x_slice = hl.load( # src[jagged_layer_norm.py:N]: x_flat, # src[jagged_layer_norm.py:N]: [flat_indices], # src[jagged_layer_norm.py:N-N]: ... - x_slice_2 = tl.load(x_flat + v_42 * 1, mask_6[None, :, None] & combined_mask_2, other=0) + x_slice_2 = tl.load(x_flat + v_41 * 1, mask_6[None, :, None] & combined_mask_2, other=0) # src[jagged_layer_norm.py:N]: (x_slice.to(torch.float32) - mean_acc[:, None, None]) subscript_19 = v_16_copy_1_0_copy_0[:, None, None] - v_45 = x_slice_2 - subscript_19 + v_44 = x_slice_2 - subscript_19 # src[jagged_layer_norm.py:N]: * rstd[:, None, None], - subscript_20 = v_36_copy_0_copy_0[:, None, None] + subscript_20 = v_35_copy_0_copy_0[:, None, None] # src[jagged_layer_norm.py:N]: (x_slice.to(torch.float32) - mean_acc[:, None, None]) # src[jagged_layer_norm.py:N]: * rstd[:, None, None], - v_46 = v_45 * subscript_20 + v_45 = v_44 * subscript_20 # src[jagged_layer_norm.py:N]: normalized = torch.where( # src[jagged_layer_norm.py:N]: combined_mask, # src[jagged_layer_norm.py:N]: (x_slice.to(torch.float32) - mean_acc[:, None, None]) # src[jagged_layer_norm.py:N-N]: ... - v_47 = 0.0 - v_48 = v_47[None, None, None] - v_49 = tl.where(combined_mask_2, v_46, v_48) + v_46 = 0.0 + v_47 = tl.where(combined_mask_2, v_45, v_46) # src[jagged_layer_norm.py:N]: hl.store( # src[jagged_layer_norm.py:N]: out_flat, # src[jagged_layer_norm.py:N]: [flat_indices], # src[jagged_layer_norm.py:N-N]: ... - tl.store(out_flat + v_42 * 1, v_49, mask_6[None, :, None] & combined_mask_2) + tl.store(out_flat + v_41 * 1, v_47, mask_6[None, :, None] & combined_mask_2) def jagged_layer_norm_kernel(x_values: torch.Tensor, x_offsets: torch.Tensor, eps: float=1e-06, *, _launcher=_default_launcher): """ @@ -3375,13 +3361,11 @@ def _helion_jagged_mean_kernel(x_offsets, x_feature_counts, x_flat, out, out_str v_16 = nnz_expanded > v_15 v_17 = row_sums / nnz_expanded v_18 = 0.0 - v_19 = v_18[None, None] - v_20 = tl.where(v_16, v_17, v_19) + v_19 = tl.where(v_16, v_17, v_18) # src[jagged_mean.py:N]: out[tile_b, tile_m] = torch.where(feature_valid, result, 0.0) - v_21 = 0.0 - v_22 = v_21[None, None] - v_23 = tl.where(v_4, v_20, v_22) - tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * 1), v_23, mask_1[None, :]) + v_20 = 0.0 + v_21 = tl.where(v_4, v_19, v_20) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * 1), v_21, mask_1[None, :]) def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, max_M: int, *, _launcher=_default_launcher): """ @@ -3512,37 +3496,35 @@ def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, _BLOCK_SIZE_0: tl.cons x_slice = tl.load(x_flat + v_8 * 1, mask_2[None, :, None] & v_13, other=0) # src[jagged_softmax.py:N]: slice_max = torch.where(combined_mask, x_slice, float("-inf")).amax( v_14 = float('-inf') - v_15 = v_14[None, None, None] - v_16 = tl.where(v_13, x_slice, v_15) + v_15 = tl.where(v_13, x_slice, v_14) # src[jagged_softmax.py:N]: slice_max = torch.where(combined_mask, x_slice, float("-inf")).amax( # src[jagged_softmax.py:N]: dim=1 # src[jagged_softmax.py:N]: ) - _mask_to = tl.where(tl.broadcast_to(mask_2[None, :, None], [_BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1]), v_16, tl.full([], float('-inf'), tl.float32)) + _mask_to = tl.where(tl.broadcast_to(mask_2[None, :, None], [_BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1]), v_15, tl.full([], float('-inf'), tl.float32)) slice_max = tl.cast(tl.max(_mask_to, 1), tl.float32) # src[jagged_softmax.py:N]: block_new_max = torch.maximum(block_max, slice_max) block_new_max = triton_helpers.maximum(block_max_copy_0, slice_max) # src[jagged_softmax.py:N]: block_L *= torch.exp(block_max - block_new_max) - v_18 = block_max_copy_0 - block_new_max - v_19 = libdevice.exp(v_18) - v_20 = block_L_copy_0 * v_19 + v_17 = block_max_copy_0 - block_new_max + v_18 = libdevice.exp(v_17) + v_19 = block_L_copy_0 * v_18 # src[jagged_softmax.py:N]: x_slice - block_new_max[:, None, :], subscript_8 = block_new_max[:, None, :] - v_21 = x_slice - subscript_8 + v_20 = x_slice - subscript_8 # src[jagged_softmax.py:N]: torch.where( # src[jagged_softmax.py:N]: combined_mask, # src[jagged_softmax.py:N]: x_slice - block_new_max[:, None, :], # src[jagged_softmax.py:N-N]: ... - v_22 = float('-inf') - v_23 = v_22[None, None, None] - v_24 = tl.where(v_13, v_21, v_23) + v_21 = float('-inf') + v_22 = tl.where(v_13, v_20, v_21) # src[jagged_softmax.py:N]: block_L += torch.exp( # src[jagged_softmax.py:N]: torch.where( # src[jagged_softmax.py:N]: combined_mask, # src[jagged_softmax.py:N-N]: ... - v_25 = libdevice.exp(v_24) - _mask_to_1 = tl.where(tl.broadcast_to(mask_2[None, :, None], [_BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1]), v_25, tl.full([], 0, tl.float32)) + v_23 = libdevice.exp(v_22) + _mask_to_1 = tl.where(tl.broadcast_to(mask_2[None, :, None], [_BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1]), v_23, tl.full([], 0, tl.float32)) sum_1 = tl.cast(tl.sum(_mask_to_1, 1), tl.float32) - block_L = v_20 + sum_1 + block_L = v_19 + sum_1 # src[jagged_softmax.py:N]: block_max = block_new_max block_max = block_new_max # src[jagged_softmax.py:N]: for tile_k in hl.tile(max_seqlen): @@ -3563,42 +3545,42 @@ def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, _BLOCK_SIZE_0: tl.cons # src[jagged_softmax.py:N]: base_indices = starts[:, None] + tile_k.index[None, :] subscript_9 = starts_copy_0_copy_1_0[:, None] subscript_10 = indices_3[None, :] - v_27 = tl.cast(subscript_10, tl.int64) - v_28 = subscript_9 + v_27 + v_25 = tl.cast(subscript_10, tl.int64) + v_26 = subscript_9 + v_25 # src[jagged_softmax.py:N]: base_indices[:, :, None] * M + tile_m.index[None, None, :] - subscript_11 = v_28[:, :, None] - v_29 = tl.full([], 8, tl.int64) - v_30 = tl.cast(subscript_11 * v_29, tl.int64) + subscript_11 = v_26[:, :, None] + v_27 = tl.full([], 8, tl.int64) + v_28 = tl.cast(subscript_11 * v_27, tl.int64) subscript_12 = indices_1[None, None, :] - v_31 = tl.cast(subscript_12, tl.int64) - v_32 = v_30 + v_31 + v_29 = tl.cast(subscript_12, tl.int64) + v_30 = v_28 + v_29 # src[jagged_softmax.py:N]: row_mask = tile_k.index[None, :] < seqlens[:, None] subscript_13 = indices_3[None, :] subscript_14 = v_2_copy_0_copy_1_0[:, None] - v_33 = tl.cast(subscript_13, tl.int64) - v_34 = v_33 < subscript_14 + v_31 = tl.cast(subscript_13, tl.int64) + v_32 = v_31 < subscript_14 # src[jagged_softmax.py:N]: combined_mask = row_mask[:, :, None] & (tile_m.index < M)[None, None, :] - subscript_15 = v_34[:, :, None] - v_35 = tl.full([], 8, tl.int32) - v_36 = indices_1 < v_35 - subscript_16 = v_36[None, None, :] - v_37 = subscript_15 & subscript_16 + subscript_15 = v_32[:, :, None] + v_33 = tl.full([], 8, tl.int32) + v_34 = indices_1 < v_33 + subscript_16 = v_34[None, None, :] + v_35 = subscript_15 & subscript_16 # src[jagged_softmax.py:N]: x_slice = hl.load( # src[jagged_softmax.py:N]: x_flat, # src[jagged_softmax.py:N]: [flat_indices], # src[jagged_softmax.py:N-N]: ... - x_slice_1 = tl.load(x_flat + v_32 * 1, mask_3[None, :, None] & v_37, other=0) + x_slice_1 = tl.load(x_flat + v_30 * 1, mask_3[None, :, None] & v_35, other=0) # src[jagged_softmax.py:N]: torch.exp(x_slice - block_max[:, None, :]) / block_L[:, None, :] subscript_17 = block_max_copy_1_0[:, None, :] - v_38 = x_slice_1 - subscript_17 - v_39 = libdevice.exp(v_38) + v_36 = x_slice_1 - subscript_17 + v_37 = libdevice.exp(v_36) subscript_18 = block_L_copy_1_0[:, None, :] - v_40 = v_39 / subscript_18 + v_38 = v_37 / subscript_18 # src[jagged_softmax.py:N]: hl.store( # src[jagged_softmax.py:N]: out, # src[jagged_softmax.py:N]: [flat_indices], # src[jagged_softmax.py:N-N]: ... - tl.store(out + v_32 * 1, v_40, mask_3[None, :, None] & v_37) + tl.store(out + v_30 * 1, v_38, mask_3[None, :, None] & v_35) def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launcher=_default_launcher): """ @@ -4587,10 +4569,9 @@ def _helion_low_mem_dropout(x_flat, out_flat, out_flat_stride_0, x_flat_stride_0 v_1 = xi * scale # src[low_mem_dropout.py:N]: yi = torch.where(keep, yscaled, 0.0) v_2 = 0.0 - v_3 = v_2[None] - v_4 = tl.where(v_0, v_1, v_3) + v_3 = tl.where(v_0, v_1, v_2) # src[low_mem_dropout.py:N]: out_flat[tidx] = yi.to(x.dtype) - tl.store(out_flat + indices_0 * out_flat_stride_0, v_4, mask_0) + tl.store(out_flat + indices_0 * out_flat_stride_0, v_3, mask_0) def low_mem_dropout(p: float, x: torch.Tensor, seed: int, *, _launcher=_default_launcher): """ @@ -5165,19 +5146,16 @@ def _helion_moe_matmul_ogs(expert_token_offsets, expert_token_counts, sorted_to_ num_tokens_copy_0_copy_0 = num_tokens_copy_0_copy start_copy_0_copy_0 = start_copy_0_copy # src[moe_matmul_ogs.py:N]: token_valid = local_token_offsets < num_tokens - v_2 = num_tokens_copy_0_copy_0[None] - v_3 = indices_1 < v_2 + v_2 = indices_1 < num_tokens_copy_0_copy_0 # src[moe_matmul_ogs.py:N]: local_token_offsets_valid = torch.where( # src[moe_matmul_ogs.py:N]: token_valid, local_token_offsets, 0 # src[moe_matmul_ogs.py:N]: ) - v_4 = tl.full([], 0, tl.int32) - v_5 = v_4[None] - v_6 = tl.where(v_3, indices_1, v_5) + v_3 = tl.full([], 0, tl.int32) + v_4 = tl.where(v_2, indices_1, v_3) # src[moe_matmul_ogs.py:N]: expert_sorted_token_indices = start + local_token_offsets_valid - v_7 = start_copy_0_copy_0[None] - v_8 = v_7 + v_6 + v_5 = start_copy_0_copy_0 + v_4 # src[moe_matmul_ogs.py:N]: expert_sorted_token_indices.squeeze(0) - squeeze = tl.reshape(v_8, [_BLOCK_SIZE_1]) + squeeze = tl.reshape(v_5, [_BLOCK_SIZE_1]) # src[moe_matmul_ogs.py:N]: expert_orig_token_indices = sorted_to_orig_token_idx[ # src[moe_matmul_ogs.py:N]: expert_sorted_token_indices.squeeze(0) # src[moe_matmul_ogs.py:N]: ] @@ -5204,15 +5182,15 @@ def _helion_moe_matmul_ogs(expert_token_offsets, expert_token_counts, sorted_to_ # src[moe_matmul_ogs.py:N]: existing_values = C[expert_orig_token_indices, tile_n] existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0) # src[moe_matmul_ogs.py:N]: mask_2d = token_valid.view(block_T, 1).expand(block_T, block_N) - view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1]) + view = tl.reshape(v_2, [_BLOCK_SIZE_1, 1]) mask_2d = tl.broadcast_to(view, [_BLOCK_SIZE_1, _BLOCK_SIZE_2]) # src[moe_matmul_ogs.py:N]: mask_2d, acc.to(C.dtype), existing_values - v_9 = tl.cast(acc, tl.float16) + v_6 = tl.cast(acc, tl.float16) # src[moe_matmul_ogs.py:N]: C[expert_orig_token_indices, tile_n] = torch.where( # src[moe_matmul_ogs.py:N]: mask_2d, acc.to(C.dtype), existing_values # src[moe_matmul_ogs.py:N]: ) - v_10 = tl.where(mask_2d, v_9, existing_values) - tl.store(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_10, mask_1[:, None] & mask_2[None, :]) + v_7 = tl.where(mask_2d, v_6, existing_values) + tl.store(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_7, mask_1[:, None] & mask_2[None, :]) def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch.Tensor, expert_token_offsets: torch.Tensor, sorted_to_orig_token_idx: torch.Tensor, max_T_per_expert: int, *, _launcher=_default_launcher): """ @@ -5514,10 +5492,9 @@ def _helion_segmented_reduction_helion(input_data, indices, output, _BLOCK_SIZE_ # src[segment_reduction.py:N]: segment_vals = torch.where(mask.unsqueeze(1), out_vals, 0.0) unsqueeze_1 = v_18[:, None] v_19 = 0.0 - v_20 = v_19[None, None] - v_21 = tl.where(unsqueeze_1, out_vals, v_20) + v_20 = tl.where(unsqueeze_1, out_vals, v_19) # src[segment_reduction.py:N]: hl.atomic_add(output, [idxs, tile_f], segment_vals) - tl.atomic_add(output + (idxs[:, None] * 32 + indices_1[None, :] * 1), v_21, mask=mask_0[:, None], sem='relaxed') + tl.atomic_add(output + (idxs[:, None] * 32 + indices_1[None, :] * 1), v_20, mask=mask_0[:, None], sem='relaxed') def segmented_reduction_helion(indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int, *, _launcher=_default_launcher): """ diff --git a/test/test_misc.expected b/test/test_misc.expected index f8109a64a..c89b6923b 100644 --- a/test/test_misc.expected +++ b/test/test_misc.expected @@ -287,11 +287,9 @@ def _helion_test_tile_block_size_usage(out, _BLOCK_SIZE_0: tl.constexpr): # src[test_misc.py:N]: out[tile] = torch.where(mask, 1, 0) v_12 = tl.full([], 0, tl.int64) v_13 = tl.full([], 1, tl.int64) - v_14 = v_13[None] - v_15 = v_12[None] - v_16 = tl.where(v_11, v_14, v_15) - v_17 = tl.cast(v_16, tl.int32) - tl.store(out + indices_0 * 1, v_17, None) + v_14 = tl.where(v_11, v_13, v_12) + v_15 = tl.cast(v_14, tl.int32) + tl.store(out + indices_0 * 1, v_15, None) def test_tile_block_size_usage(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_misc.py:N]: out = torch.zeros_like(x, dtype=torch.int32) diff --git a/test/test_views.py b/test/test_views.py index 7b0c1a27e..d91c7a6da 100644 --- a/test/test_views.py +++ b/test/test_views.py @@ -246,6 +246,30 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected) self.assertIn("tl.join", code) + def test_scalar_broadcast_2d(self): + """Test that scalars broadcast correctly with 2D tensors.""" + + @helion.kernel( + config=helion.Config( + block_sizes=[2, 64], + flatten_loops=[True], + indexing=["pointer", "pointer", "tensor_descriptor"], + ) + ) + def scalar_multiply(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + m, n = x.shape + out = torch.empty_like(x) + for tile_idx in hl.tile(out.shape): + scale_val = hl.load(scale, [0]) + out[tile_idx] = x[tile_idx] * scale_val + return out + + input_tensor = torch.randn([4, 128], device=DEVICE) + scale_tensor = torch.tensor([2.0], device=DEVICE) + result = scalar_multiply(input_tensor, scale_tensor) + expected = input_tensor * scale_tensor[0] + torch.testing.assert_close(result, expected) + def test_reshape_input_types(self): @helion.kernel(static_shapes=True) def reshape_reduction_dim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: