Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions examples/int4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor:
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)

for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed):
# Load corresponding tiles from A (need to load twice the packed tile size)
# We need to map tile_k_packed to the corresponding range in A
a_tile_begin = tile_k_packed.begin * 2
a_tile_len = block_size_k_packed * 2
a_tile = A[tile_m, a_tile_begin : (a_tile_begin + a_tile_len)].to(
torch.float32
) # [BLOCK_SIZE_M, BLOCK_SIZE_K]

# Load packed int8 data from B
b_tile = B[tile_k_packed, tile_n] # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]

Expand All @@ -60,29 +68,19 @@ def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor:
b_lo = ((b_tile << 4) >> 4).to(torch.int8) # Sign-extend low 4 bits
b_hi = (b_tile >> 4).to(torch.int8) # Sign-extend high 4 bits

# Convert to bfloat16
b_lo_bf16 = b_lo.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
b_hi_bf16 = b_hi.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]

# Stack and reshape to interleave low and high bits
# Stack along a new dimension to get [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N]
b_stacked = torch.stack([b_lo_bf16, b_hi_bf16], dim=1)
b_stacked = torch.stack([b_lo, b_hi], dim=1)

# Reshape to interleave: [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] -> [BLOCK_SIZE_K, BLOCK_SIZE_N]
# This will place elements in the order: b_lo[0], b_hi[0], b_lo[1], b_hi[1], ...
b_unpacked = b_stacked.reshape(
tile_k_packed.block_size * 2, tile_n.block_size
)

# Load corresponding tiles from A (need to load twice the packed tile size)
# We need to map tile_k_packed to the corresponding range in A
a_tile_begin = tile_k_packed.begin * 2
a_tile_len = tile_k_packed.block_size * 2
a_tile = A[
tile_m, a_tile_begin : (a_tile_begin + a_tile_len)
] # [BLOCK_SIZE_M, BLOCK_SIZE_K]
).to(torch.float32)

acc = acc + hl.dot(a_tile, b_unpacked) # [BLOCK_SIZE_M, BLOCK_SIZE_N]
a_tile = a_tile.unsqueeze(2) # [BLOCK_SIZE_M, BLOCK_SIZE_K, 1]
b_unpacked = b_unpacked.unsqueeze(0)
acc = acc + (a_tile * b_unpacked).sum(dim=1) # [BLOCK_SIZE_M, BLOCK_SIZE_N]

C[tile_m, tile_n] = acc.to(torch.bfloat16)

Expand All @@ -106,14 +104,13 @@ def int4_gemm_tritonbench(tb_op: object, x: torch.Tensor, w: torch.Tensor) -> Ca
Callable: A function that performs the int4 gemm.
"""

def run_kernel() -> torch.Tensor:
x_2d = x.reshape(-1, x.size(-1))

# Pack w to int4 format (two 4-bit values per int8 byte)
w_int8 = w.to(torch.int8)
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2)
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8)
# Pack w to int4 format (two 4-bit values per int8 byte)
x_2d = x.reshape(-1, x.size(-1))
w_int8 = w.to(torch.int8)
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2)
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to double-check: do other backends in TritonBench also run this preprocess part outside of the measured kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 yes they do. There are two versions, one preprocessed, one not. The comparisons I did were against the preprocessed pt2 version


def run_kernel() -> torch.Tensor:
return matmul_bf16_int4(x_2d, w_packed)

return run_kernel
Expand Down
60 changes: 32 additions & 28 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@ from torch._inductor.runtime import triton_helpers
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul: tl.constexpr):
def _helion_matmul_bf16_int4(A, B, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul_1: tl.constexpr):
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
Expand All @@ -1355,37 +1355,40 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri
mask_2 = indices_2 < N
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
floordiv = triton_helpers.div_floor_integer(K, 2)
for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_0 < floordiv
for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
mask_0 = indices_3 < floordiv
acc_copy = acc
acc_copy_0 = acc_copy
b_tile = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
v_0 = tl.full([], 4, tl.int8)
v_1 = b_tile << v_0
v_2 = tl.full([], 4, tl.int8)
v_3 = v_1 >> v_2
v_4 = tl.full([], 4, tl.int8)
v_5 = b_tile >> v_4
v_6 = tl.cast(v_3, tl.bfloat16)
v_7 = tl.cast(v_5, tl.bfloat16)
mul = 2 * offset_3
iota = mul + tl.arange(0, mul_1)
load = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
v_0 = tl.cast(load, tl.float32)
b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
v_1 = tl.full([], 4, tl.int8)
v_2 = b_tile << v_1
v_3 = tl.full([], 4, tl.int8)
v_4 = v_2 >> v_3
v_5 = tl.full([], 4, tl.int8)
v_6 = b_tile >> v_5
stack_idx = tl.arange(0, 2)
broadcast_idx = stack_idx[None, :, None]
expanded_0 = tl.expand_dims(v_6, 1)
expanded_1 = tl.expand_dims(v_7, 1)
expanded_0 = tl.expand_dims(v_4, 1)
expanded_1 = tl.expand_dims(v_6, 1)
stacked_result = tl.zeros_like(expanded_0)
mask_3 = broadcast_idx == 0
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
mask_4 = broadcast_idx == 1
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
mul_5 = 2 * offset_0
iota = mul_5 + tl.arange(0, mul)
a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
acc = acc_copy_0 + dot
v_9 = tl.cast(acc, tl.bfloat16)
tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_9, mask_1[:, None] & mask_2[None, :])
mask_4 = broadcast_idx == 0
stacked_result = tl.where(mask_4, expanded_0, stacked_result)
mask_5 = broadcast_idx == 1
stacked_result = tl.where(mask_5, expanded_1, stacked_result)
view = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
v_7 = tl.cast(view, tl.float32)
a_tile_1 = v_0[:, :, None]
b_unpacked_1 = v_7[None, :, :]
v_8 = a_tile_1 * b_unpacked_1
sum_1 = tl.cast(tl.sum(v_8, 1), tl.float32)
acc = acc_copy_0 + sum_1
v_10 = tl.cast(acc, tl.bfloat16)
tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_10, mask_1[:, None] & mask_2[None, :])

def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
"""
Expand All @@ -1409,7 +1412,8 @@ def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
_BLOCK_SIZE_1 = 64
_BLOCK_SIZE_2 = 32
_BLOCK_SIZE_0 = 64
_launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
_RDIM_SIZE_3 = triton.next_power_of_2(2 * _BLOCK_SIZE_0)
_launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), A, B, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return C

--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
Expand Down
Loading