diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index cb7c8d0481..ed782a7350 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -955,6 +955,15 @@ def linear_forward_8da4w( precision, ) + print(f"w_dq dequantized: {w_dq}") + q_tensors = torch.load("/home/jackzhxng/torchrepos/executorch/fake_quantized_and_original_weights.pt") + correct_dequantized = q_tensors['q_after_quant_dequant'].to(torch.float32) + torch.testing.assert_close(correct_dequantized, w_dq) + snr = 20 * torch.log10(torch.norm(w_dq, p=2) / torch.norm(w_dq - correct_dequantized, p=2) + assert snr.item() == 0 + print("Weights quantized properly") + exit() + # x = x.to(torch.float16) # w_dq = w_dq.to(torch.float16) c = torch.nn.functional.linear(x, w_dq)