In [1]:
import torch

In [1]:
import torch
import numpy as np

def gptq_quantize(tensor, num_bits=8):
    qmin = 0
    qmax = (1 << num_bits) - 1
    min_val, max_val = tensor.min(), tensor.max()

    # Scale and zero point
    scale = (max_val - min_val) / (qmax - qmin)
    zero_point = qmin - min_val / scale

    # Perform quantization
    quantized = torch.clamp(torch.round(tensor / scale + zero_point), qmin, qmax)
    dequantized = (quantized - zero_point) * scale

    return quantized, dequantized, scale, zero_point

def truncate_to_4bit(quantized_8bit):
    # Convert to integers for bit manipulation
    int_8bit = quantized_8bit.to(torch.int32)

    # Extract the 4 most significant bits using bitwise shift
    int_4bit = (int_8bit >> 4) & 0xF

    # Convert back to float for further calculations
    scale_4bit = 15 / 255.0  # Since we use the MSBs, rescale to 4-bit range
    dequantized_4bit = int_4bit * scale_4bit
    return int_4bit, dequantized_4bit

def calculate_error(original, reconstructed):
    return torch.mean((original - reconstructed) ** 2)

def main():
    # Step 1: Generate a random FP16 tensor
    tensor = torch.randn(10, dtype=torch.float16)
    print("Original FP16 Tensor:", tensor)

    # Step 2: 8-bit Quantization
    quantized_8bit, dequantized_8bit, scale_8bit, zp_8bit = gptq_quantize(tensor)
    print("8-bit Quantized Tensor:", quantized_8bit)
    print("8-bit Dequantized Tensor:", dequantized_8bit)

    error_8bit = calculate_error(tensor, dequantized_8bit)
    print("8-bit Quantization Error:", error_8bit)

    # Step 3: Convert to 4-bit using MSBs
    quantized_4bit, dequantized_4bit = truncate_to_4bit(quantized_8bit)
    print("4-bit Quantized Tensor:", quantized_4bit)
    print("4-bit Dequantized Tensor:", dequantized_4bit)

    error_4bit = calculate_error(tensor, dequantized_4bit)
    print("4-bit Quantization Error:", error_4bit)

if __name__ == "__main__":
    main()


Original FP16 Tensor: tensor([ 5.8154e-01, -9.8145e-01, -1.5293e+00,  6.4014e-01,  1.6346e-03,
         2.4646e-01, -3.2275e-01, -4.3237e-01, -2.0977e+00, -1.4443e+00],
       dtype=torch.float16)
8-bit Quantized Tensor: tensor([250., 104.,  53., 255., 196., 218., 165., 155.,   0.,  61.],
       dtype=torch.float16)
8-bit Dequantized Tensor: tensor([ 0.5864, -0.9810, -1.5283,  0.6401,  0.0067,  0.2429, -0.3262, -0.4333,
        -2.0977, -1.4424], dtype=torch.float16)
8-bit Quantization Error: tensor(7.9870e-06, dtype=torch.float16)
4-bit Quantized Tensor: tensor([15,  6,  3, 15, 12, 13, 10,  9,  0,  3], dtype=torch.int32)
4-bit Dequantized Tensor: tensor([0.8824, 0.3529, 0.1765, 0.8824, 0.7059, 0.7647, 0.5882, 0.5294, 0.0000,
        0.1765])
4-bit Quantization Error: tensor(1.4386)


In [3]:
import torch
import numpy as np

def gptq_quantize(tensor, num_bits=8):
    qmin = 0
    qmax = (1 << num_bits) - 1
    min_val, max_val = tensor.min(), tensor.max()

    # Scale and zero point
    scale = (max_val - min_val) / (qmax - qmin)
    zero_point = qmin - min_val / scale

    # Perform quantization
    quantized = torch.clamp(torch.round(tensor / scale + zero_point), qmin, qmax)
    dequantized = (quantized - zero_point) * scale

    return quantized, dequantized, scale, zero_point

