In [1]:
import os

# os.environ['TRITON_INTERPRET'] = '1' 

import triton
import triton.language as tl
import torch



In [15]:
@triton.jit
def grouped_matmul_k(
    x_ptr, l1_ptr, xl1_ptr,x_quant_ptr , scale_ptr,
    m,n,k,
    bm: tl.constexpr, bn: tl.constexpr, bk: tl.constexpr, group_sz: tl.constexpr , bk_quant: tl.constexpr , group_size: tl.constexpr
):
    
    #mat mul 
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    num_pid_m, num_pid_n = tl.num_programs(0), tl.num_programs(1)
    pid_m_new , pid_n_new=tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, group_sz)  # Weirdness: tl.swizzle2d doesn't work when simulating on CPU
    # pid_m_new , pid_n_new=pid_m,pid_n

    acc = tl.zeros((bm, bn), dtype=tl.float32)
    x_ptr_offset=k*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1))+ tl.expand_dims(tl.arange(0,bk), 0)
    l1_ptr_offset=(pid_n_new*bn+tl.expand_dims(tl.arange(0,bn), 0))+tl.expand_dims(tl.arange(0,bk), 1)*n
    x_ptr_mask=k*tl.expand_dims(pid_m_new*bm+tl.arange(1,bm+1),1)
    x_ptr_mask=tl.where(x_ptr_mask<m*k , x_ptr_mask , m*k)
    l1_mask=(k-1)*n+pid_n_new*bn+tl.expand_dims(tl.arange(1,bn+1), 0)
    l1_mask=tl.where(l1_mask<k*n , l1_mask , k*n)
    for i in range ( 0 , tl.cdiv(k , bk)):
        # print(i,l1_ptr_offset)
        x_loaded=tl.load(x_ptr+x_ptr_offset ,x_ptr_offset<x_ptr_mask ,0)
        l1_loaded=tl.load(l1_ptr+l1_ptr_offset ,l1_ptr_offset<l1_mask ,0)
        acc += tl.dot(x_loaded, l1_loaded)
        x_ptr_offset+=bk
        l1_ptr_offset+=n*bk
    xl1_offset=n*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1))+ pid_n_new*bn+tl.expand_dims(tl.arange(0,bn), 0)
    xl1_mask=n*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1)+1)
    acc=acc.to(tl.float16)
    tl.store(xl1_ptr+xl1_offset , acc , xl1_offset<xl1_mask)

    #quant
    xptr_quant_start_point=k*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1))+ pid_n_new*bk_quant
    x_ptr_quant_offset_1=xptr_quant_start_point + 2*tl.expand_dims(tl.arange(0,group_size), 0)
    x_ptr_quant_offset_2=xptr_quant_start_point + 2*tl.expand_dims(tl.arange(0,group_size), 0)+1
    x_ptr_quant_mask=k*tl.expand_dims(pid_m_new*bm+tl.arange(1,bm+1),1)
    x_ptr_quant_mask=tl.where(x_ptr_quant_mask<m*k , x_ptr_quant_mask , m*k)
    
    x_quant_store_offset=tl.cdiv(k,2)*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1))+pid_n_new*bk_quant+tl.expand_dims(tl.arange(0,group_size), 0)
    x_quant_store_mask=tl.cdiv(k,2)*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1)+1)
    
    scale_quant_store_offset=tl.cdiv(k,2*group_size)*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1))+pid_n_new*bk_quant
    scale_quant_store_mask=tl.cdiv(k,(2*group_size))*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1)+1)
    if pid_n_new==0:
        for i in range ( 0,tl.cdiv(k , 2*group_size)):
            x_loaded_1=tl.load(x_ptr+x_ptr_quant_offset_1 ,x_ptr_quant_offset_1<x_ptr_quant_mask ,0)
            x_loaded_2=tl.load(x_ptr+x_ptr_quant_offset_2 ,x_ptr_quant_offset_2<x_ptr_quant_mask ,0)
            max_val_1=tl.max(tl.abs(x_loaded_1) , axis=1 , keep_dims=True)
            max_val_2=tl.max(tl.abs(x_loaded_2) , axis=1 , keep_dims=True)
            max_val=tl.where(max_val_1>max_val_2 , max_val_1 , max_val_2)
            scale = 7.0 / (max_val + 1e-6)
            scaled_1 = x_loaded_1 * scale
            scaled_2 = x_loaded_2 * scale
            clamped_vals_1 = tl.clamp(scaled_1, -8.0, 7.0)
            clamped_vals_2 = tl.clamp(scaled_2, -8.0, 7.0)
            int4_vals_1 = tl.where(clamped_vals_1 >= 0,
                                 tl.floor(clamped_vals_1 + 0.5),
                                 tl.ceil(clamped_vals_1 - 0.5)).to(tl.int8)
                
            int4_vals_2 = tl.where(clamped_vals_2 >= 0,
                                 tl.floor(clamped_vals_2 + 0.5),
                                 tl.ceil(clamped_vals_2 - 0.5)).to(tl.int8)
            int8block = int4_vals_1 & 0x0F  # Lower 4 bits
            int8block2 = int4_vals_2 & 0x0F  # Upper 4 bits
            packed = (int8block2.to(tl.int8) << 4) | int8block.to(tl.int8)
            packed=packed.to(tl.int8)
            tl.store(x_quant_ptr +x_quant_store_offset,packed  , x_quant_store_offset<x_quant_store_mask)
            tl.store(scale_ptr+scale_quant_store_offset ,max_val ,  scale_quant_store_offset<scale_quant_store_mask)
    
            x_ptr_quant_offset_1+=2*group_size
            x_ptr_quant_offset_2+=2*group_size
            x_quant_store_offset+=group_size
            scale_quant_store_offset+=1
        
        

