In [13]:
import triton
import triton.language as tl
import paddle

@triton.jit
def gemm_kernel(
    # 指针参数
    a_ptr,                                # 激活矩阵指针 (float16)
    w_ptr,                               # 权重矩阵指针 (int4)
    output_ptr,                          # 输出矩阵指针 (float32)
    # 矩阵维度参数
    M, N, K,
    # 步长参数
    stride_am, stride_ak,                # 激活矩阵的步长
    stride_wk, stride_wn,                # 权重矩阵的步长
    stride_om, stride_on,                # 输出矩阵的步长
    # 块大小参数
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    # group size参数
    GROUP_SIZE_M: tl.constexpr,
):
    """
    计算 GEMM: output = activation @ weight.T
    参数:
        a_ptr: 激活矩阵指针 (M, K)，float16类型
        w_ptr: 权重矩阵指针 (N, K)，int4类型
        output_ptr: 输出矩阵指针 (M, N)，float32类型
    """
    # -----------------------------------------------------------
    # 计算程序ID
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # -----------------------------------------------------------
    # 计算当前块的偏移
    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
    offs_k = tl.arange(0, BLOCK_K)
    
    # -----------------------------------------------------------
    # 创建指针
    a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
    w_ptrs = w_ptr + offs_k[:, None] * stride_wk + offs_bn[None, :] * stride_wn
    
    # -----------------------------------------------------------
    # 初始化累加器
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    # -----------------------------------------------------------
    # 主循环
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        # 加载激活值
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
        w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
        
        # 累加结果
        accumulator += tl.dot(a.to(tl.float32), w.to(tl.float32))
        
        # 更新指针
        a_ptrs += BLOCK_K * stride_ak
        w_ptrs += BLOCK_K * stride_wk
    
    # -----------------------------------------------------------
    # 写回结果
    offs_om = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_on = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    out_ptrs = output_ptr + offs_om[:, None] * stride_om + offs_on[None, :] * stride_on
    o_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
    tl.store(out_ptrs, accumulator, mask=o_mask)

# 包装函数
def gemm(activation, weight):
    """
    执行GEMM运算
    参数:
        activation: 激活矩阵 (M, K)，float16类型
        weight: 权重矩阵 (N, K)，int4类型（以uint8打包存储）
    返回:
        output: 输出矩阵 (M, N)，float32类型
    """
    M, K = activation.shape
    #N, K_ = weight.shape
    #assert K == K_, f"维度不匹配: activation.K={K} != weight.K={K_}"
    
    # 输出分配
    output = paddle.zeros((M, N), dtype='float32')
    
    # 计算 grid 和 block 大小
    BLOCK_M = 16
    BLOCK_N = 16
    BLOCK_K = 32
    GROUP_SIZE_M = 8
    
    grid = lambda META: (
        triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
    )
    
    gemm_kernel[grid](
        activation, weight, output,
        M, N, K,
        activation.strides[0], activation.strides[1],
        weight.strides[0], weight.strides[1],
        output.strides[0], output.strides[1],
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        GROUP_SIZE_M=GROUP_SIZE_M,
    )
    
    return output 

In [25]:
import triton
import triton.language as tl
import paddle