def truncate_to_4bit(quantized_8bit):
    # Convert to integers for bit manipulation
    int_8bit = quantized_8bit.to(torch.int32)

    # Extract the 4 most significant bits using bitwise shift
    int_4bit = (int_8bit >> 4) & 0xF

    # Convert back to float for further calculations
    scale_4bit = 15 / 255.0  # Since we use the MSBs, rescale to 4-bit range
    dequantized_4bit = int_4bit * scale_4bit
    return int_4bit, dequantized_4bit

def quantize_4bit_direct(tensor):
    return gptq_quantize(tensor, num_bits=4)

def calculate_error(original, reconstructed):
    return torch.mean((original - reconstructed) ** 2)

def main():
    # Step 1: Generate a random FP16 tensor
    tensor = torch.randn(100, dtype=torch.float16)
    print("Original FP16 Tensor:", tensor)

    # Step 2: 8-bit Quantization
    quantized_8bit, dequantized_8bit, scale_8bit, zp_8bit = gptq_quantize(tensor)
    print("8-bit Quantized Tensor:", quantized_8bit)
    print("8-bit Dequantized Tensor:", dequantized_8bit)

    error_8bit = calculate_error(tensor, dequantized_8bit)
    print("8-bit Quantization Error:", error_8bit)

    # Step 3: Convert to 4-bit using MSBs
    quantized_4bit, dequantized_4bit = truncate_to_4bit(quantized_8bit)
    print("4-bit Quantized Tensor (Truncated):", quantized_4bit)
    print("4-bit Dequantized Tensor (Truncated):", dequantized_4bit)

    error_4bit_trunc = calculate_error(tensor, dequantized_4bit)
    print("4-bit Quantization Error (Truncated):", error_4bit_trunc)

    # Step 4: Direct 4-bit Quantization
    quantized_4bit_direct, dequantized_4bit_direct, scale_4bit_direct, zp_4bit_direct = quantize_4bit_direct(tensor)
    print("4-bit Quantized Tensor (Direct):", quantized_4bit_direct)
    print("4-bit Dequantized Tensor (Direct):", dequantized_4bit_direct)

    error_4bit_direct = calculate_error(tensor, dequantized_4bit_direct)
    print("4-bit Quantization Error (Direct):", error_4bit_direct)

if __name__ == "__main__":
    main()


Original FP16 Tensor: tensor([ 0.2395, -1.2275,  0.8218, -0.3242, -0.7002,  0.8267, -0.9517, -1.2881,
         1.6172, -0.4983, -0.5249,  0.6733, -1.3516,  1.8643, -0.3481,  0.4285,
         0.8452, -0.5103, -0.6406, -1.3027,  0.3516, -1.0391, -0.0948,  3.4883,
         0.0701,  0.3892,  0.2957, -0.3340, -0.8672, -0.3958,  0.4685,  0.5503,
        -1.2285,  0.2272,  0.6924,  0.2269,  0.7070,  1.2891,  1.6953,  0.1940,
        -0.3435, -0.5029,  1.0088,  0.0308, -1.0684, -0.1390,  1.3232, -0.9883,
        -1.5957,  1.6719, -0.1266,  0.1552, -0.4001,  0.4436, -1.4746,  1.1240,
        -0.3530,  1.0449,  0.1497, -0.8423,  0.5576,  0.2115,  0.8115, -1.2686,
        -0.3835, -1.5059,  0.2529, -0.9888, -1.2100, -0.2747, -0.1814, -0.1918,
         0.6865,  0.9927,  0.9966,  1.9365, -0.5586, -0.3181, -0.2612,  1.8379,
         1.8428, -1.5352,  0.1803,  1.5244,  0.8569, -0.6719, -0.2135,  1.0977,
        -0.5669, -0.3411, -0.0305, -2.2754,  0.3923,  0.2622, -0.7246,  0.9468,
         0.8062, -