Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ class RunResult:
"examples.welford",
"welford",
),
"int4_gemm": (
"tritonbench.operators.int4_gemm.int4_gemm",
"examples.int4_gemm",
"int4_gemm_tritonbench",
),
}


Expand Down Expand Up @@ -266,6 +271,14 @@ class RunResult:
"helion_kl_div_tritonbench-speedup": "helion_speedup",
"helion_kl_div_tritonbench-accuracy": "helion_accuracy",
},
"int4_gemm": {
"triton_int4_gemm-speedup": "triton_speedup",
"triton_int4_gemm-accuracy": "triton_accuracy",
"torch_compile_int4_gemm-speedup": "torch_compile_speedup",
"torch_compile_int4_gemm-accuracy": "torch_compile_accuracy",
"helion_int4_gemm_tritonbench-speedup": "helion_speedup",
"helion_int4_gemm_tritonbench-accuracy": "helion_accuracy",
},
}


Expand Down
178 changes: 178 additions & 0 deletions examples/int4_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""
INT4 General Matrix Multiplication (GEMM) with Helion
=====================================================
This example demonstrates an INT4 GEMM kernel implemented in Helion. The kernel performs
matrix multiplication where the second matrix B is packed with two 4-bit values per byte.
The kernel unpacks the int4 values, converts to bfloat16, and performs matmul with
the bfloat16 matrix A.
"""

# %%
# Imports
# -------
from __future__ import annotations

from typing import Callable

import torch
from torch import Tensor

import helion
import helion.language as hl


# %%
# INT4 GEMM Kernel
# ----------------
@helion.kernel(
use_default_config=True,
static_shapes=False, # Allow dynamic shapes to handle different input sizes
)
def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor:
"""
BFloat16 x INT4 General Matrix Multiplication (GEMM).

This kernel performs matrix multiplication where:
- A is a bfloat16 matrix of shape [M, K]
- B is an int8 matrix of shape [K//2, N] containing packed int4 values
(two 4-bit values packed into each int8)

Args:
A (Tensor): Input tensor of shape [M, K] in bfloat16 format.
B (Tensor): Packed int4 tensor of shape [K//2, N] in int8 format.

Returns:
Tensor: Output tensor of shape [M, N] in bfloat16 format.
"""
M, K = A.shape
_, N = B.shape

C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
block_size_k_packed = hl.register_block_size(K // 2)

# Use Helion to tile the computation
for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)

for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed):
# Load packed int8 data from B
b_tile = B[tile_k_packed, tile_n] # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]

# Extract low and high 4-bit values with sign extension
# Low nibble: sign-extend from 4-bit to 8-bit using left shift then arithmetic right shift
b_lo = ((b_tile << 4) >> 4).to(torch.int8) # Sign-extend low 4 bits
b_hi = (b_tile >> 4).to(torch.int8) # Sign-extend high 4 bits

# Convert to bfloat16
b_lo_bf16 = b_lo.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
b_hi_bf16 = b_hi.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]

# Stack and reshape to interleave low and high bits
# Stack along a new dimension to get [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N]
b_stacked = torch.stack([b_lo_bf16, b_hi_bf16], dim=1)

# Reshape to interleave: [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] -> [BLOCK_SIZE_K, BLOCK_SIZE_N]
# This will place elements in the order: b_lo[0], b_hi[0], b_lo[1], b_hi[1], ...
b_unpacked = b_stacked.reshape(
tile_k_packed.block_size * 2, tile_n.block_size
)

# Load corresponding tiles from A (need to load twice the packed tile size)
# We need to map tile_k_packed to the corresponding range in A
a_tile_begin = tile_k_packed.begin * 2
a_tile_len = tile_k_packed.block_size * 2
a_tile = A[
tile_m, a_tile_begin : (a_tile_begin + a_tile_len)
] # [BLOCK_SIZE_M, BLOCK_SIZE_K]

acc = acc + hl.dot(a_tile, b_unpacked) # [BLOCK_SIZE_M, BLOCK_SIZE_N]

C[tile_m, tile_n] = acc.to(torch.bfloat16)

return C


# %%
# TritonBench Wrapper
# -------------------
def int4_gemm_tritonbench(tb_op: object, x: torch.Tensor, w: torch.Tensor) -> Callable:
"""
Wrapper for TritonBench compatibility.

Args:
tb_op: TritonBench operator instance
x (torch.Tensor): Left input tensor in bfloat16 format.
w (torch.Tensor): Right input tensor of shape [K, N] containing int4 values.
Will be packed to int4 format.

Returns:
Callable: A function that performs the int4 gemm.
"""

def run_kernel() -> torch.Tensor:
x_2d = x.reshape(-1, x.size(-1))

# Pack w to int4 format (two 4-bit values per int8 byte)
w_int8 = w.to(torch.int8)
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2)
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8)

return matmul_bf16_int4(x_2d, w_packed)

return run_kernel


# %%
# Verification Function
# ---------------------
def check(m: int, k: int, n: int) -> None:
"""
Test the INT4 GEMM implementation.

Args:
m (int): Number of rows in the left input matrix.
k (int): Shared dimension (must be even).
n (int): Number of columns in the right input matrix.
"""
# Create test matrices
A = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")

# Create packed int4 matrix B (K//2 x N)
# Generate random int4 values in range [-8, 7] and pack them
B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device="cuda")

# Pack using the same format as tritonbench
B_reshaped = B_unpacked.reshape(k // 2, 2, n).permute(1, 0, 2)
B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8)

# Convert unpacked values to bfloat16 for reference
B_unpacked_bf16 = B_unpacked.to(torch.bfloat16)

# Compute reference result
expected = torch.matmul(A, B_unpacked_bf16)

# Run the kernel
result = matmul_bf16_int4(A, B_packed)

# Check accuracy with appropriate tolerance
torch.testing.assert_close(result, expected, rtol=2e-1, atol=1.0)
print(f"Test passed for shapes: M={m}, K={k}, N={n}")


# %%
# Main Function
# -------------
def main() -> None:
"""
Main function to run tests with different matrix sizes.
"""
check(256, 512, 256)
check(512, 512, 512)
check(1024, 1024, 1024)


# %%
# Run Example
# -----------
if __name__ == "__main__":
main()
80 changes: 80 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,86 @@ def geglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher):
_launcher(_helion_geglu, (triton.cdiv(total_elements, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, a_flat.stride(0), b_flat.stride(0), out_flat.stride(0), total_elements, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_int4_gemm)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul: tl.constexpr):
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_1 = pid_0 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < M
offset_2 = pid_1 * _BLOCK_SIZE_2
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
mask_2 = indices_2 < N
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
floordiv = triton_helpers.div_floor_integer(K, 2)
for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_3 < floordiv
acc_copy = acc
acc_copy_0 = acc_copy
b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
v_0 = tl.full([], 4, tl.int8)
v_1 = b_tile << v_0
v_2 = tl.full([], 4, tl.int8)
v_3 = v_1 >> v_2
v_4 = tl.full([], 4, tl.int8)
v_5 = b_tile >> v_4
v_6 = tl.cast(v_3, tl.bfloat16)
v_7 = tl.cast(v_5, tl.bfloat16)
stack_idx = tl.arange(0, 2)
broadcast_idx = stack_idx[None, :, None]
expanded_0 = tl.expand_dims(v_6, 1)
expanded_1 = tl.expand_dims(v_7, 1)
stacked_result = tl.zeros_like(expanded_0)
mask_4 = broadcast_idx == 0
stacked_result = tl.where(mask_4, expanded_0, stacked_result)
mask_5 = broadcast_idx == 1
stacked_result = tl.where(mask_5, expanded_1, stacked_result)
b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
mul_5 = 2 * offset_3
iota = mul_5 + tl.arange(0, mul)
a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
acc = acc_copy_0 + dot
v_9 = tl.cast(acc, tl.bfloat16)
tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_9, mask_1[:, None] & mask_2[None, :])

def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
"""
BFloat16 x INT4 General Matrix Multiplication (GEMM).

This kernel performs matrix multiplication where:
- A is a bfloat16 matrix of shape [M, K]
- B is an int8 matrix of shape [K//2, N] containing packed int4 values
(two 4-bit values packed into each int8)

Args:
A (Tensor): Input tensor of shape [M, K] in bfloat16 format.
B (Tensor): Packed int4 tensor of shape [K//2, N] in int8 format.

Returns:
Tensor: Output tensor of shape [M, N] in bfloat16 format.
"""
M, K = A.shape
_, N = B.shape
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
_BLOCK_SIZE_1 = 64
_BLOCK_SIZE_2 = 32
_RDIM_SIZE_3 = triton.next_power_of_2(K)
_BLOCK_SIZE_0 = 64
_launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return C

--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
from __future__ import annotations

Expand Down
35 changes: 35 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,41 @@ def test_kl_div(self):
)
)

def test_int4_gemm(self):
# Matrix dimensions
M, K, N = 256, 512, 256

# Create bfloat16 matrix A
A = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE)

# Create packed int4 matrix B
# Generate random int4 values in range [-8, 7]
B_unpacked = torch.randint(-8, 8, (K, N), dtype=torch.int8, device=DEVICE)

# Pack two int4 values per int8
B_reshaped = B_unpacked.reshape(K // 2, 2, N).permute(1, 0, 2)
B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8)

# Convert unpacked to bfloat16 for expected result
B_unpacked_bf16 = B_unpacked.to(torch.bfloat16)
expected = torch.matmul(A, B_unpacked_bf16)

args = (A, B_packed)

self.assertExpectedJournal(
check_example(
"int4_gemm",
args,
expected,
fn_name="matmul_bf16_int4",
block_sizes=[64, 64, 32],
num_warps=4,
num_stages=3,
rtol=2e-1,
atol=1.0,
)
)


if __name__ == "__main__":
unittest.main()
Loading