In [16]:
def cdiv(a,b):
    return (a+b-1) //b
def matmul(x1, l1):
    assert x1.shape[1] == l1.shape[0], "k should be same "

    m=x1.shape[0]
    n=l1.shape[1]
    k=l1.shape[0]

    batch_size=16
    half_quant_group_size=32
    bk_quant=max(half_quant_group_size*2 , cdiv(k*batch_size,max(n , batch_size)))
    grid = lambda meta: (triton.cdiv(m, meta['bm']),  triton.cdiv(n, meta['bn']))
    xl1 = torch.empty((m, n), dtype=torch.float16).cuda().contiguous()
    x1_quant = torch.zeros((m, cdiv(k,2)), dtype=torch.int8).cuda().contiguous()
    x1_quant_scale=torch.zeros((m, cdiv(k,half_quant_group_size*2)), dtype=torch.float16).cuda().contiguous()

    # print(x1_quant_scale.shape , bk_quant)
    xl1 = torch.empty((m, n), dtype=torch.float16).cuda().contiguous()
    # Launch kernel
    
    grouped_matmul_k[grid](
        x_ptr=x1,l1_ptr=l1,xl1_ptr=xl1,
        
        x_quant_ptr=x1_quant , scale_ptr=x1_quant_scale,
        
        m=m,n=n,k=k,
        bm=batch_size , bn=batch_size , bk=batch_size , group_sz=batch_size , bk_quant=bk_quant,group_size=half_quant_group_size
    )
    return xl1 , x1_quant , x1_quant_scale

In [17]:
m,r,k=160 , 128, 620
x1 = torch.randn((m, k),dtype=torch.float16).cuda().contiguous()
l1 = torch.randn((k, r), dtype=torch.float16).cuda().contiguous()

triton_output , triton_quant ,x1_quant_scale =matmul(x1,l1)
torch_output=x1@l1
if torch.allclose(triton_output, torch_output, atol=5e-2, rtol=0):
    print('you rock')
else:
    print("❌ Triton and Torch differ")


print(torch.max(triton_output-torch_output))
print(torch.mean(triton_output-torch_output))

you rock
tensor(0.0078, device='cuda:0', dtype=torch.float16)
tensor(-1.7881e-07, device='cuda:0', dtype=torch.float16)


In [18]:
x1_quant_scale

tensor([[2.1660, 2.7773, 2.6387,  ..., 2.4297, 2.1875, 1.6777],
        [2.6484, 2.4004, 2.5938,  ..., 3.4727, 2.7715, 2.7930],
        [2.4102, 2.5664, 2.7480,  ..., 2.2910, 2.9023, 3.2246],
        ...,
        [2.6152, 2.4766, 3.0020,  ..., 3.4102, 2.5996, 2.5664],
        [2.7988, 1.9678, 2.8809,  ..., 2.0430, 2.3379, 2.6367],
        [2.9902, 2.6641, 2.5430,  ..., 2.6621, 2.6309, 3.6484]],
       device='cuda:0', dtype=torch.float16)

