From bce71664dc526d35b1e2b8eb8878061cc64f7d85 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 3 Oct 2025 16:25:27 -0700 Subject: [PATCH 1/2] test --- test/test_constexpr.expected | 8 +++--- test/test_examples.expected | 25 +++++++++-------- test/test_loops.expected | 29 +++++++++++--------- test/test_loops.py | 3 +-- test/test_matmul.expected | 52 ++++++++++++++++++++++++++++++++++++ test/test_matmul.py | 36 +++++++++++++++++++++++++ 6 files changed, 121 insertions(+), 32 deletions(-) diff --git a/test/test_constexpr.expected b/test/test_constexpr.expected index 1a3c225c5..29a2470e0 100644 --- a/test/test_constexpr.expected +++ b/test/test_constexpr.expected @@ -31,11 +31,11 @@ def _helion_matmul_int4_block_expr(B, A, C, _NUM_SM: tl.constexpr, _BLOCK_SIZE_2 offset_2 = pid_1 * _BLOCK_SIZE_2 indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) - for offset_3 in tl.range(0, 16, loop_unroll_factor=4, num_stages=1, disallow_acc_multi_buffer=True, flatten=True): - indices_3 = offset_3 + tl.arange(0, 1).to(tl.int32) + for offset_0 in tl.range(0, 16, loop_unroll_factor=4, num_stages=1, disallow_acc_multi_buffer=True, flatten=True): + indices_0 = offset_0 + tl.arange(0, 1).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy - packed = tl.load(B + (indices_3[:, None] * 16 + indices_2[None, :] * 1), None) + packed = tl.load(B + (indices_0[:, None] * 16 + indices_2[None, :] * 1), None) v_0 = tl.full([], 4, tl.int8) v_1 = packed << v_0 v_2 = tl.full([], 4, tl.int8) @@ -54,7 +54,7 @@ def _helion_matmul_int4_block_expr(B, A, C, _NUM_SM: tl.constexpr, _BLOCK_SIZE_2 mask_1 = broadcast_idx == 1 stacked_result = tl.where(mask_1, expanded_1, stacked_result) unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2]) - mul_5 = 2 * offset_3 + mul_5 = 2 * offset_0 iota = mul_5 + tl.arange(0, mul) a_tile = tl.load(A + (indices_1[:, None] * 32 + iota[None, :] * 1), None) dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32) diff --git a/test/test_examples.expected b/test/test_examples.expected index d7d69b397..b36008989 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1353,12 +1353,12 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri mask_2 = indices_2 < N acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) floordiv = triton_helpers.div_floor_integer(K, 2) - for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0): - indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) - mask_0 = indices_3 < floordiv + for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0): + indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) + mask_0 = indices_0 < floordiv acc_copy = acc acc_copy_0 = acc_copy - b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + b_tile = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0) v_0 = tl.full([], 4, tl.int8) v_1 = b_tile << v_0 v_2 = tl.full([], 4, tl.int8) @@ -1372,12 +1372,12 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri expanded_0 = tl.expand_dims(v_6, 1) expanded_1 = tl.expand_dims(v_7, 1) stacked_result = tl.zeros_like(expanded_0) - mask_4 = broadcast_idx == 0 - stacked_result = tl.where(mask_4, expanded_0, stacked_result) - mask_5 = broadcast_idx == 1 - stacked_result = tl.where(mask_5, expanded_1, stacked_result) + mask_3 = broadcast_idx == 0 + stacked_result = tl.where(mask_3, expanded_0, stacked_result) + mask_4 = broadcast_idx == 1 + stacked_result = tl.where(mask_4, expanded_1, stacked_result) b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2]) - mul_5 = 2 * offset_3 + mul_5 = 2 * offset_0 iota = mul_5 + tl.arange(0, mul) a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0) dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32) @@ -1406,7 +1406,6 @@ def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher): C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device) _BLOCK_SIZE_1 = 64 _BLOCK_SIZE_2 = 32 - _RDIM_SIZE_3 = triton.next_power_of_2(K) _BLOCK_SIZE_0 = 64 _launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3) return C @@ -3124,7 +3123,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size_0, grad_out_stride_0, grad_out_stride_1, grad_weight_stride_1, grad_x_stride_0, grad_x_stride_1, rsqrt_stride_0, weight_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr): +def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size_0, grad_out_stride_0, grad_out_stride_1, grad_weight_stride_0, grad_weight_stride_1, grad_x_stride_0, grad_x_stride_1, rsqrt_stride_0, weight_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) @@ -3162,7 +3161,7 @@ def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size v_18 = tl.cast(v_17, tl.float16) tl.store(grad_x + (indices_2[:, None] * grad_x_stride_0 + indices_3[None, :] * grad_x_stride_1), v_18, None) tile_id = offset_0 // _BLOCK_SIZE_0 - tl.store(grad_weight + indices_3 * grad_weight_stride_1, grad_w_m, None) + tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), grad_w_m, None) def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, rsqrt: torch.Tensor, *, _launcher=_default_launcher): """ @@ -3187,7 +3186,7 @@ def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32) _BLOCK_SIZE_0 = 32 _RDIM_SIZE_2 = 64 - _launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3) + _launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(0), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3) return (grad_x, grad_weight.sum(0).to(weight.dtype)) --- assertExpectedJournal(TestExamples.test_rms_norm_bwd_dw) diff --git a/test/test_loops.expected b/test/test_loops.expected index cdfa4523c..d4c9ac063 100644 --- a/test/test_loops.expected +++ b/test/test_loops.expected @@ -982,20 +982,22 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_kernel_fixed_block_size(loss_sum, y_true, kl_loss, loss, loss_sum_stride_0, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): +def _helion_kernel_fixed_block_size(loss_sum, y_true, kl_loss, loss, loss_sum_stride_0, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _RDIM_SIZE_3: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr): pid_0 = tl.program_id(0) offset_1 = pid_0 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) - indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) - full = tl.full([64, 64], 0.0, tl.float32) - tl.store(loss_sum + (indices_4[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), full, None) - for offset_2 in tl.range(0, 128, _BLOCK_SIZE_3): - indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) - y_true_val = tl.load(y_true + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None) - tl.store(kl_loss + (indices_1[:, None] * 128 + indices_2[None, :] * 1), y_true_val, None) - load_1 = tl.load(kl_loss + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None) - tl.atomic_add(loss_sum + (indices_1[:, None] * loss_sum_stride_0 + indices_2[None, :] * 1), load_1, mask=None, sem='relaxed') - load = tl.load(loss_sum + (indices_4[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), None) + indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + indices_6 = tl.arange(0, _RDIM_SIZE_3).to(tl.int32) + mask_3 = indices_6 < _BLOCK_SIZE_0 + full = tl.full([64, _BLOCK_SIZE_0], 0.0, tl.float32) + tl.store(loss_sum + (indices_5[:, None] * loss_sum_stride_0 + indices_6[None, :] * 1), full, mask_3[None, :]) + for offset_4 in tl.range(0, 128, _BLOCK_SIZE_0): + indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) + y_true_val = tl.load(y_true + (indices_1[:, None] * 128 + indices_4[None, :] * 1), None) + tl.store(kl_loss + (indices_1[:, None] * 128 + indices_4[None, :] * 1), y_true_val, None) + load_1 = tl.load(kl_loss + (indices_1[:, None] * 128 + indices_4[None, :] * 1), None) + tl.atomic_add(loss_sum + (indices_1[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), load_1, mask=None, sem='relaxed') + load = tl.load(loss_sum + (indices_5[:, None] * loss_sum_stride_0 + indices_6[None, :] * 1), mask_3[None, :], other=0) sum_1 = tl.cast(tl.sum(load, 1), tl.float32) tl.store(loss + indices_1 * 1, sum_1, None) @@ -1008,8 +1010,9 @@ def kernel_fixed_block_size(y_pred: torch.Tensor, y_true: torch.Tensor, *, _laun loss_sum = torch.zeros([BT_SIZE, block_size_n], dtype=torch.float32, device=y_pred.device) _BLOCK_SIZE_1 = 64 _RDIM_SIZE_2 = 64 - _BLOCK_SIZE_3 = 64 - _launcher(_helion_kernel_fixed_block_size, (triton.cdiv(64, _BLOCK_SIZE_1),), loss_sum, y_true, kl_loss, loss, loss_sum.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3) + _BLOCK_SIZE_0 = 128 + _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_0) + _launcher(_helion_kernel_fixed_block_size, (triton.cdiv(64, _BLOCK_SIZE_1),), loss_sum, y_true, kl_loss, loss, loss_sum.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_2, _RDIM_SIZE_3, _BLOCK_SIZE_0, num_warps=4, num_stages=3) return torch.sum(loss) / BT --- assertExpectedJournal(TestLoops.test_reorder_with_register_block_size) diff --git a/test/test_loops.py b/test/test_loops.py index f6c602316..c68d1b3ab 100644 --- a/test/test_loops.py +++ b/test/test_loops.py @@ -332,7 +332,6 @@ def fn(x: torch.Tensor) -> torch.Tensor: self.assertEqual(spec.min_size, 32) self.assertEqual(spec.max_size, 256) - @skipIfRefEager("Triton codegen is disabled in ref eager mode") def test_register_block_size_codegen_size_hint(self): @helion.kernel(static_shapes=True) def kernel_fixed_block_size( @@ -368,7 +367,7 @@ def kernel_fixed_block_size( code, result = code_and_output(kernel_fixed_block_size, args, block_sizes=[128]) self.assertExpectedJournal(code) - expected = y_true[:, : y_pred.size(0)].sum() / y_pred.size(0) + expected = y_true[:, :].sum() / y_pred.size(0) torch.testing.assert_close(result, expected) def test_reorder_with_register_block_size(self): diff --git a/test/test_matmul.expected b/test/test_matmul.expected index bf49a0ef0..5787ffbd9 100644 --- a/test/test_matmul.expected +++ b/test/test_matmul.expected @@ -193,6 +193,58 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] _launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestMatmul.test_matmul_packed_rhs) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_matmul_with_packed_b(A, B, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul_2: tl.constexpr): + num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_1 = pid_0 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < M + offset_2 = pid_1 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < N + acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) + floordiv = triton_helpers.div_floor_integer(K, 2) + for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0): + indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) + mask_0 = indices_0 < floordiv + acc_copy = acc + acc_copy_0 = acc_copy + mul = 2 * offset_0 + iota = mul + tl.arange(0, mul_2) + lhs = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0) + packed = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + stack_idx = tl.arange(0, 2) + broadcast_idx = stack_idx[None, :, None] + expanded_0 = tl.expand_dims(packed, 1) + expanded_1 = tl.expand_dims(packed, 1) + stacked_result = tl.zeros_like(expanded_0) + mask_3 = broadcast_idx == 0 + stacked_result = tl.where(mask_3, expanded_0, stacked_result) + mask_4 = broadcast_idx == 1 + stacked_result = tl.where(mask_4, expanded_1, stacked_result) + rhs = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2]) + acc = tl.dot(tl.cast(lhs, tl.float32), tl.cast(rhs, tl.float32), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), acc, mask_1[:, None] & mask_2[None, :]) + +def matmul_with_packed_b(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, *, _launcher=_default_launcher): + M, K = A.shape + _, N = B.shape + _BLOCK_SIZE_1 = 16 + _BLOCK_SIZE_2 = 16 + _BLOCK_SIZE_0 = 16 + _launcher(_helion_matmul_with_packed_b, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), A, B, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3) + --- assertExpectedJournal(TestMatmul.test_matmul_split_k) from __future__ import annotations diff --git a/test/test_matmul.py b/test/test_matmul.py index d135fbd27..f8e1cc4cc 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -224,6 +224,42 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2) self.assertExpectedJournal(code) + def test_matmul_packed_rhs(self): + @helion.kernel(static_shapes=False) + def matmul_with_packed_b( + A: torch.Tensor, B: torch.Tensor, C: torch.Tensor + ) -> None: + M, K = A.shape + _, N = B.shape + + block_size_k = hl.register_block_size(K // 2) + + for tile_m, tile_n in hl.tile([M, N]): + acc = hl.zeros([tile_m, tile_n], dtype=A.dtype) + + for tile_k in hl.tile(K // 2, block_size=block_size_k): + lhs = A[ + tile_m, + tile_k.begin * 2 : tile_k.begin * 2 + tile_k.block_size * 2, + ] + packed = B[tile_k, tile_n] + rhs = torch.stack([packed, packed], dim=1).reshape( + tile_k.block_size * 2, tile_n.block_size + ) + acc = torch.addmm(acc, lhs, rhs) + + C[tile_m, tile_n] = acc + + M, K, N = 32, 64, 32 + A = torch.randn(M, K, device=DEVICE) + B = torch.randn(K // 2, N, device=DEVICE) + C = torch.empty(M, N, device=DEVICE) + code, _ = code_and_output(matmul_with_packed_b, (A, B, C)) + B_unpacked = torch.stack([B, B], dim=1).reshape(K, N) + expected = A @ B_unpacked + torch.testing.assert_close(C, expected, atol=5e-2, rtol=1e-3) + self.assertExpectedJournal(code) + if __name__ == "__main__": unittest.main() From 7ecaeb02fe334ee9b724dd3808b79435b1b9e5b8 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 7 Oct 2025 10:42:46 -0700 Subject: [PATCH 2/2] fix --- helion/_compiler/compile_environment.py | 5 ++++- helion/_compiler/indexing_strategy.py | 10 +++++++--- helion/_compiler/type_propagation.py | 11 +++++++++++ helion/_compiler/utils.py | 4 +++- 4 files changed, 25 insertions(+), 5 deletions(-) diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index b633bc621..6f7088c2f 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -522,11 +522,14 @@ def mark_alternate_size(self, size: torch.SymInt | int | None) -> None: self.size = size if size is not None: env = CompileEnvironment.current() + # Refresh the var_to_val hint to match the resolved block size + hint = env.size_hint(size) + env.shape_env.var_to_val[self.symbol()] = sympy.Integer(hint) with contextlib.suppress(KeyError): # update the size hint now that we know the size env.config_spec.block_sizes.block_id_lookup( self.block_id - ).update_hint(env.size_hint(size)) + ).update_hint(hint) elif size is None or self.size is None or self.size != size: self.size = None diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 2fc1910e7..df8206b88 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -527,6 +527,10 @@ def create( output_size = SubscriptIndexing.compute_shape(fake_value, index) env = CompileEnvironment.current() dtype = env.triton_index_type() + + def _is_size_one(size: int | torch.SymInt) -> bool: + return env.known_equal(size, 1) + for n, k in enumerate(index): if k is None: output_idx += 1 @@ -544,7 +548,7 @@ def create( index_values.append(f"({index_var}){expand}") if ( mask := state.codegen.mask_var(origin.origin.block_id) - ) and fake_value.size(i) != 1: + ) and not _is_size_one(fake_value.size(i)): mask_values.setdefault(f"({mask}){expand}") output_idx += 1 else: @@ -576,7 +580,7 @@ def create( index_values.append(f"{start}{expand}") else: # Full slice or slice without step - if size != 1: + if not _is_size_one(size): rdim = env.allocate_reduction_dimension(size) block_idx = rdim.block_id index_var = state.codegen.index_var(block_idx) @@ -620,7 +624,7 @@ def create( assert len(index_values) == fake_value.ndim index_expr = [] for i, idx in enumerate(index_values): - if fake_value.size(i) != 1: + if not _is_size_one(fake_value.size(i)): stride = state.device_function.tensor_stride(fake_value, i).name index_expr.append(f"{idx} * {stride}") if not index_expr: diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index edf2f4faf..3a408f438 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -37,6 +37,7 @@ from .compile_environment import FixedBlockSizeSource from .compile_environment import LoopSpecBlockSizeSource from .compile_environment import warning +from .device_function import contains_only_block_size_symbols from .host_function import HostFunction from .host_function import SymbolOrigin from .output_header import library_imports @@ -473,6 +474,16 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]: if self.origin.is_device(): output_sizes.append(output_size) elif output_size != 1: + # If all symbols in output_size are block size symbols, we reuse them + if isinstance(output_size, torch.SymInt): + expr = output_size._sympy_() + if ( + isinstance(expr, sympy.Expr) + and expr.free_symbols + and contains_only_block_size_symbols(expr) + ): + output_sizes.append(output_size) + continue rdim = CompileEnvironment.current().allocate_reduction_dimension( output_size ) diff --git a/helion/_compiler/utils.py b/helion/_compiler/utils.py index f5260cd7f..0992514af 100644 --- a/helion/_compiler/utils.py +++ b/helion/_compiler/utils.py @@ -26,4 +26,6 @@ def compute_slice_size( step = slice_obj.step return (stop - start + step - 1) // step # Full slice or slice without step - return original_size + start = slice_obj.start if slice_obj.start is not None else 0 + stop = slice_obj.stop if slice_obj.stop is not None else original_size + return stop - start