# Verify Training vs Eval Quantization Equivalence

This notebook checks that models are quantized identically during training and evaluation.

In [1]:
import sys
sys.path.insert(0, '..')

import torch
from transformers import AutoTokenizer
from peft import PeftModel
from torchao.quantization import quantize_
from torchao.quantization.qat import QATConfig
from config import TrainConfig, EvalConfig

  from .autonotebook import tqdm as notebook_tqdm


TMA benchmarks will be running without grid constant TMA descriptor.


In [2]:
# Use a small test config
train_cfg = TrainConfig(quant_type="int4")
eval_cfg = EvalConfig(quant_type="int4")

## 1. Verify QAT configs are identical

In [3]:
train_qat_config = train_cfg.get_qat_config()
eval_qat_config = eval_cfg.get_qat_config()

print("Training QAT config:")
print(f"  {train_qat_config}")

print("\nEval QAT config:")
print(f"  {eval_qat_config}")

# Check they match
assert str(train_qat_config) == str(eval_qat_config)
print("\n✓ QAT configs match!")

Training QAT config:
  QATConfig(base_config=Int4WeightOnlyConfig(group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False, zero_point_domain=<ZeroPointDomain.NONE: 3>, set_inductor_config=True, preserve_zero=None, int4_packing_format=<Int4PackingFormat.PLAIN: 'plain'>, int4_choose_qparams_algorithm=<Int4ChooseQParamsAlgorithm.TINYGEMM: 'tinygemm'>, version=2), activation_config=None, weight_config=None, step='prepare')

Eval QAT config:
  QATConfig(base_config=Int4WeightOnlyConfig(group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False, zero_point_domain=<ZeroPointDomain.NONE: 3>, set_inductor_config=True, preserve_zero=None, int4_packing_format=<Int4PackingFormat.PLAIN: 'plain'>, int4_choose_qparams_algorithm=<Int4ChooseQParamsAlgorithm.TINYGEMM: 'tinygemm'>, version=2), activation_config=None, weight_config=None, step='prepare')

✓ QAT configs match!


## 2. Compare model weights after QAT prepare

In [4]:
# Load two fresh models and apply QAT
model_train = train_cfg.load_model()
model_eval = eval_cfg.load_model()

quantize_(model_train, train_cfg.get_qat_config())
quantize_(model_eval, eval_cfg.get_qat_config())

Loading checkpoint shards:   0%|                                                                                                                                                         | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:  33%|████████████████████████████████████████████████▎                                                                                                | 1/3 [00:01<00:02,  1.11s/it]

Loading checkpoint shards:  67%|████████████████████████████████████████████████████████████████████████████████████████████████▋                                                | 2/3 [00:02<00:01,  1.12s/it]

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.32it/s]




Loading checkpoint shards:   0%|                                                                                                                                                         | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:  33%|████████████████████████████████████████████████▎                                                                                                | 1/3 [00:01<00:02,  1.08s/it]

Loading checkpoint shards:  67%|████████████████████████████████████████████████████████████████████████████████████████████████▋                                                | 2/3 [00:02<00:01,  1.10s/it]

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.35it/s]




In [5]:
# Compare all parameters
def compare_models(m1, m2, name1="model1", name2="model2"):
    """Compare two models parameter by parameter."""
    mismatches = []
    for (n1, p1), (n2, p2) in zip(m1.named_parameters(), m2.named_parameters()):
        assert n1 == n2, f"Parameter name mismatch: {n1} vs {n2}"
        if not torch.equal(p1, p2):
            max_diff = (p1 - p2).abs().max().item()
            mismatches.append((n1, max_diff))
    
    if mismatches:
        print(f"✗ {len(mismatches)} parameter mismatches between {name1} and {name2}:")
        for name, diff in mismatches[:10]:
            print(f"  {name}: max diff = {diff:.6e}")
    else:
        print(f"✓ All parameters match between {name1} and {name2}")
    return len(mismatches) == 0

compare_models(model_train, model_eval, "train_model", "eval_model")

✓ All parameters match between train_model and eval_model


True

## 3. Compare forward pass outputs

In [6]:
tokenizer = AutoTokenizer.from_pretrained(train_cfg.model_name)
test_input = tokenizer("The quick brown fox", return_tensors="pt").to(model_train.device)

with torch.no_grad():
    out_train = model_train(**test_input)
    out_eval = model_eval(**test_input)

