Skip to content

int4 gemm kernel #508

@yf225

Description

@yf225

Example:

import torch
from torch import Tensor
import helion
import helion.language as hl

@helion.kernel(use_default_config=True, static_shapes=True)
def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor:
    """
    A: (M, K) bf16
    B: (K, N) int4. assume b is packed with 2 `int4` elements per K. i.e., it's a
        (K//2)xNx(2xint4) matrix, represented in Triton as (K//2)xNxi8.
    """
    M, K = A.shape
    _, N = B.shape

    C = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
    block_size_k_packed = hl.register_block_size(K // 2)
    block_size_n = hl.register_block_size(N)
    b_bf16 = torch.empty([block_size_k_packed, 2, block_size_n], dtype=torch.bfloat16, device=A.device)

    # Use Helion to tile the computation
    for tile_m in hl.tile(M):
        for tile_n in hl.tile(N, block_size=block_size_n):
            acc = hl.zeros((tile_m, tile_n), dtype=torch.bfloat16)

            for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed):
                # Load packed int8 data from B
                b_tile = B[tile_k_packed, tile_n]  # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
                
                # Extract low and high 4-bit values
                b_lo = b_tile & 0x0F  # Extract low 4 bits
                b_hi = (b_tile >> 4) & 0x0F  # Extract high 4 bits
                
                # Stack to create [BLOCK_SIZE_K//2, BLOCK_SIZE_N, 2]
                b_lo_bf16 = b_lo.to(torch.bfloat16)  # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
                b_hi_bf16 = b_hi.to(torch.bfloat16)  # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
                b_bf16[tile_k_packed, 0, tile_n] = b_lo_bf16
                b_bf16[tile_k_packed, 1, tile_n] = b_hi_bf16

                # Reshape to [BLOCK_SIZE_K, BLOCK_SIZE_N] - unpacking the int4 values
                b_bf16_reshaped = b_bf16[tile_k_packed, :, tile_n].reshape([tile_k_packed.block_size * 2, tile_n.block_size])
                
                # Load corresponding tiles from A (need to load twice the packed tile size)
                # We need to map tile_k_packed to the corresponding range in A
                # Use arange to create indices for the second dimension
                a_start = tile_k_packed.begin * 2
                a_end = a_start + tile_k_packed.block_size * 2
                a_tile = A[tile_m, a_start:a_end]  # [BLOCK_SIZE_M, BLOCK_SIZE_K]
                
                acc = acc + hl.dot(a_tile, b_bf16_reshaped).to(torch.bfloat16)  # [BLOCK_SIZE_M, BLOCK_SIZE_N]

            C[tile_m, tile_n] = acc

# Test the kernel
A = torch.randn(8192, 8192, dtype=torch.bfloat16, device="cuda")
B = torch.randint(0, 16, (4096, 8192), dtype=torch.int8, device="cuda")
C = matmul_bf16_int4(A, B)

This kernel packs two int4 values into one int8 value, hence we need to do customized slicing beyond a tile's length.

Metadata

Metadata

Assignees

Labels

matmulmatmul / gemm / mm / bmm / tl.dot / hl.dot related issuesptc2025view opsSupport for slicing / view / reshape / permute / transpose / flatten / etc. ops

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions