In [None]:
from huggingface_hub import snapshot_download

PATH = "../QuEST-800M-INT4"
snapshot_download(repo_id="ISTA-DASLab/QuEST-800M-INT4", local_dir=PATH)

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

summary.json:   0%|          | 0.00/1.83k [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

main.pt:   0%|          | 0.00/12.3G [00:00<?, ?B/s]

'/nfs/scistore19/alistgrp/apanfero/QuEST/QuEST-800M-INT4'

In [2]:
import numpy as np
from scipy.stats import norm

def compute_mse(grid):
    """
    Compute the Mean Squared Error (MSE) for a given scaling parameter 'a' and number of quantization levels 'N'.
    """
    print(f"quant center: {grid}")
    q = [-np.inf] + [(grid[i] + grid[i+1]) / 2 for i in range(len(grid) - 1)] + [np.inf]  # Quantization boundaries
    print(f"q values: {q}")
    MSE = 0.0
    for i in range(len(grid)):
        left = q[i]
        right = q[i + 1]
        center = grid[i]

        # Probability of the interval
        P_i = norm.cdf(right) - norm.cdf(left)

        # First and second moments over the interval
        M1_i = norm.expect(lambda t: t, loc=0, scale=1, lb=left, ub=right)
        M2_i = norm.expect(lambda t: t**2, loc=0, scale=1, lb=left, ub=right)

        # MSE for the i-th interval
        E_i = M2_i - 2 * center * M1_i + center**2 * P_i
        MSE += E_i

    # Total MSE
    return MSE

def get_uniform_grid(a, N):
    return np.linspace(-a, a, N)

from scipy.optimize import minimize
from tqdm.auto import tqdm
bits = 4
N = 2**bits  # You can change this value as needed

# Objective function for minimization
def objective(a):
  return compute_mse(get_uniform_grid(a[0], N))

# Initial guess for 'a'
a0 = [2.0]

# Bounds for 'a' to ensure it's positive
bounds = [(0.1, 10.0)]

# Minimize the MSE
result = minimize(objective, a0, bounds=bounds, method='L-BFGS-B')

# Optimal scaling parameter and corresponding MSE
optimal_a = result.x[0]
minimum_mse = result.fun

print(f"bits: {bits}: optimal scaling parameter (a): {optimal_a}, Minimum MSE: {minimum_mse}")

quant center: [-2.         -1.73333333 -1.46666667 -1.2        -0.93333333 -0.66666667
 -0.4        -0.13333333  0.13333333  0.4         0.66666667  0.93333333
  1.2         1.46666667  1.73333333  2.        ]
q values: [-inf, np.float64(-1.8666666666666667), np.float64(-1.6), np.float64(-1.3333333333333335), np.float64(-1.0666666666666667), np.float64(-0.8), np.float64(-0.5333333333333333), np.float64(-0.2666666666666666), np.float64(0.0), np.float64(0.2666666666666666), np.float64(0.5333333333333332), np.float64(0.7999999999999998), np.float64(1.0666666666666667), np.float64(1.3333333333333335), np.float64(1.6), np.float64(1.8666666666666667), inf]
quant center: [-2.00000001 -1.73333334 -1.46666667 -1.20000001 -0.93333334 -0.66666667
 -0.4        -0.13333333  0.13333333  0.4         0.66666667  0.93333334
  1.20000001  1.46666667  1.73333334  2.00000001]
q values: [-inf, np.float64(-1.866666676), np.float64(-1.6000000079999999), np.float64(-1.3333333399999998), np.float64(-1.06666667

In [5]:
def get_fp4_grid(a:float=1):
    zeros = [+0, -0]
    normal = [sign * (1+m)/2 * 2**(e-1) for sign in [1, -1] for e in range(1,4) for m in range(1,3)]
    subnormal = [sign * (0+m) * 2**(-1) for m in range(1,2) for sign in [1, -1]]
    return a * np.array(sorted(zeros + normal + subnormal))

fp4_grid = get_fp4_grid()
print(fp4_grid)

# Objective function for minimization
def objective(a):
    return compute_mse(get_fp4_grid(a[0]))

# Initial guess for 'a'
a0 = [1.0]

# Bounds for 'a' to ensure it's positive
bounds = [(0.1, 10.0)]

# Minimize the MSE
result = minimize(objective, a0, bounds=bounds, method='L-BFGS-B')

# Optimal scaling parameter and corresponding MSE
optimal_a = result.x[0]
minimum_mse = result.fun

print(f"Optimal scaling parameter (a): {optimal_a}")
print(f"Minimum MSE: {minimum_mse}")

[-6.  -4.  -3.  -2.  -1.5 -1.  -0.5  0.   0.   0.5  1.   1.5  2.   3.
  4.   6. ]
quant center: [-6.  -4.  -3.  -2.  -1.5 -1.  -0.5  0.   0.   0.5  1.   1.5  2.   3.
  4.   6. ]
q values: [-inf, np.float64(-5.0), np.float64(-3.5), np.float64(-2.5), np.float64(-1.75), np.float64(-1.25), np.float64(-0.75), np.float64(-0.25), np.float64(0.0), np.float64(0.25), np.float64(0.75), np.float64(1.25), np.float64(1.75), np.float64(2.5), np.float64(3.5), np.float64(5.0), inf]
quant center: [-6.00000006 -4.00000004 -3.00000003 -2.00000002 -1.50000001 -1.00000001
 -0.5         0.          0.          0.5         1.00000001  1.50000001
  2.00000002  3.00000003  4.00000004  6.00000006]
q values: [-inf, np.float64(-5.00000005), np.float64(-3.500000035), np.float64(-2.500000025), np.float64(-1.7500000175), np.float64(-1.2500000125), np.float64(-0.7500000075), np.float64(-0.2500000025), np.float64(0.0), np.float64(0.2500000025), np.float64(0.7500000075), np.float64(1.2500000125), np.float64(1.7500000175

In [6]:
import numpy as np
from scipy.linalg import hadamard

def hadamard_transform_cpu(x):
    """
    纯 CPU 实现 Hadamard 变换（仅支持 2 的幂次维度）
    x: 形状为 [..., n] 的张量，n 必须是 2 的幂
    """
    n = x.shape[-1]
    assert (n & (n - 1)) == 0, f"维度 {n} 不是 2 的幂，Hadamard 矩阵仅支持 2^k 维度"

    H = hadamard(n, dtype=np.float32) / np.sqrt(n)  # 归一化
    return x @ H.T

In [7]:
import numpy as np

X = np.random.rand(32, 64)
print("矩阵 X 的形状:", X.shape)
xhat = hadamard_transform_cpu(X)
print("变换后的矩阵 xhat 的形状:", xhat.shape)

矩阵 X 的形状: (32, 64)
变换后的矩阵 xhat 的形状: (32, 64)


In [11]:
import torch
import numpy as np

X = np.random.rand(32, 64)
print("矩阵 X 的形状:", X.shape)
xhat = hadamard_transform_cpu(X)
print("变换后的矩阵 xhat 的形状:", xhat.shape)
# Convert numpy array xhat to torch.Tensor
x_had_torch = torch.from_numpy(xhat)

# Perform the calculation
std = torch.sqrt(torch.mean(x_had_torch**2, dim=-1, keepdim=True)) + 1e-8

print("x_had_torch 的形状:", x_had_torch.shape)

OPTIMAL_GAUSSIAN_SCALES = {
    1: 0.7978845587140913,
    1.585: 1.2240089519030855,
    2: 1.4935346200015913,
    3: 2.051068354131873,
    4: 2.513930578568423,
    5: 2.9160938834961225,
    6: 3.276597282593217,
    7: 3.6010497188221655,
    8: 3.884938678807525,
}

scale = OPTIMAL_GAUSSIAN_SCALES[4] * std
n_levels = 2 ** 4
step = 2 * scale / (n_levels - 1)
x_clip = torch.clamp(x_had_torch, -scale, scale)
xq = torch.round((x_clip + scale) / step)
xq_dequant = xq * step - scale
print(f"original xhat: {xhat}")
print(f"xq dequant: {xq_dequant}")

矩阵 X 的形状: (32, 64)
变换后的矩阵 xhat 的形状: (32, 64)
x_had_torch 的形状: torch.Size([32, 64])
original xhat: [[ 3.86878332  0.18963859 -0.07150455 ... -0.2114696  -0.04286694
   0.10988797]
 [ 4.00476104 -0.15742877 -0.41666549 ... -0.17256879  0.12637896
   0.05964679]
 [ 4.5204274   0.08438341 -0.21845222 ...  0.19240287  0.27501169
  -0.27515666]
 ...
 [ 4.29916205  0.15445018 -0.91097936 ...  0.01469263 -0.09612832
   0.16471343]
 [ 4.25988561  0.28805873  0.18535192 ...  0.29734724  0.13099063
  -0.31029949]
 [ 4.15420156 -0.26030808 -0.1824393  ...  0.1875418   0.08439724
   0.01126734]]
xq dequant: tensor([[ 1.4121,  0.2824, -0.0941,  ..., -0.2824, -0.0941,  0.0941],
        [ 1.4439, -0.0963, -0.4813,  ..., -0.0963,  0.0963,  0.0963],
        [ 1.5949,  0.1063, -0.3190,  ...,  0.1063,  0.3190, -0.3190],
        ...,
        [ 1.5430,  0.1029, -0.9258,  ...,  0.1029, -0.1029,  0.1029],
        [ 1.5289,  0.3058,  0.1019,  ...,  0.3058,  0.1019, -0.3058],
        [ 1.5109, -0.3022, -0.1007,

In [12]:
import torch
import numpy as np

# Ensure X is a torch tensor for comparison with xq_dequant
X_torch = torch.from_numpy(xhat)

# Calculate Mean Squared Error (MSE)
mse = torch.mean((X_torch - xq_dequant)**2)

# Calculate Mean Absolute Error (MAE)
mae = torch.mean(torch.abs(X_torch - xq_dequant))

print(f"MSE 误差: {mse.item()}")
print(f"MAE 误差: {mae.item()}")

MSE 误差: 0.10371601384252502
MAE 误差: 0.08587475992192264


In [None]:
import torch
from torch import nn
import torch.nn.functional as F

from fast_hadamard_transform import hadamard_transform

from models.quantization.base_linear import OPTIMAL_GAUSSIAN_SCALES, HadamardTrustQuantizer, HalfHadamardTrustQuantizer


def quantize_pack_hadamard_dense(x: torch.Tensor, quantizer: HadamardTrustQuantizer):
    assert quantizer.centered
    x_had = hadamard_transform(x.reshape(-1, 128), scale=2 ** (-7/2)).reshape(x.shape)

    std = torch.sqrt(torch.mean(x_had**2, dim=-1, keepdim=True)) + 1e-8
    scale = OPTIMAL_GAUSSIAN_SCALES[quantizer.bits] * std

    step = 2 * scale / (quantizer.n_levels - 1)
    x_clip = torch.clamp(x_had, -scale, scale)
    xq = torch.round((x_clip + scale) / step)

    assert xq.min() >= 0 and xq.max() < quantizer.n_levels
    return xq, scale, step
    # ^ note: xq is in rotated space!

def dequantize_dense(xq, scale, step):
    return xq * step - scale

weight = torch.rand(2, 128).cuda()
quantizer = HadamardTrustQuantizer(bits=4)
ref = quantizer(weight)
xq, scale, step = quantize_pack_hadamard_dense(weight, quantizer)
deq = dequantize_dense(xq, scale, step)

torch.testing.assert_close(hadamard_transform(ref.reshape(-1, 128), scale=2 ** (-7/2)).reshape(ref.shape), deq, rtol=1e-3, atol=1e-3)

In [None]:
from models.quantization.base_linear import QuantizedLinear

class Linear4bit(nn.Module):
    def __init__(self, quantizer_linear):
        super().__init__()

        assert isinstance(quantizer_linear.weight_quantizer, HadamardTrustQuantizer)
        assert isinstance(quantizer_linear.activation_quantizer, HadamardTrustQuantizer)

        self.activation_quantizer = quantizer_linear.activation_quantizer

        wq = dequantize_dense(*quantize_pack_hadamard_dense(quantizer_linear.weight, quantizer_linear.weight_quantizer))
        self.register_buffer("wq", wq)
        self.bias = quantizer_linear.bias

    def forward(self, x):
        x = dequantize_dense(*quantize_pack_hadamard_dense(x, self.activation_quantizer))
        return F.linear(x, self.wq, self.bias)


def replace_linears(model):
    for name, module in model.named_children():
        if isinstance(module, QuantizedLinear):
            model._modules[name] = Linear4bit(module)
        else:
            replace_linears(module)
    return model

In [None]:
class PseudoDdp(nn.Module):
    def __init__(self, model):
        super().__init__()
        self._orig_mod = nn.ModuleDict({
            "module": model,
        })

class PseudoLoader:
    def load_state_dict(self, *args, **kwargs):
        pass

model = PseudoDdp(get_model(DotDict(config['args'])))
model.load_state_dict(torch.load(f"{PATH}/main.pt"))
model = model.cuda()
model = model._orig_mod["module"]
model = replace_linears(model)

  model.load_state_dict(torch.load(f"{PATH}/main.pt"))


In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

In [None]:
def generate_text_greedily(model, tokenizer, prompt, max_length=50, device='cuda'):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids, get_logits=True)
            logits = outputs['logits'][:, -1, :]

        next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1)
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

generated_text = generate_text_greedily(model, tokenizer, "Hi!", max_length=20)
print(generated_text)


Hi! Sign in to let us know how The Coffee House was?
by jennifer1


In [None]:
numel = 0
for name, param in model.named_buffers():
    numel += param.numel()
    # print(name, param.numel())

print(numel/1e6)

822.083584