In [19]:
import torch
import math

def grouped_matmul_quant_pytorch(x1, l1, group_size=32):
    """
    Performs matrix multiplication and quantizes x1 in PyTorch.
    
    Args:
        x1 (torch.Tensor): Input tensor of shape (m, k), dtype float
        l1 (torch.Tensor): Weight tensor of shape (k, n), dtype float
        group_size (int): Number of elements per quantization group (default: 32)
    
    Returns:
        xl1 (torch.Tensor): Result of x1 @ l1, shape (m, n)
        x1_quant (torch.Tensor): Quantized x1, shape (m, ceil(k / 2)), dtype int8
        scales (torch.Tensor): Scales per group, shape (m, ceil(k / (2 * group_size)))
    """
    m, k = x1.shape
    _, n = l1.shape
    device = x1.device
    assert x1.shape[1] == l1.shape[0], "k dimensions must match for matrix multiplication"

    # 1. Matrix Multiplication
    xl1 = torch.matmul(x1, l1)

    # 2. Quantization of x1
    # Number of groups per row
    num_groups = math.ceil(k / (2 * group_size))
    padded_k = num_groups * 2 * group_size

    # Pad x1 to a multiple of 2 * group_size along the k dimension
    x1_padded = torch.nn.functional.pad(x1, (0, padded_k - k), mode='constant', value=0)
    # Reshape into (m, num_groups, 2 * group_size)
    x1_groups = x1_padded.view(m, num_groups, 2 * group_size)

    # Compute maximum absolute value per group
    max_vals = torch.max(torch.abs(x1_groups+1e-6), dim=2).values  # Shape: (m, num_groups)
    scales = 7.0 / (max_vals + 1e-6)                   # Shape: (m, num_groups)

    # Scale and clamp
    scaled = x1_groups * scales.unsqueeze(2)           # Shape: (m, num_groups, 2 * group_size)
    clamped = torch.clamp(scaled, -8.0, 7.0)

    # Round to nearest integer, with ties rounding away from zero
    rounded = torch.where(
        clamped >= 0,
        torch.floor(clamped + 0.5),
        torch.ceil(clamped - 0.5)
    )
    int8_vals = rounded.to(torch.int8)                 # Shape: (m, num_groups, 2 * group_size)

    # Trim to original k and prepare for packing
    q_vals = int8_vals.view(m, -1)[:, :k]              # Shape: (m, k)
    num_pairs = k // 2
    x1_quant = torch.zeros((m, math.ceil(k / 2)), dtype=torch.int8, device=device)

    # Pack pairs of int4 values into int8
    if num_pairs > 0:
        q0 = q_vals[:, 0:2*num_pairs:2]                # Even indices
        q1 = q_vals[:, 1:2*num_pairs:2]                # Odd indices
        packed_pairs = (
            (torch.bitwise_and(q1, 0x0F).to(torch.uint8) << 4) |
            torch.bitwise_and(q0, 0x0F).to(torch.uint8)
        )
        x1_quant[:, :num_pairs] = packed_pairs.to(torch.int8)
    if k % 2 == 1:
        q_last = q_vals[:, -1]
        x1_quant[:, -1] = torch.bitwise_and(q_last, 0x0F).to(torch.uint8)

    return xl1, x1_quant, max_vals

In [20]:
import torch
import triton
import time

# Your Triton matmul function (assumed to be defined as in the query)
# def matmul(x1, l1): ...  # Returns xl1, x1_quant, x1_quant_scale

# Test parameters
m, k, n = 400,624, 400
group_size = 32
device = torch.device("cuda")

# Generate random inputs
x1 = torch.randn(m, k, dtype=torch.float16, device=device)
l1 = torch.randn(k, n, dtype=torch.float16, device=device)

# Warm-up runs
for _ in range(10):
    matmul(x1, l1)
    grouped_matmul_quant_pytorch(x1, l1, group_size)

# Time Triton kernel
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
    xl1_triton, x1_quant_triton, scales_triton = matmul(x1, l1)
torch.cuda.synchronize()
triton_time = (time.time() - start) / 100
print(f"Triton average time: {triton_time:.6f} seconds")

# Time PyTorch implementation
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
    xl1_pytorch, x1_quant_pytorch, scales_pytorch = grouped_matmul_quant_pytorch(x1, l1, group_size)