@triton.jit
def wint2_gemm_kernel(
    # 指针参数
    a_ptr,                                # 激活矩阵指针 (float16)
    w_ptr,                               # 权重矩阵指针 (int4)
    output_ptr,                          # 输出矩阵指针 (float32)
    scale,
    # 矩阵维度参数
    M, N, K,
    # 步长参数
    stride_am, stride_ak,                # 激活矩阵的步长
    stride_wk, stride_wn,                # 权重矩阵的步长
    stride_om, stride_on,                # 输出矩阵的步长
    # 块大小参数
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    W_BLOCK_K: tl.constexpr,
    # group size参数
    GROUP_SIZE_M: tl.constexpr,
):
    """
    计算 GEMM: output = activation @ weight.T
    参数:
        a_ptr: 激活矩阵指针 (M, K)，float16类型
        w_ptr: 权重矩阵指针 (N, K)，int4类型
        output_ptr: 输出矩阵指针 (M, N)，float32类型
    """
    # -----------------------------------------------------------
    # 计算程序ID
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # -----------------------------------------------------------
    # 计算当前块的偏移
    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
    a_offs_k = tl.arange(0, BLOCK_K)
    w_offs_k = tl.arange(0, W_BLOCK_K)
    
    # -----------------------------------------------------------
    # 创建指针
    a_ptrs = a_ptr + offs_am[:, None] * stride_am + a_offs_k[None, :] * stride_ak
    w_ptrs = w_ptr + w_offs_k[:, None] * stride_wk + offs_bn[None, :] * stride_wn
    
    # -----------------------------------------------------------
    # 初始化累加器
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    # -----------------------------------------------------------
    # 主循环
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        # 加载激活值
        a = tl.load(a_ptrs, mask=a_offs_k[None, :] < K - k * BLOCK_K, other=0.0)
        w_packed = tl.load(w_ptrs, mask=w_offs_k[:, None] < K/4 - k * BLOCK_K/4, other=0.0)
        
        # 加载并解压int4权重
        w_packed = tl.load(w_ptrs, mask=w_offs_k[:, None] < K, other=0)

        unpacked_3 = w_packed & 0b11
        unpacked_2 = (w_packed >> 2) & 0b11
        unpacked_1 = (w_packed >> 4) & 0b11
        unpacked_0 = (w_packed >> 6) & 0b11

        
        # 反量化权重
        w_float_0 = (unpacked_0.to(tl.float16) - 1) * scale
        w_float_1 = (unpacked_1.to(tl.float16) - 1) * scale
        w_float_2 = (unpacked_2.to(tl.float16) - 1) * scale
        w_float_3 = (unpacked_3.to(tl.float16) - 1) * scale
        
        # 累加结果
        accumulator += tl.dot(a.to(tl.float32), w_float_0.to(tl.float32))
        accumulator += tl.dot(a.to(tl.float32), w_float_1.to(tl.float32))
        accumulator += tl.dot(a.to(tl.float32), w_float_2.to(tl.float32))
        accumulator += tl.dot(a.to(tl.float32), w_float_3.to(tl.float32))

        # 更新指针
        a_ptrs += BLOCK_K * stride_ak
        w_ptrs += BLOCK_K * stride_wk
    
    # -----------------------------------------------------------
    # 写回结果
    offs_om = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_on = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    out_ptrs = output_ptr + offs_om[:, None] * stride_om + offs_on[None, :] * stride_on
    o_mask = (offs_om[:, None] < M) & (offs_on[None, :] < N)
    tl.store(out_ptrs, accumulator, mask=o_mask)

# 包装函数
def wint2_gemm(activation, weight, scale):
    """
    执行WINT2 GEMM运算
    参数:
        activation: 激活矩阵 (M, K)，float16类型
        weight: 权重矩阵 (N, K)，int2类型（以uint8打包存储）
    返回:
        output: 输出矩阵 (M, N)，float32类型
    """
    M, K = activation.shape
    #N, K_ = weight.shape
    #assert K == K_, f"维度不匹配: activation.K={K} != weight.K={K_}"
    
    # 输出分配
    output = paddle.zeros((M, N), dtype='float32')
    
    # 计算 grid 和 block 大小
    BLOCK_M = 16
    BLOCK_N = 16
    BLOCK_K = 32
    W_BLOCK_K = 8
    GROUP_SIZE_M = 8
    
    grid = lambda META: (
        triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
    )
    
    wint2_gemm_kernel[grid](
        activation, weight, output,
        scale,
        M, N, K,
        activation.strides[0], activation.strides[1],
        weight.strides[0], weight.strides[1],
        output.strides[0], output.strides[1],
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, W_BLOCK_K=W_BLOCK_K,
        GROUP_SIZE_M=GROUP_SIZE_M,
    )
    
    return output 

