diff --git a/benchmarks/run.py b/benchmarks/run.py index 578808d03..ed712aeae 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -166,6 +166,11 @@ class RunResult: "examples.welford", "welford", ), + "int4_gemm": ( + "tritonbench.operators.int4_gemm.int4_gemm", + "examples.int4_gemm", + "int4_gemm_tritonbench", + ), } @@ -266,6 +271,14 @@ class RunResult: "helion_kl_div_tritonbench-speedup": "helion_speedup", "helion_kl_div_tritonbench-accuracy": "helion_accuracy", }, + "int4_gemm": { + "triton_int4_gemm-speedup": "triton_speedup", + "triton_int4_gemm-accuracy": "triton_accuracy", + "torch_compile_int4_gemm-speedup": "torch_compile_speedup", + "torch_compile_int4_gemm-accuracy": "torch_compile_accuracy", + "helion_int4_gemm_tritonbench-speedup": "helion_speedup", + "helion_int4_gemm_tritonbench-accuracy": "helion_accuracy", + }, } diff --git a/examples/int4_gemm.py b/examples/int4_gemm.py new file mode 100644 index 000000000..a7346226b --- /dev/null +++ b/examples/int4_gemm.py @@ -0,0 +1,178 @@ +""" +INT4 General Matrix Multiplication (GEMM) with Helion +===================================================== +This example demonstrates an INT4 GEMM kernel implemented in Helion. The kernel performs +matrix multiplication where the second matrix B is packed with two 4-bit values per byte. +The kernel unpacks the int4 values, converts to bfloat16, and performs matmul with +the bfloat16 matrix A. +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +from typing import Callable + +import torch +from torch import Tensor + +import helion +import helion.language as hl + + +# %% +# INT4 GEMM Kernel +# ---------------- +@helion.kernel( + use_default_config=True, + static_shapes=False, # Allow dynamic shapes to handle different input sizes +) +def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor: + """ + BFloat16 x INT4 General Matrix Multiplication (GEMM). + + This kernel performs matrix multiplication where: + - A is a bfloat16 matrix of shape [M, K] + - B is an int8 matrix of shape [K//2, N] containing packed int4 values + (two 4-bit values packed into each int8) + + Args: + A (Tensor): Input tensor of shape [M, K] in bfloat16 format. + B (Tensor): Packed int4 tensor of shape [K//2, N] in int8 format. + + Returns: + Tensor: Output tensor of shape [M, N] in bfloat16 format. + """ + M, K = A.shape + _, N = B.shape + + C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device) + block_size_k_packed = hl.register_block_size(K // 2) + + # Use Helion to tile the computation + for tile_m, tile_n in hl.tile([M, N]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + + for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed): + # Load packed int8 data from B + b_tile = B[tile_k_packed, tile_n] # [BLOCK_SIZE_K//2, BLOCK_SIZE_N] + + # Extract low and high 4-bit values with sign extension + # Low nibble: sign-extend from 4-bit to 8-bit using left shift then arithmetic right shift + b_lo = ((b_tile << 4) >> 4).to(torch.int8) # Sign-extend low 4 bits + b_hi = (b_tile >> 4).to(torch.int8) # Sign-extend high 4 bits + + # Convert to bfloat16 + b_lo_bf16 = b_lo.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N] + b_hi_bf16 = b_hi.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N] + + # Stack and reshape to interleave low and high bits + # Stack along a new dimension to get [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] + b_stacked = torch.stack([b_lo_bf16, b_hi_bf16], dim=1) + + # Reshape to interleave: [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] -> [BLOCK_SIZE_K, BLOCK_SIZE_N] + # This will place elements in the order: b_lo[0], b_hi[0], b_lo[1], b_hi[1], ... + b_unpacked = b_stacked.reshape( + tile_k_packed.block_size * 2, tile_n.block_size + ) + + # Load corresponding tiles from A (need to load twice the packed tile size) + # We need to map tile_k_packed to the corresponding range in A + a_tile_begin = tile_k_packed.begin * 2 + a_tile_len = tile_k_packed.block_size * 2 + a_tile = A[ + tile_m, a_tile_begin : (a_tile_begin + a_tile_len) + ] # [BLOCK_SIZE_M, BLOCK_SIZE_K] + + acc = acc + hl.dot(a_tile, b_unpacked) # [BLOCK_SIZE_M, BLOCK_SIZE_N] + + C[tile_m, tile_n] = acc.to(torch.bfloat16) + + return C + + +# %% +# TritonBench Wrapper +# ------------------- +def int4_gemm_tritonbench(tb_op: object, x: torch.Tensor, w: torch.Tensor) -> Callable: + """ + Wrapper for TritonBench compatibility. + + Args: + tb_op: TritonBench operator instance + x (torch.Tensor): Left input tensor in bfloat16 format. + w (torch.Tensor): Right input tensor of shape [K, N] containing int4 values. + Will be packed to int4 format. + + Returns: + Callable: A function that performs the int4 gemm. + """ + + def run_kernel() -> torch.Tensor: + x_2d = x.reshape(-1, x.size(-1)) + + # Pack w to int4 format (two 4-bit values per int8 byte) + w_int8 = w.to(torch.int8) + w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2) + w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8) + + return matmul_bf16_int4(x_2d, w_packed) + + return run_kernel + + +# %% +# Verification Function +# --------------------- +def check(m: int, k: int, n: int) -> None: + """ + Test the INT4 GEMM implementation. + + Args: + m (int): Number of rows in the left input matrix. + k (int): Shared dimension (must be even). + n (int): Number of columns in the right input matrix. + """ + # Create test matrices + A = torch.randn(m, k, dtype=torch.bfloat16, device="cuda") + + # Create packed int4 matrix B (K//2 x N) + # Generate random int4 values in range [-8, 7] and pack them + B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device="cuda") + + # Pack using the same format as tritonbench + B_reshaped = B_unpacked.reshape(k // 2, 2, n).permute(1, 0, 2) + B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8) + + # Convert unpacked values to bfloat16 for reference + B_unpacked_bf16 = B_unpacked.to(torch.bfloat16) + + # Compute reference result + expected = torch.matmul(A, B_unpacked_bf16) + + # Run the kernel + result = matmul_bf16_int4(A, B_packed) + + # Check accuracy with appropriate tolerance + torch.testing.assert_close(result, expected, rtol=2e-1, atol=1.0) + print(f"Test passed for shapes: M={m}, K={k}, N={n}") + + +# %% +# Main Function +# ------------- +def main() -> None: + """ + Main function to run tests with different matrix sizes. + """ + check(256, 512, 256) + check(512, 512, 512) + check(1024, 1024, 1024) + + +# %% +# Run Example +# ----------- +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index cf081f79d..51ca95eb1 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -900,6 +900,86 @@ def geglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher): _launcher(_helion_geglu, (triton.cdiv(total_elements, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, a_flat.stride(0), b_flat.stride(0), out_flat.stride(0), total_elements, _BLOCK_SIZE_0, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestExamples.test_int4_gemm) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul: tl.constexpr): + num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_1 = pid_0 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < M + offset_2 = pid_1 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < N + acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) + floordiv = triton_helpers.div_floor_integer(K, 2) + for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) + mask_0 = indices_3 < floordiv + acc_copy = acc + acc_copy_0 = acc_copy + b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + v_0 = tl.full([], 4, tl.int8) + v_1 = b_tile << v_0 + v_2 = tl.full([], 4, tl.int8) + v_3 = v_1 >> v_2 + v_4 = tl.full([], 4, tl.int8) + v_5 = b_tile >> v_4 + v_6 = tl.cast(v_3, tl.bfloat16) + v_7 = tl.cast(v_5, tl.bfloat16) + stack_idx = tl.arange(0, 2) + broadcast_idx = stack_idx[None, :, None] + expanded_0 = tl.expand_dims(v_6, 1) + expanded_1 = tl.expand_dims(v_7, 1) + stacked_result = tl.zeros_like(expanded_0) + mask_4 = broadcast_idx == 0 + stacked_result = tl.where(mask_4, expanded_0, stacked_result) + mask_5 = broadcast_idx == 1 + stacked_result = tl.where(mask_5, expanded_1, stacked_result) + b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2]) + mul_5 = 2 * offset_3 + iota = mul_5 + tl.arange(0, mul) + a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0) + dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + v_9 = tl.cast(acc, tl.bfloat16) + tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_9, mask_1[:, None] & mask_2[None, :]) + +def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher): + """ + BFloat16 x INT4 General Matrix Multiplication (GEMM). + + This kernel performs matrix multiplication where: + - A is a bfloat16 matrix of shape [M, K] + - B is an int8 matrix of shape [K//2, N] containing packed int4 values + (two 4-bit values packed into each int8) + + Args: + A (Tensor): Input tensor of shape [M, K] in bfloat16 format. + B (Tensor): Packed int4 tensor of shape [K//2, N] in int8 format. + + Returns: + Tensor: Output tensor of shape [M, N] in bfloat16 format. + """ + M, K = A.shape + _, N = B.shape + C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device) + _BLOCK_SIZE_1 = 64 + _BLOCK_SIZE_2 = 32 + _RDIM_SIZE_3 = triton.next_power_of_2(K) + _BLOCK_SIZE_0 = 64 + _launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return C + --- assertExpectedJournal(TestExamples.test_jagged_dense_add) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index a566ba379..287f5d4db 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1112,6 +1112,41 @@ def test_kl_div(self): ) ) + def test_int4_gemm(self): + # Matrix dimensions + M, K, N = 256, 512, 256 + + # Create bfloat16 matrix A + A = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE) + + # Create packed int4 matrix B + # Generate random int4 values in range [-8, 7] + B_unpacked = torch.randint(-8, 8, (K, N), dtype=torch.int8, device=DEVICE) + + # Pack two int4 values per int8 + B_reshaped = B_unpacked.reshape(K // 2, 2, N).permute(1, 0, 2) + B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8) + + # Convert unpacked to bfloat16 for expected result + B_unpacked_bf16 = B_unpacked.to(torch.bfloat16) + expected = torch.matmul(A, B_unpacked_bf16) + + args = (A, B_packed) + + self.assertExpectedJournal( + check_example( + "int4_gemm", + args, + expected, + fn_name="matmul_bf16_int4", + block_sizes=[64, 64, 32], + num_warps=4, + num_stages=3, + rtol=2e-1, + atol=1.0, + ) + ) + if __name__ == "__main__": unittest.main()