torch.cuda.synchronize()
pytorch_time = (time.time() - start) / 100
print(f"PyTorch average time: {pytorch_time:.6f} seconds")

# Verify correctness (optional)
print("Matrix multiplication matches:", torch.allclose(xl1_triton.float(), xl1_pytorch.float(), atol=5e-2, rtol=0))
print("Quantized values match:", torch.allclose(x1_quant_triton.float(), x1_quant_pytorch.float(), atol=5e-2, rtol=0))
print("Scales match:", torch.allclose(scales_triton.float(), scales_pytorch.float(), atol=5e-2, rtol=0))

Triton average time: 0.001160 seconds
PyTorch average time: 0.000438 seconds
Matrix multiplication matches: False
Quantized values match: False
Scales match: True


In [None]:
print(torch.mean(xl1_triton.float()))
print(torch.mean(xl1_triton.float()))
print(torch.mean(xl1_triton.float()-xl1_pytorch.float()))
print(torch.max(xl1_triton.float()-xl1_pytorch.float()))
xl1_triton.float()-xl1_pytorch.float()

In [None]:
# i=
x1_quant_pytorch.float()-x1_quant_triton.float()

In [None]:
scales_triton[-1]

In [None]:
even_Values = x1_quant_triton & 0x0F  # 0 , 2 , 4 , 8 ..
odd_Values = (x1_quant_triton >> 4) & 0x0F  # 1 , 2 , 3 , 4 ....
x1quant_raw=merge_even_odd(even_Values , odd_Values , m , k) 
x1quant_raw=torch.where(x1quant_raw<8 , x1quant_raw , x1quant_raw-16)
expanded_scale=expand_scale(scales_triton , m , k , 32)
x1quant=x1quant_raw * (expanded_scale/7.0  )
# print(expanded_scale[-1]/7.0)
i=9
j=15
print((x1quant[j][64*i:64*(i+1)] - x1[j][64*i:64*(i+1)])<scales_triton[j][i]/7)
print(x1_quant_pytorch.float()[j][32*i:32*(i+1)]-x1_quant_triton.float()[j][32*i:32*(i+1)])

In [None]:
x1_quant_pytorch.float()[j][64*i:64*(i+1)]

In [None]:
scales_pytorch[-1]

In [None]:
torch.max(x1[-1])

In [None]:
torch_output

In [None]:
triton_output

In [None]:
def merge_even_odd(even_Values, odd_Values, m, k):
    
    device = even_Values.device
    dtype = even_Values.dtype
    
    
    merged = torch.zeros((m, k), dtype=dtype, device=device)
    
    
    merged[:, 0::2] = even_Values  # Even indices
    
    merged[:, 1::2] = odd_Values   # Odd indices
    
    return merged

def expand_scale(x1_quant_scale, m, k, group_size):

    num_groups = x1_quant_scale.shape[1]  # k//(2*group_size)
    # Reshape to (m, num_groups, 1) and repeat across 2*group_size elements
    scale_expanded = x1_quant_scale.unsqueeze(-1).repeat(1, 1, 2 * group_size)
    # Flatten to (m, k), trimming excess if k is not perfectly divisible
    scale_expanded = scale_expanded.reshape(m, -1)[:, :k]
    return scale_expanded
    

even_Values = triton_quant & 0x0F  # 0 , 2 , 4 , 8 ..
odd_Values = (triton_quant >> 4) & 0x0F  # 1 , 2 , 3 , 4 ....
x1quant=merge_even_odd(even_Values , odd_Values , m , k) 
x1quant=torch.where(x1quant<8 , x1quant , x1quant-16)
expanded_scale=expand_scale(x1_quant_scale , m , k , 32)
x1quant=x1quant * (expanded_scale/7.0  )

In [None]:
x1quant[-1]

In [None]:
x1[-1]

In [None]:
triton_quant.shape , m , k



In [None]:
x1_quant.shape

In [None]:
x1

In [None]:
x1_quant_scale

In [None]:
print(torch.max(triton_output-torch_output))
print(torch.mean(triton_output-torch_output))

In [None]:
import random
for i in range(0,100):
    point1=random.randint(0,m-1)
    point2=random.randint(0,r-1)
    print(triton_output[point1][point2])
    print(torch_output[point1][point2])