In [26]:
import paddle
import numpy as np
ZERO_POINT = 1
def quantize_to_int2(tensor, scale=None):
    if scale is None:
        # 计算量化scale，使用绝对值最大值除以(2^2-1)=3
        scale = float(paddle.max(paddle.abs(tensor)).item() / 3)
    # 确保scale不为0
    scale = max(scale, 1e-8)
    # 量化到0-3范围
    quant = paddle.clip(paddle.round(tensor / scale) + ZERO_POINT, 0, 3).astype('uint8')
    # 重塑张量以准备打包
    new_shape = list(quant.shape[:-1]) + [-1, 4]
    quant = quant.reshape(new_shape)
    # 打包4个int2值到一个int8
    quant_np = quant.numpy()
    packed = (quant_np[..., 0] << 6) | \
             (quant_np[..., 1] << 4) | \
             (quant_np[..., 2] << 2) | \
             quant_np[..., 3]
    packed = paddle.to_tensor(packed, dtype='uint8')
    
    return packed, scale

def dequantize_from_int2(packed, scale, original_shape):
    # 解包int2值
    packed_np = packed.numpy()
    unpacked_3 = packed_np & 0b11
    unpacked_2 = (packed_np >> 2) & 0b11
    unpacked_1 = (packed_np >> 4) & 0b11
    unpacked_0 = (packed_np >> 6) & 0b11
    # 堆叠解包的值
    unpacked = paddle.stack([paddle.to_tensor(unpacked_0), paddle.to_tensor(unpacked_1), paddle.to_tensor(unpacked_2), paddle.to_tensor(unpacked_3)], axis=-1)
    # 重塑为原始形状
    unpacked = unpacked.reshape(original_shape[:-1] + [-1])
    # 反量化
    tensor = (unpacked.astype('float32')- ZERO_POINT) * scale
    return tensor


# 创建测试张量
x = paddle.randn([2, 8], dtype='float32')
print("原始张量:")
print(x)

# 量化
packed, scale = quantize_to_int2(x)
print("\n打包后的张量:")
print(packed)
print("量化scale:", scale)

# 反量化
reconstructed = dequantize_from_int2(packed, scale, x.shape)
print("\n重建的张量:")
print(reconstructed)

# # 计算误差
# error = paddle.abs(x - reconstructed)
# print("\n最大绝对误差:", paddle.max(error).item())
# print("平均绝对误差:", paddle.mean(error).item())

原始张量:
Tensor(shape=[2, 8], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       [[-0.40147418,  0.84105122, -1.44239163, -0.12111820,  0.57649255,
          0.20772524,  0.45755649,  0.08082178],
        [ 2.34299016,  0.00141706, -1.09204257, -0.13947889,  0.54678589,
          2.05279350,  0.40679640, -0.35385308]])

打包后的张量:
Tensor(shape=[2, 2], dtype=uint8, place=Place(gpu:0), stop_gradient=True,
       [[33 , 153],
        [209, 185]])
量化scale: 0.7809967199961344

重建的张量:
Tensor(shape=[2, 8], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       [[-0.78099674,  0.78099674, -0.78099674,  0.        ,  0.78099674,
          0.        ,  0.78099674,  0.        ],
        [ 1.56199348,  0.        , -0.78099674,  0.        ,  0.78099674,
          1.56199348,  0.78099674,  0.        ]])


In [27]:
import paddle
import numpy as np

# 设置随机种子以保证结果可复现
paddle.seed(4)
np.random.seed(4)

# 创建示例数据
M, N, K = 128, 64, 32

# 创建激活矩阵 (M, K)
activation = paddle.randn([M, K], dtype='float32')

# 创建权重矩阵 (N, K)
weight = paddle.randn([N, K])

# 量化权重到int2
packed_weight, scale = quantize_to_int2(weight)


# 使用WINT4 GEMM kernel计算结果
output = wint2_gemm(activation, packed_weight.T, scale)

dequant_weight = dequantize_from_int2(packed_weight, scale, weight.shape)
# 使用反量化后的权重计算参考结果
ref_output = paddle.matmul(activation.astype('float32'), 
                            dequant_weight.T.astype('float32'))

# 计算相对误差
rel_error = paddle.abs(output - ref_output) / (paddle.abs(ref_output) + 1e-7)
max_rel_error = paddle.max(rel_error)
avg_rel_error = paddle.mean(rel_error)

print(f"最大相对误差: {max_rel_error.item():.6f}")
print(f"平均相对误差: {avg_rel_error.item():.6f}")


