In [None]:
import torch
import torch.nn as nn
import torch.quantization as tq
from transformers import AutoModelForCausalLM, AutoTokenizer

# Step 1: Load a smaller LLM (for demo we use DistilGPT2)
model_name = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Step 2: Define QAT config
# qat_config = tq.QConfig(
#     activation=tq.FakeQuantize.with_args(observer=tq.MovingAverageMinMaxObserver, quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine),
#     weight=tq.FakeQuantize.with_args(observer=tq.MinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
# )

qat_config = tq.QConfig(
    activation=tq.FakeQuantize.with_args(
        observer=tq.MovingAverageMinMaxObserver,
        quant_min=0, quant_max=255,
        dtype=torch.quint8, qscheme=torch.per_tensor_affine
    ),
    weight=tq.FakeQuantize.with_args(
        observer=tq.MinMaxObserver,
        quant_min=-128, quant_max=127,
        dtype=torch.qint8, qscheme=torch.per_tensor_symmetric
    )
)


# Explicitly ignore embedding layers
for name, module in model.named_modules():
    if isinstance(module, nn.Embedding):
        module.qconfig = None  # no quantization for embeddings


# Step 3: Attach quant/dequant stubs
model.qconfig = qat_config

model.train()
tq.prepare_qat(model, inplace=True)   # Model now has fake quant ops

# Step 4: Fine-tune the quantization-aware model
inputs = tokenizer("Quantization Aware Training on LLMs!", return_tensors="pt")

labels = inputs["input_ids"]

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# model.train()

for step in range(50):   # tiny fine-tuning loop
    outputs = model(**inputs, labels=labels)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if step % 10 == 0:
        print(f"Step {step} | Loss: {loss.item()}")

# Step 5: Convert to fully quantized
qat_model = tq.convert(model.eval(), inplace=False)


In [None]:
 import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling
from datasets import Dataset
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Method 1: Using BitsAndBytes QAT (Recommended)
def qat_with_bitsandbytes():
    """Quantization-Aware Training using BitsAndBytes"""
    try:
        from transformers import BitsAndBytesConfig
        import bitsandbytes as bnb

        # QAT Configuration
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
        )

        model_name = "distilgpt2"

        # Load model with quantization config
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map="auto",
            torch_dtype=torch.float16
        )

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Prepare for training
        model.gradient_checkpointing_enable()
        model.enable_input_require_grads()

        # Sample training data
        texts = [
            "Quantization Aware Training helps reduce model size while maintaining performance.",
            "Large Language Models can be efficiently quantized using various techniques.",
            "QAT allows models to adapt to quantization during training process."
        ]

        def tokenize_function(examples):
            return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)

        dataset = Dataset.from_dict({"text": texts})
        tokenized_dataset = dataset.map(tokenize_function, batched=True)

        # Training arguments
        training_args = TrainingArguments(
            output_dir="./qat_model",
            overwrite_output_dir=True,
            num_train_epochs=3,
            per_device_train_batch_size=1,
            gradient_accumulation_steps=4,
            warmup_steps=10,
            logging_steps=1,
            learning_rate=5e-5,
            fp16=True,
            save_strategy="no",
        )

        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_dataset,
            data_collator=data_collator,
        )

        print("Starting QAT with BitsAndBytes...")
        trainer.train()

        # Test the quantized model
        test_input = "Quantization makes models"
        inputs = tokenizer(test_input, return_tensors="pt")

        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=20, do_sample=True, temperature=0.7)
            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            print(f"Generated text: {generated_text}")

        return model, tokenizer

    except ImportError:
        print("BitsAndBytes not available. Install with: pip install bitsandbytes")
        return None, None