logit_diff = (out_train.logits - out_eval.logits).abs()
print(f"Logit diff: max={logit_diff.max():.6e}, mean={logit_diff.mean():.6e}")

if logit_diff.max() < 1e-5:
    print("✓ Forward pass outputs match!")
else:
    print("✗ Forward pass outputs differ!")

Logit diff: max=0.000000e+00, mean=0.000000e+00
✓ Forward pass outputs match!


In [7]:
# Cleanup
del model_train, model_eval
torch.cuda.empty_cache()

## 4. Test full eval pipeline with a trained checkpoint

Compare the training student model (before final save) with eval's reconstruction.

In [8]:
# Point to a trained checkpoint
CHECKPOINT_PATH = "../dump/lmbda0_lr5e-6_beta0/checkpoint-1000"  # adjust as needed

from pathlib import Path
if not Path(CHECKPOINT_PATH).exists():
    print(f"Checkpoint not found at {CHECKPOINT_PATH}, skipping this test")
else:
    # Simulate eval.py's loading (lines 260-269)
    model_eval_style = eval_cfg.load_model()
    quantize_(model_eval_style, eval_cfg.get_qat_config())
    model_eval_style = PeftModel.from_pretrained(model_eval_style, CHECKPOINT_PATH)
    model_eval_style = model_eval_style.merge_and_unload()
    
    print("Model loaded in eval style (QAT prepare → LoRA → merge)")
    print(f"First layer weight dtype: {model_eval_style.model.layers[0].self_attn.q_proj.weight.dtype}")

Loading checkpoint shards:   0%|                                                                                                                                                         | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:  33%|████████████████████████████████████████████████▎                                                                                                | 1/3 [00:01<00:02,  1.12s/it]

Loading checkpoint shards:  67%|████████████████████████████████████████████████████████████████████████████████████████████████▋                                                | 2/3 [00:02<00:01,  1.14s/it]

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.30it/s]




Model loaded in eval style (QAT prepare → LoRA → merge)
First layer weight dtype: torch.bfloat16


## 5. Check: Does convert step change outputs?

Compare QAT fake-quant vs real INT4 after convert.

In [9]:
import copy

# Load fresh model with QAT prepare (fake quant)
model_fake = train_cfg.load_model()
quantize_(model_fake, train_cfg.get_qat_config())

# Clone and convert to real int4
model_real = copy.deepcopy(model_fake)
quantize_(model_real, QATConfig(train_cfg._get_torchao_config(), step="convert"))

print("Fake quant model (training):")
print(f"  q_proj type: {type(model_fake.model.layers[0].self_attn.q_proj)}")

print("\nReal int4 model (after convert):")
print(f"  q_proj type: {type(model_real.model.layers[0].self_attn.q_proj)}")

Loading checkpoint shards:   0%|                                                                                                                                                         | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:  33%|████████████████████████████████████████████████▎                                                                                                | 1/3 [00:01<00:02,  1.05s/it]

Loading checkpoint shards:  67%|████████████████████████████████████████████████████████████████████████████████████████████████▋                                                | 2/3 [00:02<00:01,  1.07s/it]

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.39it/s]




Fake quant model (training):
  q_proj type: <class 'torchao.quantization.qat.linear.FakeQuantizedLinear'>

Real int4 model (after convert):
  q_proj type: <class 'torch.nn.modules.linear.Linear'>


In [10]:
# Compare forward pass
test_input = tokenizer("The quick brown fox", return_tensors="pt").to(model_fake.device)

with torch.no_grad():
    out_fake = model_fake(**test_input)
    out_real = model_real(**test_input)

logit_diff = (out_fake.logits - out_real.logits).abs()
print(f"Fake vs Real logit diff: max={logit_diff.max():.6e}, mean={logit_diff.mean():.6e}")

# Check if predictions match
pred_fake = out_fake.logits.argmax(-1)
pred_real = out_real.logits.argmax(-1)
print(f"Predictions match: {(pred_fake == pred_real).all().item()}")

Fake vs Real logit diff: max=3.437500e-01, mean=5.834961e-02
Predictions match: True


## Summary

If all checks pass:
- Training and eval use identical quantization configs
- Fresh models quantize identically
- The `convert` step (fake → real int4) preserves model behavior

If checks fail, investigate:
1. Different random seeds affecting initialization
2. QAT config differences between train/eval code paths
3. LoRA merge order affecting quantization