In [1]:
import os

# os.environ['TRITON_INTERPRET'] = '1' 

import triton
import triton.language as tl
import torch



@triton.jit
def tryfunc():
    x1 = torch.randn(160, dtype=torch.float32).cuda()  # Generate in float32 on GPU
    x1 = (x1 * 25.5 + 127.5).clamp(0, 255).to(torch.uint8).contiguous()  # Scale to [0, 255] and convert
    grid = lambda meta: (triton.cdiv(100, meta['bm'])  , )
    batchsize=16
    halfbatchsize=8
    print(x1)
    exp[grid](x1 , bm=batchsize , bmh=halfbatchsize)
@triton.jit
def exp(
    xl1_ptr,
    bm:tl.constexpr,bmh:tl.constexpr,
):
    # print(bm)
    offs_k = tl.arange(0, bm)
    offs_k=offs_k//2
    print(offs_k)
    off_k_half=tl.arange(0, bm)
    shifter = (off_k_half % 2) * 4
    b=tl.load(xl1_ptr +offs_k , offs_k<100 ,0)
    print(shifter[None, :])
    print(b)
    b = (b >> shifter[None:,]) & 0x0F 
    print(b)

In [2]:




@triton.jit
def dequant_vertical(
    tensor_ptr,
    bn,bk,
    stride,
    tensor_startpoint,
    m,n , 
    scale_ptr,
    scale_start,
    scale_group_size,
):
    offset_tensor = tensor_startpoint + tl.expand_dims(tl.arange(0,bn) , 0)+(stride)*(tl.expand_dims(tl.arange(0,bk) , 1)//2)
    offset_tensor_mask=tensor_startpoint+(n)*((tl.expand_dims(tl.arange(0,bk) , 1))//2+1)
    offset_tensor_mask=tl.where(offset_tensor_mask<m*n//2 ,offset_tensor_mask , m*n//2 )
    

    # print(offset_tensor , offset_tensor<offset_tensor_mask)
    offset_scale=scale_start + tl.expand_dims(tl.arange(0,bn) , 0)
    offset_scale_mask=scale_start + min(bn,n)
    
    tensor=tl.load(tensor_ptr+offset_tensor  ,offset_tensor<offset_tensor_mask , 0)
    scale=tl.load(scale_ptr+offset_scale  ,offset_scale<offset_scale_mask , 0)
    
    shifter=(0*tl.expand_dims(tl.arange(0,bn) , 0)+(tl.expand_dims(tl.arange(0,bk) , 1))%2)*4
    tensor = (tensor >> shifter) & 0x0F 
    tensor=tl.where(tensor<8 , tensor , tensor-16).to(tl.int8)
    # print("vertical" , tensor , scale)
    return tensor,scale


@triton.jit
def dequant_horizontal(
    tensor_ptr,
    bm , bk,
    stride,
    tensor_startpoint,
    m,n,
    scale_ptr,
    scale_start,
    scale_group_size
):
    offset_tensor=tensor_startpoint+(stride//2)*(tl.expand_dims(tl.arange(0,bm) , 1)) + tl.expand_dims(tl.arange(0,bk), 0)//2
    offset_tensor_mask=tensor_startpoint+stride//2*(tl.expand_dims(tl.arange(1,bm+1) , 1))
    offset_tensor_mask=tl.where(offset_tensor_mask<m*n//2 ,offset_tensor_mask , m*n//2 )
    
    offset_scale=scale_start+(tl.cdiv(stride,scale_group_size)*(tl.expand_dims(tl.arange(0,bm) , 1)))
    offset_scale_mask=scale_start+(tl.cdiv(stride,scale_group_size)*(tl.expand_dims(tl.arange(1,bm+1) , 1)))
    offset_scale_mask=tl.where(offset_scale_mask<m*tl.cdiv(stride,scale_group_size) ,offset_scale_mask , m*(tl.cdiv(stride,scale_group_size) ))
    
    # print(offset_scale , offset_scale_mask)
    tensor=tl.load(tensor_ptr+offset_tensor  ,offset_tensor<offset_tensor_mask , 0)
    scale=tl.load(scale_ptr+offset_scale  ,offset_scale<offset_scale_mask, 0)
    shift_offset=bk*(tl.expand_dims(tl.arange(0,bm) , 1)) + tl.expand_dims(tl.arange(0,bk), 0)
    shifter = (shift_offset % 2) * 4
    
    tensor = (tensor >> shifter) & 0x0F 
    tensor=tl.where(tensor<8 , tensor , tensor-16).to(tl.int8)
    # print("horizontal" ,  tensor , scale)
    return tensor , scale



@triton.jit
def grouped_matmul_k(
    xl1_ptr , x1_quant_ptr , x1_scale_ptr , w_quant_ptr , w_scale_ptr , l2_ptr , output_ptr,
    m,n,k ,r,
    bm: tl.constexpr, bn: tl.constexpr, bk: tl.constexpr, group_sz: tl.constexpr , quant_group_size: tl.constexpr
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    num_pid_m, num_pid_n = tl.num_programs(0), tl.num_programs(1)
    pid_m_new , pid_n_new=tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, group_sz)  # Weirdness: tl.swizzle2d doesn't work when simulating on CPU
    
    acc = tl.zeros((bm, bn), dtype=tl.float32)
    
    
    # x_ptr_offset=k*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1))+ tl.expand_dims(tl.arange(0,bk), 0)
    # x_ptr_mask=k*tl.expand_dims(pid_m_new*bm+tl.arange(1,bm+1),1)
    # x_ptr_mask=tl.where(x_ptr_mask<m*k , x_ptr_mask , m*k)
    # x_ptr_scale_offset=k*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1))+ tl.expand_dims(tl.arange(0,bk), 0)
    
    # w_ptr_offset=(pid_n_new*bn+tl.expand_dims(tl.arange(0,bn), 0))+tl.expand_dims(tl.arange(0,bk), 1)*n
    # w_mask=(k-1)*n+pid_n_new*bn+tl.expand_dims(tl.arange(1,bn+1), 0)
    # w_mask=tl.where(l1_mask<k*n , l1_mask , k*n)

    xl1_ptr_offset=r*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1))+ tl.expand_dims(tl.arange(0,bk), 0)
    x_l1_mask=r*tl.expand_dims(pid_m_new*bm+tl.arange(1,bm+1),1)
    x_l1_mask=tl.where(x_l1_mask<m*r , x_l1_mask , m*r)
    
    l2_offset=(pid_n_new*bn+tl.expand_dims(tl.arange(0,bn), 0))+tl.expand_dims(tl.arange(0,bk), 1)*n
    l2_mask=pid_n_new*bn+tl.expand_dims(tl.arange(1,bk+1), 1)*n
    l2_mask=tl.where(l2_mask<r*n , l2_mask , r*n)

    # print(l2_offset)
    startpoint_x=k//2*(pid_m_new*bm)
    startpoint_x_scale=pid_m_new*bm*tl.cdiv(k,quant_group_size)
    
    startpoint_w=pid_n_new*bn
    startpoint_w_scale=pid_n_new*bn
    
    acc = tl.zeros((bm, bn), dtype=tl.float32)
    zeros = tl.zeros((bm, bn), dtype=tl.float32)
    # print('pid_m_new' , pid_m_new,'pid_n_new',pid_n_new)
    # print("w_start" , startpoint_w )
    # print("startpoint_w_scale" , startpoint_w_scale )
    # print("x_start" , startpoint_x)
    for i in range ( 0 , tl.cdiv(k , bk)):
        
        dequant_x1 , scale_x1=dequant_horizontal(x1_quant_ptr ,  bm , bk , k ,startpoint_x , m , k ,x1_scale_ptr , startpoint_x_scale+(i*bk)//quant_group_size ,  quant_group_size)
        
        dequant_w , scale_w=dequant_vertical(w_quant_ptr ,  bk , bn , n ,startpoint_w , k , n , w_scale_ptr ,startpoint_w_scale+n*((i*bk)//quant_group_size) , quant_group_size)
    
        acc_x_w=  tl.dot(dequant_x1, dequant_w)
        # print(acc_x_w)
        acc_x_scale_w_scale = scale_x1*scale_w
        startpoint_x+=bk//2
        startpoint_w+=bk*n//2
        if i*bk<r:
            # print(xl1_ptr_offset , xl1_ptr_offset<x_l1_mask)
            xl1_loaded = tl.load(xl1_ptr+xl1_ptr_offset , xl1_ptr_offset<x_l1_mask , 0)
            l2_loaded = tl.load(l2_ptr+l2_offset , l2_offset<l2_mask , 0)
            # print(xl1_loaded , l2_loaded)
            acc_lora=  tl.dot(xl1_loaded, l2_loaded)
            acc+=acc_lora
            xl1_ptr_offset+=bk
            l2_offset+=n*bk
            l2_mask+=n*bk
            l2_mask=tl.where(l2_mask<r*n , l2_mask , r*n)
        acc+=tl.fma(acc_x_w , acc_x_scale_w_scale , zeros)
    
    answer_offset=n*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1))+ pid_n_new*bn+tl.expand_dims(tl.arange(0,bn), 0)
    answer_mask=n*(pid_m_new*bm + tl.expand_dims(tl.arange(0,bm), 1)+1)
    answer_mask=tl.where(answer_mask<m*n , answer_mask , m*n)
    tl.store(output_ptr+answer_offset , acc , answer_offset<answer_mask)    
    # acc=acc.to(tl.float16)
    

In [3]:
def cdiv(a,b):
    return (a+b-1) //b
def matmul(x1l1, x1Quant  ,x1scale , WQuant , Wscale , l2 ):
    m=x1l1.shape[0]
    n=l2.shape[1]
    k=2*x1Quant.shape[1]
    r=l2.shape[0]
    answer = torch.zeros((m, n), dtype=torch.float16).cuda().contiguous()
    batch_size=64
    quant_group_size=64
    
    grid = lambda meta: (triton.cdiv(m, meta['bm']),  triton.cdiv(n, meta['bn']))

    grouped_matmul_k[grid](
        xl1_ptr=x1l1 , x1_quant_ptr=x1Quant , x1_scale_ptr=x1scale , w_quant_ptr=WQuant , w_scale_ptr=Wscale , l2_ptr=l2 , 
    m=m,n=n,k=k ,r=r,
    bm= batch_size, bn= batch_size, bk= batch_size, group_sz= batch_size , quant_group_size= quant_group_size , output_ptr=answer
    )
    return answer



In [4]:
k=124
k& 0x0F,(k >> 4) & 0x0F

(12, 7)

In [5]:
m,k_half,n,r=10 , 4 , 4,4

x1_quant = torch.randn((m,k_half), dtype=torch.float16).cuda()  # Generate in float32 on GPU
x1_quant = (x1_quant * 25.5 + 127.5).clamp(0, 255).to(torch.uint8).contiguous()  # Scale to [0, 255] and convert
x1_scale = torch.randn((m,cdiv(2*k_half,64)), dtype=torch.float16).cuda().contiguous()  # Generate in float32 on GPU

w_quant = torch.randn((k_half , n), dtype=torch.float16).cuda()  # Generate in float32 on GPU
w_quant = (w_quant * 25.5 + 127.5).clamp(0, 255).to(torch.uint8).contiguous()  # Scale to [0, 255] and convert
w_scale = torch.randn((cdiv(2*k_half,64) , n), dtype=torch.float16).cuda().contiguous()  # Generate in float32 on GPU

x1l1 = torch.randn((m , r), dtype=torch.float16).cuda().contiguous()  # Generate in float32 on GPU
l2 = torch.randn((r , n), dtype=torch.float16).cuda().contiguous()  # Generate in float32 on GPU

matmul(x1l1 , x1_quant , x1_scale , w_quant , w_scale , l2 )

tensor([[   4.1484,    8.5312, -115.3125,  -15.0391],
        [   2.5039,   -2.8086,   -5.7070,   -6.1797],
        [   0.3042,   16.4219,  -58.0625,    4.8828],
        [   5.3203,  -35.6250,    5.3086,  -14.5234],
        [  -7.0742,   33.8750,   14.9609,   10.7109],
        [   3.2148,    1.0303,  -11.8828,   14.8047],
        [  14.2656,   67.1250,   -7.7500,   47.1250],
        [  11.6406,  -23.3594,  -68.6875,   -2.5684],
        [  -6.9727,   39.6562,   -5.5664,   15.5547],
        [   7.0273,   14.2109,    7.5547,   14.8594]], device='cuda:0',
       dtype=torch.float16)

In [6]:
def merge_even_odd(even_Values, odd_Values, m, k):
    
    device = even_Values.device
    dtype = even_Values.dtype
    
    
    merged = torch.zeros((m, k), dtype=dtype, device=device)
    
    
    merged[:, 0::2] = even_Values  # Even indices
    
    merged[:, 1::2] = odd_Values   # Odd indices
    
    return merged

def expand_scale(x1_quant_scale, m, k, group_size):

    num_groups = x1_quant_scale.shape[1]  # k//(2*group_size)
    # Reshape to (m, num_groups, 1) and repeat across 2*group_size elements
    scale_expanded = x1_quant_scale.unsqueeze(-1).repeat(1, 1, 2 * group_size)
    # Flatten to (m, k), trimming excess if k is not perfectly divisible
    scale_expanded = scale_expanded.reshape(m, -1)[:, :k]
    return scale_expanded
    

even_Values = x1_quant & 0x0F  # 0 , 2 , 4 , 8 ..
odd_Values = (x1_quant >> 4) & 0x0F  # 1 , 2 , 3 , 4 ....
x1quant=merge_even_odd(even_Values , odd_Values , m , 32) .to(torch.int8)
x1quant=torch.where(x1quant<8 , x1quant , x1quant-16)
expanded_scale=expand_scale(x1_scale , m , 32 , 32)
print(x1quant , expanded_scale)
# x1quant=x1quant * (expanded_scale  )

RuntimeError: The expanded size of the tensor (16) must match the existing size (4) at non-singleton dimension 1.  Target sizes: [10, 16].  Tensor sizes: [10, 4]

In [37]:
import torch
import random
import time
import torch

def dequantize_vertical_pytorch(tensor_quant, m, n, scale, quant_group_size):
    """
    Dequantizes a vertically tiled quantized tensor similar to dequant_vertical.
    tensor_quant: (k//2, n) quantized tensor (uint8, two 4-bit values per byte)
    scale: (k//quant_group_size, n) scaling factors
    """
    k = 2 * tensor_quant.shape[0]  # Full dimension after unpacking
    # Convert to int16 to prevent overflow during bit operations
    tensor_quant_int16 = tensor_quant.to(torch.int16)
    
    # Unpack all 4-bit values at once
    # Reshape to (k//2, n, 1) and expand to separate high and low bits
    low_bits = tensor_quant_int16 & 0x0F  # Lower 4 bits
    high_bits = (tensor_quant_int16 >> 4) & 0x0F  # Upper 4 bits
    
    # Stack and reshape: interleave low and high bits
    tensor = torch.stack([low_bits, high_bits], dim=1).reshape(k, n)
    
    # Convert to signed int8 (0-7 stay as is, 8-15 become -8 to -1)
    tensor = torch.where(tensor < 8, tensor, tensor - 16).to(torch.int8)
    
    # Apply scales
    scale_expanded = torch.repeat_interleave(scale, quant_group_size, dim=0)[:k]
    tensor = tensor.to(torch.float32) * scale_expanded
    
    return tensor

def dequantize_horizontal_pytorch(tensor_quant, m, k, scale, quant_group_size):
    """
    Dequantizes a horizontally tiled quantized tensor similar to dequant_horizontal.
    tensor_quant: (m, k//2) quantized tensor (uint8, two 4-bit values per byte)
    scale: (m, k//quant_group_size) scaling factors
    """
    # Convert to int16 to prevent overflow during bit operations
    tensor_quant_int16 = tensor_quant.to(torch.int16)
    
    # Unpack all 4-bit values at once
    low_bits = tensor_quant_int16 & 0x0F  # Lower 4 bits
    high_bits = (tensor_quant_int16 >> 4) & 0x0F  # Upper 4 bits
    
    # Stack and reshape: interleave low and high bits
    tensor = torch.stack([low_bits, high_bits], dim=2).reshape(m, k)
    
    # Convert to signed int8
    tensor = torch.where(tensor < 8, tensor, tensor - 16).to(torch.int8)
    
    # Apply scales
    scale_expanded = torch.repeat_interleave(scale, quant_group_size, dim=1)[:, :k]
    tensor = tensor.to(torch.float32) * scale_expanded
    
    return tensor

def matmul_pytorch(x1l1, x1Quant, x1scale, WQuant, Wscale, l2):
    """
    PyTorch equivalent of the grouped_matmul_k kernel.
    """
    m, r = x1l1.shape
    k = 2 * x1Quant.shape[1]  # Full k dimension after dequantization
    n = l2.shape[1]
    quant_group_size = 64  # Match your Triton's quant_group_size
    
    # Dequantize x1Quant (m x k//2) -> (m x k)
    x1_dequant = dequantize_horizontal_pytorch(x1Quant, m, k, x1scale, quant_group_size)
    # Dequantize WQuant (k//2 x n) -> (k x n)
    w_dequant = dequantize_vertical_pytorch(WQuant, k, n, Wscale, quant_group_size)

    # Perform matrix multiplication: x1 @ W
    result = x1_dequant @ w_dequant
    
    # Add LoRA contribution: xl1 @ l2
    # print(x1l1 , l2)
    lora = x1l1 @ l2
    result += lora
    
    return result

# Your Triton matmul function (unchanged)
def matmul_triton(x1l1, x1Quant, x1scale, WQuant, Wscale, l2):
    m = x1l1.shape[0]
    n = l2.shape[1]
    k = 2 * x1Quant.shape[1]
    r = l2.shape[0]
    answer = torch.zeros((m, n), dtype=torch.float16, device='cuda').contiguous()
    batch_size = 32
    quant_group_size = 64
    
    grid = lambda meta: (triton.cdiv(m, meta['bm']), triton.cdiv(n, meta['bn']))
    grouped_matmul_k[grid](
        xl1_ptr=x1l1, x1_quant_ptr=x1Quant, x1_scale_ptr=x1scale, w_quant_ptr=WQuant,
        w_scale_ptr=Wscale, l2_ptr=l2, output_ptr=answer,
        m=m, n=n, k=k, r=r,
        bm=batch_size, bn=batch_size, bk=batch_size, group_sz=batch_size,
        quant_group_size=quant_group_size
    )
    return answer

# Test and compare
def test_accuracy():
    # Generate sample inputs
    # m = random.randrange(50, 201, 2)  # Range: 50 to 200, step 2 (even numbers)
    # n = random.randrange(50, 201, 2)  # Range: 50 to 200, step 2
    # k = random.randrange(50, 201, 2)  # Range: 50 to 200, step 2
    # r = random.randrange(2, k, 2)     # Range: 2 to k-1, step 2, ensures r < k
    m,n,k,r=1000 , 100 , 1024 , 32
    print(f"Generated dimensions: m={m}, n={n}, k={k}, r={r}")
    quant_group_size = 64
    
    x1l1 = torch.randn(m, r, dtype=torch.float16, device='cuda').contiguous()
    x1Quant = torch.randint(0, 256, (m, k//2), dtype=torch.uint8, device='cuda').contiguous()  # 0-255 for uint8
    x1scale = torch.randn(m, cdiv(k,quant_group_size), dtype=torch.float32, device='cuda').contiguous()
    WQuant = torch.randint(0, 256, (k//2, n), dtype=torch.uint8, device='cuda').contiguous()
    Wscale = torch.randn(cdiv(k,quant_group_size), n, dtype=torch.float32, device='cuda').contiguous()
    l2 = torch.randn(r, n, dtype=torch.float16, device='cuda').contiguous()

    for _ in range(5):
        matmul_pytorch(x1l1, x1Quant, x1scale, WQuant, Wscale, l2)
        matmul_triton(x1l1, x1Quant, x1scale, WQuant, Wscale, l2)
    torch.cuda.synchronize()  # Ensure warm-up is complete

    # Time PyTorch
    pytorch_times = []
    for _ in range(10):  # Run 10 iterations for better averaging
        start = time.time()
        pytorch_result = matmul_pytorch(x1l1, x1Quant, x1scale, WQuant, Wscale, l2)
        torch.cuda.synchronize()  # Wait for GPU to finish
        end = time.time()
        pytorch_times.append(end - start)
    torch.cuda.synchronize()
    pytorch_avg_time = sum(pytorch_times) / len(pytorch_times)
    
    # Time Triton
    triton_times = []
    for _ in range(10):  # Run 10 iterations for better averaging
        start = time.time()
        triton_result = matmul_triton(x1l1, x1Quant, x1scale, WQuant, Wscale, l2)
        torch.cuda.synchronize()  # Wait for GPU to finish
        end = time.time()
        triton_times.append(end - start)
    torch.cuda.synchronize()
    triton_avg_time = sum(triton_times) / len(triton_times)
    
    # Print timing results
    print(f"PyTorch average time: {pytorch_avg_time:.6f} seconds")
    print(f"Triton average time: {triton_avg_time:.6f} seconds")
    speedup = pytorch_avg_time / triton_avg_time if triton_avg_time > 0 else float('inf')
    print(f"Speedup (PyTorch/Triton): {speedup:.2f}x")

    
    # Compute with PyTorch
    pytorch_result = matmul_pytorch(x1l1, x1Quant, x1scale, WQuant, Wscale, l2)
   
    # Compute with Triton
    triton_result = matmul_triton(x1l1, x1Quant, x1scale, WQuant, Wscale, l2)
   
    # Compare with relative tolerance
    diff = torch.abs(pytorch_result - triton_result.to(torch.float32))
    max_diff = diff.max().item()
    mean_diff = diff.mean().item()
    
    # Compute relative difference
    magnitude = torch.max(torch.abs(pytorch_result), torch.abs(triton_result.to(torch.float32)))
    
    relative_diff = diff / (magnitude + 1e-8)  # Add small epsilon to avoid division by zero
    max_relative_diff = relative_diff.max().item()
    mean_relative_diff = relative_diff.mean().item()
    
    # Find the coordinates of the maximum absolute difference
    max_diff_index = torch.argmax(diff)  # Get the flat index of the max difference
    coords = torch.unravel_index(max_diff_index, diff.shape)  # Convert to coordinates
    max_diff_point = tuple(coord.item() for coord in coords)  # Convert to tuple of integers
    
    # Print results including the problematic point
    print(f"Max absolute difference: {max_diff}")
    print(f"Mean absolute difference: {mean_diff}")
    print(f"Max relative difference: {max_relative_diff}")
    print(f"Mean relative difference: {mean_relative_diff}")
    print(f"Max magnitude: {magnitude.max().item()}")
    print(f"Mean magnitude: {magnitude.mean().item()}")
    print(f"Point with max absolute difference: {max_diff_point}")
    print(f"PyTorch value at this point: {pytorch_result[max_diff_point].item()}")
    print(f"Triton value at this point: {triton_result[max_diff_point].item()}")
    # Tolerance settings
    abs_tolerance = 1e-2  # Absolute tolerance (small for near-zero values)
    rel_tolerance = 1e-3  # Relative tolerance (0.1% of the magnitude)
    
    # Check if differences are within tolerance
    # Element-wise check: acceptable if diff <= abs_tolerance + rel_tolerance * magnitude
    tolerance = abs_tolerance + rel_tolerance * magnitude
    within_tolerance = torch.all(diff <= tolerance)
    
    if within_tolerance:
        print("Results match within tolerance!")
    else:
        print("Results differ beyond tolerance.")
    
    return diff
if __name__ == "__main__":
    diff=test_accuracy()

Generated dimensions: m=1000, n=100, k=1024, r=32
PyTorch average time: 0.000530 seconds
Triton average time: 0.000212 seconds
Speedup (PyTorch/Triton): 2.50x
Max absolute difference: 1.892578125
Mean absolute difference: 0.09210830181837082
Max relative difference: 0.04272996261715889
Mean relative difference: 0.00018081805319525301
Max magnitude: 4421.56201171875
Mean magnitude: 524.3851928710938
Point with max absolute difference: (548, 66)
PyTorch value at this point: 4165.892578125
Triton value at this point: 4164.0
Results match within tolerance!


In [17]:
diff.max()

tensor(1.9985, device='cuda:0')

In [None]:
k=33
k& 0x0F,(k >> 4) & 0x0F

In [None]:
for i in range(0,90):
    print(i)
    print(torch.max(diff[i]))