In [None]:
# Method 2: Custom QAT Implementation
def custom_qat_llm():
    """Custom Quantization-Aware Training implementation"""

    class QuantizedLinear(nn.Module):
        def __init__(self, in_features, out_features, bias=True):
            super().__init__()
            self.in_features = in_features
            self.out_features = out_features

            # Learnable quantization parameters
            self.weight_scale = nn.Parameter(torch.ones(1))
            self.weight_zero_point = nn.Parameter(torch.zeros(1))
            self.input_scale = nn.Parameter(torch.ones(1))
            self.input_zero_point = nn.Parameter(torch.zeros(1))

            self.weight = nn.Parameter(torch.randn(out_features, in_features))
            if bias:
                self.bias = nn.Parameter(torch.randn(out_features))
            else:
                self.register_parameter('bias', None)

        def quantize_tensor(self, tensor, scale, zero_point):
            # Fake quantization during training
            quantized = torch.round(tensor / scale + zero_point)
            quantized = torch.clamp(quantized, -128, 127)  # int8 range
            dequantized = (quantized - zero_point) * scale
            return dequantized

        def forward(self, x):
            # Quantize weights and inputs during forward pass
            q_weight = self.quantize_tensor(self.weight, self.weight_scale, self.weight_zero_point)
            q_input = self.quantize_tensor(x, self.input_scale, self.input_zero_point)
            return nn.functional.linear(q_input, q_weight, self.bias)

    def replace_linear_layers(model):
        """Replace Linear layers with Quantized versions"""
        for name, module in model.named_children():
            if isinstance(module, nn.Linear):
                # Replace with quantized version
                quant_layer = QuantizedLinear(module.in_features, module.out_features,
                                            module.bias is not None)
                quant_layer.weight.data = module.weight.data.clone()
                if module.bias is not None:
                    quant_layer.bias.data = module.bias.data.clone()
                setattr(model, name, quant_layer)
            else:
                replace_linear_layers(module)

    model_name = "distilgpt2"
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Replace linear layers with quantized versions
    print("Replacing linear layers with quantized versions...")
    replace_linear_layers(model)

    # Training loop
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    model.train()

    training_texts = [
        "Quantization Aware Training on LLMs reduces model size significantly.",
        "Custom QAT implementations allow fine control over quantization process.",
        "Fake quantization during training helps models adapt to quantization noise."
    ]

    print("Starting custom QAT training...")
    for epoch in range(3):
        total_loss = 0
        for i, text in enumerate(training_texts):
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
            labels = inputs["input_ids"].clone()

            outputs = model(**inputs, labels=labels)
            loss = outputs.loss

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()

        avg_loss = total_loss / len(training_texts)
        print(f"Epoch {epoch + 1} | Average Loss: {avg_loss:.4f}")

    # Test the model
    model.eval()
    test_input = "Quantization helps"
    inputs = tokenizer(test_input, return_tensors="pt")

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=15, do_sample=True, temperature=0.7)
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Generated text: {generated_text}")

    return model, tokenizer

In [None]:
# Method 3: Using PEFT with QLoRA (Most Practical)
def qat_with_peft():
    """QAT using PEFT (Parameter Efficient Fine-Tuning) with QLoRA"""
    try:
        from peft import LoraConfig, get_peft_model, TaskType

        model_name = "distilgpt2"
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # LoRA configuration for QAT
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=16,
            lora_alpha=32,
            lora_dropout=0.1,
            target_modules=["c_attn", "c_proj"]  # DistilGPT2 specific
        )

        # Apply PEFT
        model = get_peft_model(model, lora_config)
        print(f"Trainable parameters: {model.print_trainable_parameters()}")

        # Simple training
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
        model.train()

        training_texts = [
            "PEFT with LoRA enables efficient quantization-aware training.",
            "QLoRA combines quantization with low-rank adaptation effectively.",
            "Parameter-efficient methods reduce memory usage during QAT."
        ]

        print("Starting PEFT QAT training...")
        for epoch in range(2):
            total_loss = 0
            for text in training_texts:
                inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
                labels = inputs["input_ids"].clone()

                outputs = model(**inputs, labels=labels)
                loss = outputs.loss

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                total_loss += loss.item()

            avg_loss = total_loss / len(training_texts)
            print(f"Epoch {epoch + 1} | Average Loss: {avg_loss:.4f}")

        # Test
        model.eval()
        test_input = "PEFT enables"
        inputs = tokenizer(test_input, return_tensors="pt")

        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=15, do_sample=True, temperature=0.7)
            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            print(f"Generated text: {generated_text}")

        return model, tokenizer

    except ImportError:
        print("PEFT not available. Install with: pip install peft")
        return None, None