CompilationError: at 79:61:        unpacked_1 = (w_packed >> 4) & 0b11
        unpacked_0 = (w_packed >> 6) & 0b11


        # 反量化权重
        w_float_0 = (unpacked_0.to(tl.float16) - 1) * scale
        w_float_1 = (unpacked_1.to(tl.float16) - 1) * scale
        w_float_2 = (unpacked_0.to(tl.float16) - 1) * scale
        w_float_3 = (unpacked_1.to(tl.float16) - 1) * scale

        # 累加结果
        accumulator += tl.dot(a.to(tl.float32), w_float_0.to(tl.float32))
                                                             ^
AssertionError('First input shape ([constexpr[16], constexpr[32]]) and second input shape [constexpr[8], constexpr[16]] are not compatible for matmul (second index of first shape (32) must be equal to first index of second shape (8)')

In [12]:
import paddle
import numpy as np
ZERO_POINT = 1
def quantize_to_int2(tensor, scale=None):
    if scale is None:
        # 计算量化scale，使用绝对值最大值除以(2^2-1)=3
        scale = float(paddle.max(paddle.abs(tensor)).item() / 3)
    # 确保scale不为0
    scale = max(scale, 1e-8)
    # 量化到0-3范围
    quant = paddle.clip(paddle.round(tensor / scale) + ZERO_POINT, 0, 3).astype('uint8')
    # 重塑张量以准备打包
    new_shape = list(quant.shape[:-1]) + [-1, 4]
    quant = quant.reshape(new_shape)
    # 打包4个int2值到一个int8
    quant_np = quant.numpy()
    packed = (quant_np[..., 0] << 6) | \
             (quant_np[..., 1] << 4) | \
             (quant_np[..., 2] << 2) | \
             quant_np[..., 3]
    packed = paddle.to_tensor(packed, dtype='uint8')
    
    return packed, scale

def dequantize_from_int2(packed, scale, original_shape):
    # 解包int2值
    packed_np = packed.numpy()
    unpacked_3 = packed_np & 0b11
    unpacked_2 = (packed_np >> 2) & 0b11
    unpacked_1 = (packed_np >> 4) & 0b11
    unpacked_0 = (packed_np >> 6) & 0b11
    # 堆叠解包的值
    unpacked = paddle.stack([paddle.to_tensor(unpacked_0), paddle.to_tensor(unpacked_1), paddle.to_tensor(unpacked_2), paddle.to_tensor(unpacked_3)], axis=-1)
    # 重塑为原始形状
    unpacked = unpacked.reshape(original_shape[:-1] + [-1])
    # 反量化
    tensor = (unpacked.astype('float32')- ZERO_POINT) * scale
    return tensor


# 创建测试张量
x = paddle.randn([2, 8], dtype='float32')
print("原始张量:")
print(x)

# 量化
packed, scale = quantize_to_int2(x)
print("\n打包后的张量:")
print(packed)
print("量化scale:", scale)

# 反量化
reconstructed = dequantize_from_int2(packed, scale, x.shape)
print("\n重建的张量:")
print(reconstructed)

# # 计算误差
# error = paddle.abs(x - reconstructed)
# print("\n最大绝对误差:", paddle.max(error).item())
# print("平均绝对误差:", paddle.mean(error).item())

原始张量:
Tensor(shape=[2, 8], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       [[-1.37485266,  0.25441313,  0.23476355,  0.57191873, -1.08756435,
         -0.89395362, -1.17351019,  0.37844741],
        [ 1.74194264,  0.32449320,  1.28417206, -1.27401042,  1.71044135,
          0.39335835, -0.70837349,  0.77900213]])

打包后的张量:
Tensor(shape=[2, 2], dtype=uint8, place=Place(gpu:0), stop_gradient=True,
       [[22 , 2  ],
        [236, 226]])
量化scale: 0.5806475480397543

重建的张量:
Tensor(shape=[2, 8], dtype=float32, place=Place(gpu:0), stop_gradient=True,
       [[-0.58064753,  0.        ,  0.        ,  0.58064753, -0.58064753,
         -0.58064753, -0.58064753,  0.58064753],
        [ 1.16129506,  0.58064753,  1.16129506, -0.58064753,  1.16129506,
          0.58064753, -0.58064753,  0.58064753]])
