From c9d8521f78e30fc85ce83c5164cc97df380d0a00 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Wed, 1 Oct 2025 14:28:29 -0700 Subject: [PATCH] Faster int4 gemm stack-info: PR: https://github.com/pytorch/helion/pull/751, branch: PaulZhang12/stack/11 --- examples/int4_gemm.py | 41 ++++++++++++------------- test/test_examples.expected | 60 ++++++++++++++++++++----------------- 2 files changed, 51 insertions(+), 50 deletions(-) diff --git a/examples/int4_gemm.py b/examples/int4_gemm.py index e1a0ce332..5dd936412 100644 --- a/examples/int4_gemm.py +++ b/examples/int4_gemm.py @@ -52,6 +52,14 @@ def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed): + # Load corresponding tiles from A (need to load twice the packed tile size) + # We need to map tile_k_packed to the corresponding range in A + a_tile_begin = tile_k_packed.begin * 2 + a_tile_len = block_size_k_packed * 2 + a_tile = A[tile_m, a_tile_begin : (a_tile_begin + a_tile_len)].to( + torch.float32 + ) # [BLOCK_SIZE_M, BLOCK_SIZE_K] + # Load packed int8 data from B b_tile = B[tile_k_packed, tile_n] # [BLOCK_SIZE_K//2, BLOCK_SIZE_N] @@ -60,29 +68,19 @@ def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor: b_lo = ((b_tile << 4) >> 4).to(torch.int8) # Sign-extend low 4 bits b_hi = (b_tile >> 4).to(torch.int8) # Sign-extend high 4 bits - # Convert to bfloat16 - b_lo_bf16 = b_lo.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N] - b_hi_bf16 = b_hi.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N] - # Stack and reshape to interleave low and high bits # Stack along a new dimension to get [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] - b_stacked = torch.stack([b_lo_bf16, b_hi_bf16], dim=1) + b_stacked = torch.stack([b_lo, b_hi], dim=1) # Reshape to interleave: [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] -> [BLOCK_SIZE_K, BLOCK_SIZE_N] # This will place elements in the order: b_lo[0], b_hi[0], b_lo[1], b_hi[1], ... b_unpacked = b_stacked.reshape( tile_k_packed.block_size * 2, tile_n.block_size - ) - - # Load corresponding tiles from A (need to load twice the packed tile size) - # We need to map tile_k_packed to the corresponding range in A - a_tile_begin = tile_k_packed.begin * 2 - a_tile_len = tile_k_packed.block_size * 2 - a_tile = A[ - tile_m, a_tile_begin : (a_tile_begin + a_tile_len) - ] # [BLOCK_SIZE_M, BLOCK_SIZE_K] + ).to(torch.float32) - acc = acc + hl.dot(a_tile, b_unpacked) # [BLOCK_SIZE_M, BLOCK_SIZE_N] + a_tile = a_tile.unsqueeze(2) # [BLOCK_SIZE_M, BLOCK_SIZE_K, 1] + b_unpacked = b_unpacked.unsqueeze(0) + acc = acc + (a_tile * b_unpacked).sum(dim=1) # [BLOCK_SIZE_M, BLOCK_SIZE_N] C[tile_m, tile_n] = acc.to(torch.bfloat16) @@ -106,14 +104,13 @@ def int4_gemm_tritonbench(tb_op: object, x: torch.Tensor, w: torch.Tensor) -> Ca Callable: A function that performs the int4 gemm. """ - def run_kernel() -> torch.Tensor: - x_2d = x.reshape(-1, x.size(-1)) - - # Pack w to int4 format (two 4-bit values per int8 byte) - w_int8 = w.to(torch.int8) - w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2) - w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8) + # Pack w to int4 format (two 4-bit values per int8 byte) + x_2d = x.reshape(-1, x.size(-1)) + w_int8 = w.to(torch.int8) + w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2) + w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8) + def run_kernel() -> torch.Tensor: return matmul_bf16_int4(x_2d, w_packed) return run_kernel diff --git a/test/test_examples.expected b/test/test_examples.expected index 5e56dbe53..dd2c5ba20 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1343,7 +1343,7 @@ from torch._inductor.runtime import triton_helpers from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul: tl.constexpr): +def _helion_matmul_bf16_int4(A, B, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul_1: tl.constexpr): num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1) pid_0 = tl.program_id(0) % num_blocks_0 pid_1 = tl.program_id(0) // num_blocks_0 @@ -1355,37 +1355,40 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri mask_2 = indices_2 < N acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) floordiv = triton_helpers.div_floor_integer(K, 2) - for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0): - indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) - mask_0 = indices_0 < floordiv + for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) + mask_0 = indices_3 < floordiv acc_copy = acc acc_copy_0 = acc_copy - b_tile = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0) - v_0 = tl.full([], 4, tl.int8) - v_1 = b_tile << v_0 - v_2 = tl.full([], 4, tl.int8) - v_3 = v_1 >> v_2 - v_4 = tl.full([], 4, tl.int8) - v_5 = b_tile >> v_4 - v_6 = tl.cast(v_3, tl.bfloat16) - v_7 = tl.cast(v_5, tl.bfloat16) + mul = 2 * offset_3 + iota = mul + tl.arange(0, mul_1) + load = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0) + v_0 = tl.cast(load, tl.float32) + b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + v_1 = tl.full([], 4, tl.int8) + v_2 = b_tile << v_1 + v_3 = tl.full([], 4, tl.int8) + v_4 = v_2 >> v_3 + v_5 = tl.full([], 4, tl.int8) + v_6 = b_tile >> v_5 stack_idx = tl.arange(0, 2) broadcast_idx = stack_idx[None, :, None] - expanded_0 = tl.expand_dims(v_6, 1) - expanded_1 = tl.expand_dims(v_7, 1) + expanded_0 = tl.expand_dims(v_4, 1) + expanded_1 = tl.expand_dims(v_6, 1) stacked_result = tl.zeros_like(expanded_0) - mask_3 = broadcast_idx == 0 - stacked_result = tl.where(mask_3, expanded_0, stacked_result) - mask_4 = broadcast_idx == 1 - stacked_result = tl.where(mask_4, expanded_1, stacked_result) - b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2]) - mul_5 = 2 * offset_0 - iota = mul_5 + tl.arange(0, mul) - a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0) - dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32) - acc = acc_copy_0 + dot - v_9 = tl.cast(acc, tl.bfloat16) - tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_9, mask_1[:, None] & mask_2[None, :]) + mask_4 = broadcast_idx == 0 + stacked_result = tl.where(mask_4, expanded_0, stacked_result) + mask_5 = broadcast_idx == 1 + stacked_result = tl.where(mask_5, expanded_1, stacked_result) + view = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2]) + v_7 = tl.cast(view, tl.float32) + a_tile_1 = v_0[:, :, None] + b_unpacked_1 = v_7[None, :, :] + v_8 = a_tile_1 * b_unpacked_1 + sum_1 = tl.cast(tl.sum(v_8, 1), tl.float32) + acc = acc_copy_0 + sum_1 + v_10 = tl.cast(acc, tl.bfloat16) + tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_10, mask_1[:, None] & mask_2[None, :]) def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher): """ @@ -1409,7 +1412,8 @@ def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_1 = 64 _BLOCK_SIZE_2 = 32 _BLOCK_SIZE_0 = 64 - _launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3) + _RDIM_SIZE_3 = triton.next_power_of_2(2 * _BLOCK_SIZE_0) + _launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), A, B, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3) return C --- assertExpectedJournal(TestExamples.test_jagged_dense_add)