In [None]:
print("=== Quantization-Aware Training for LLMs ===\n")

# Try Method 1: BitsAndBytes (most practical)
print("Method 1: BitsAndBytes QAT")
model1, tokenizer1 = qat_with_bitsandbytes()

if model1 is None:
    print("\nMethod 2: Custom QAT Implementation")
    model2, tokenizer2 = custom_qat_llm()

    print("\nMethod 3: PEFT with QLoRA")
    model3, tokenizer3 = qat_with_peft()

print("\nQAT implementation completed!")

What is True Quantization-Aware Training (QAT)?

It simulates quantization during training using "fake quantization" modules (i.e., simulate low-precision behavior during forward pass but use full-precision weights for gradient updates).

It trains both weights and quantizer parameters (like scale and zero-point).

Mostly used in image models, but can be adapted for LLMs or Transformers with custom setup.

Uses PyTorch's torch.ao.quantization or Brevitai / Intel Neural Compressor / NVIDIA toolkit.

❌ Why it’s hard in LLMs?

Hugging Face Transformers don’t directly support QAT yet for LLMs.

No ready-made torch.ao.quantization.prepare_qat path for large nn.Module like GPT2, LLaMA, Falcon etc.

But, for small models like DistilGPT2 or BERT, we can demo it.

✅ Example: True QAT using PyTorch (Fake Quant + QConfig) on DistilGPT2

This is a simplified setup using torch.ao.quantization.

| QAT Method          | Models / Scale       | Key Benefit                                               |
| ------------------- | -------------------- | --------------------------------------------------------- |
| PyTorch QAT         | LLaMA‑3 8B           | Nearly PTQ-level accuracy with int4/int8 quantization     |
| LLM‑QAT (Data‑Free) | LLaMA 7B / 13B / 30B | Better low-bit performance without requiring real data    |
| EfficientQAT        | LLaMA‑2 up to 70B    | Full 2‑bit QAT on a single GPU with minimal accuracy loss |
| DL‑QAT              | LLaMA / LLaMA‑2 (7B) | Efficient LoRA-style updates, stronger quality at 3-bit   |


| Model Source / Project          | Provided Resources                                 |
| ------------------------------- | -------------------------------------------------- |
| **EfficientQAT** (OpenGVLab)    | Pre‑trained QAT LLaMA‑2 checkpoints (Hugging Face) |
| **Gemma 3** (Google)            | QAT‑trained Gemma 3 variants publicly released     |
| **LLM‑QAT** (Facebook Research) | Training code to produce your own QAT models       |


google/gemma-3-1b-it-qat-q4_0-gguf

google/gemma-3-4b-it-qat-q4_0-gguf

google/gemma-3-12b-it-qat-q4_0-gguf

google/gemma-3-27b-it-qat-q4_0-gguf

In [None]:
!pip install transformers torch --quiet


In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.ao.quantization import (
    QConfig,
    fake_quantize_per_tensor_affine,
    default_qat_qconfig,
    prepare_qat,
    convert,
)

# 1. Load a small model (DistilGPT2 for demo)
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 2. Define a simple QConfig for QAT (you can tweak this)
model.qconfig = default_qat_qconfig

# 3. Prepare model for QAT (inserts fake quant nodes)
model.train()
qat_model = prepare_qat(model)

# 4. Training loop (demo on toy data)
text = "Quantization aware training is powerful."
inputs = tokenizer(text, return_tensors="pt", padding=True)

optimizer = torch.optim.Adam(qat_model.parameters(), lr=1e-4)

print("Starting QAT training...")
for epoch in range(3):
    outputs = qat_model(**inputs, labels=inputs["input_ids"])
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")

# 5. Convert to real quantized model
qat_model.eval()
quantized_model = convert(qat_model)

# 6. Test inference
with torch.no_grad():
    output = quantized_model.generate(**inputs, max_new_tokens=10)
    print("\nGenerated:", tokenizer.decode(output[0]))
