# Model Distillation for Pretrained LLMs

---

## **1. Introduction to Model Distillation**

**Model distillation** is a model compression technique in which a smaller, simpler model (the *student*) is trained to replicate the behavior of a larger, often more accurate model (the *teacher*). Introduced by Hinton et al. (2015), distillation has become a cornerstone in the development of resource-efficient machine learning systems.

In traditional supervised learning, models learn from labeled data. In contrast, model distillation uses *soft targets*—the probabilistic outputs (logits) of a teacher model—to transfer knowledge. This helps the student learn not just the correct answer but the *relative confidence* in each possible output, capturing rich information about the decision boundary.

### Benefits of Model Distillation

* Reduces **inference latency** and **memory footprint**
* Makes deployment to **edge devices** feasible
* Allows **faster training and inference** for downstream tasks
* Improves **generalization** by capturing teacher’s inductive biases

---

## **2. Model Distillation for Pretrained Large Language Models (LLMs)**

LLMs such as GPT, BERT, LLaMA, and Falcon have billions of parameters and require immense resources for inference and fine-tuning. Distillation makes it possible to produce smaller, faster models while retaining much of the performance, enabling:

* Real-time inference in production
* Use in resource-constrained environments
* Democratization of powerful LLMs to broader audiences

### Types of Distillation in LLMs

| Type                  | Description                                                     |
| --------------------- | --------------------------------------------------------------- |
| **Logit-based**       | Student mimics teacher’s output probabilities (most common)     |
| **Feature-based**     | Student mimics hidden representations from intermediate layers  |
| **Task-specific**     | Student learns from teacher on a specific downstream task       |
| **Self-distillation** | Teacher and student are the same model at different checkpoints |
| **Multi-teacher**     | Combines knowledge from multiple teacher models                 |

---

## **3. Best Practices in LLM Distillation**

### a. **Pretraining vs. Finetuning Distillation**

* **Pretraining Distillation**: Distilling a model over the same corpus as the teacher's original pretraining (rare due to cost).
* **Post-hoc or Finetuning Distillation**: Much more common; distills a finetuned teacher (e.g., on QA or summarization) to a smaller student.

### b. **Distillation Objective Functions**

* **Kullback–Leibler Divergence (KL-div)**: Most common loss function for matching teacher logits.
* **Cosine similarity / MSE**: For feature-based distillation.
* **Combined Loss**: Mix of supervised loss (e.g., cross-entropy with ground truth) and distillation loss with teacher outputs.

### c. **Temperature Scaling**

* Use of a *temperature parameter* > 1 softens teacher logits, exposing more useful information for student learning.

### d. **Curriculum Learning**

* Begin training on easy examples, then progressively introduce harder ones to stabilize training.

### e. **Layer Mapping**

* Choose a strategy to align teacher and student layers. Common ones:

  * Match final outputs only
  * Match every Nth layer
  * Match all corresponding layers (if architectures align)

---

## **4. Evaluating Distilled Models**

### a. **Performance Metrics**

* Compare student vs. teacher on:

  * **Accuracy / F1 / BLEU / ROUGE** (task-specific)
  * **Log-likelihood / perplexity** (language modeling)
  * **Win rate in pairwise evaluations** (for generation tasks)

### b. **Behavioral Parity Tests**

* Check if the student preserves qualitative behavior like:

  * Instruction-following
  * Ethical constraints
  * Biases (may be desirable to remove)

### c. **Distillation Quality Indicators**

* High KL divergence = poor distillation
* Low task accuracy despite similar logits = underfitting
* Similar output distribution and behavior = good distillation

---

## **5. Dataset Considerations**

### a. **Do You Need a Dataset?**

* **Yes**, especially for task-specific or fine-tuning distillation.
* **Unlabeled data** is often sufficient when using teacher outputs (i.e., pseudo-labels).
* Some self-distillation strategies or synthetic datasets can reduce need for large-scale corpora.

### b. **Choosing the Right Dataset**

* **For general-purpose LLMs**:

  * Use a corpus with wide domain coverage (e.g., C4, OpenWebText)
* **For downstream tasks**:

  * Use the same dataset teacher was fine-tuned on (e.g., SQuAD for QA)
* **For instruction tuning**:

  * Use datasets like Alpaca, Dolly, or FLAN-style prompts

### c. **Data Augmentation Techniques**

* Paraphrasing
* Prompt variation
* Synthetic data generation using teacher

---

## **6. Resource Requirements**

| Component         | Requirement                                                         |
| ----------------- | ------------------------------------------------------------------- |
| **Teacher model** | GPU/TPU for inference, ideally batched                              |
| **Student model** | Moderate-sized GPU/TPU (can be trained with less memory)            |
| **Storage**       | For intermediate logits if storing teacher outputs                  |
| **Compute**       | Depends on student size and dataset; less than full LLM pretraining |

* Example: Distilling GPT-3 (175B) into a 6B model may require:

  * 2–4 A100s (teacher inference)
  * 1–2 A100s (student training)

---

## **7. Success Stories in LLM Distillation**

* **DistilBERT** (Sanh et al., 2019): 40% smaller, 60% faster, 97% performance retention from BERT
* **TinyBERT**: Layer-to-layer distillation with task-specific finetuning
* **MiniLM**: Distilled with deep self-attention distillation, performs close to BERT with fewer parameters
* **DistilGPT2**: Smaller GPT2 variant used in real-time applications
* **Alpaca & Vicuna**: Instruction-tuned smaller models distilled from LLaMA using curated prompt datasets
* **Mistral 7B Instruct**: Partly distilled, instruction-tuned model close to GPT-3.5

---

## **8. Common Pitfalls**

### a. **Overfitting to Teacher**

* Student may memorize outputs instead of generalizing. Remedy: include ground truth loss or use diverse data.

### b. **Poor Dataset Coverage**

* If student is trained on narrow or biased data, it may fail to generalize or hallucinate.

### c. **Misaligned Architectures**

* Difficult to align hidden states if student and teacher use different architectures.

### d. **Loss of Calibration**

* Distillation can worsen uncertainty estimation (e.g., confidence scores), affecting downstream trust.

### e. **Evaluation Gaps**

* Student may perform well on benchmarks but poorly in open-ended or out-of-distribution scenarios.

---

## **9. Tools and Libraries**

* **Hugging Face Transformers + Datasets**
* **PyTorch + PyTorch Lightning**
* **OpenDelta / LoRA**: Efficient fine-tuning strategies usable with distillation
* **DistillToolkit**: Community projects for distillation (e.g., FastDistill)
* **LMFlow**: Open-source framework for LLM distillation
* **DeepSpeed / Megatron-LM**: Efficient parallelism for large-scale distillation

---

## **10. Key Takeaways**

Model distillation is an essential tool for making LLMs practical and scalable. It enables smaller models to inherit the capabilities of their larger counterparts with significant savings in computational cost and latency. By carefully selecting the dataset, aligning training objectives, and monitoring the student’s performance, it’s possible to build powerful, distilled models that retain the essence of foundational LLMs.

---


# Lab: Distilling a Pre-trained LLM using TensorFlow
- ### SQuAD Dataset

In [None]:
import tensorflow as tf
import numpy as np
from transformers import TFAutoModelForQuestionAnswering, AutoTokenizer, DefaultDataCollator
from datasets import load_dataset
import os

#### Limit memory usage

In [None]:
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
tf.config.threading.set_intra_op_parallelism_threads(2)
tf.config.threading.set_inter_op_parallelism_threads(2)

#### Load the SQuAD dataset

In [None]:
squad = load_dataset("squad")

# Use a small subset for resource constraints
train_data = squad['train'].select(range(1000))
val_data = squad['validation'].select(range(200))

#### Load the teacher model and tokenizer

In [None]:
teacher_model_name = "bert-base-uncased"
teacher_model = TFAutoModelForQuestionAnswering.from_pretrained(teacher_model_name)
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

#### Preprocess the data

In [None]:
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        truncation="only_second",
        max_length=384,
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    sample_mapping = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []
    for i, offset in enumerate(inputs["offset_mapping"]):
        input_ids = inputs["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)
        sequence_ids = inputs.sequence_ids(i)
        sample_index = sample_mapping[i]
        answer = answers[sample_index]
        if len(answer["answer_start"]) == 0:
            start_positions.append(cls_index)
            end_positions.append(cls_index)
        else:
            start_char = answer["answer_start"][0]
            end_char = start_char + len(answer["text"][0])
            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1
            if not (offset[token_start_index][0] <= start_char and offset[token_end_index][1] >= end_char):
                start_positions.append(cls_index)
                end_positions.append(cls_index)
            else:
                while token_start_index < len(offset) and offset[token_start_index][0] <= start_char:
                    token_start_index += 1
                start_positions.append(token_start_index - 1)
                while offset[token_end_index][1] >= end_char:
                    token_end_index -= 1
                end_positions.append(token_end_index + 1)
    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

encoded_train = train_data.map(preprocess_function, batched=True, remove_columns=train_data.column_names)
encoded_val = val_data.map(preprocess_function, batched=True, remove_columns=val_data.column_names)

data_collator = DefaultDataCollator(return_tensors="tf")


#### Convert datasets to tf.data

In [None]:
train_tf_dataset = encoded_train.to_tf_dataset(
    columns=["input_ids", "attention_mask"],
    label_cols=["start_positions", "end_positions"],
    shuffle=True,
    batch_size=8,
    collate_fn=data_collator
)
val_tf_dataset = encoded_val.to_tf_dataset(
    columns=["input_ids", "attention_mask"],
    label_cols=["start_positions", "end_positions"],
    shuffle=False,
    batch_size=8,
    collate_fn=data_collator
)


#### Define the student model (smaller architecture)

In [None]:
from transformers import BertConfig, TFBertForQuestionAnswering

student_config = BertConfig.from_pretrained(teacher_model_name)
student_config.num_hidden_layers = 4  # Reduce depth
student_model = TFBertForQuestionAnswering(config=student_config)

#### Distillation loss (teacher-student)

In [None]:
loss_fn_hard = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_fn_soft = tf.keras.losses.KLDivergence()

class DistillationModel(tf.keras.Model):
    def __init__(self, student, teacher, alpha=0.5):
        super().__init__()
        self.student = student
        self.teacher = teacher
        self.alpha = alpha

    def compile(self, optimizer):
        super().compile()
        self.optimizer = optimizer

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            student_outputs = self.student(x, training=True)
            teacher_outputs = self.teacher(x, training=False)

            s_start = student_outputs.start_logits
            s_end = student_outputs.end_logits
            t_start = teacher_outputs.start_logits
            t_end = teacher_outputs.end_logits

            y_start = tf.reshape(y["start_positions"], [-1])
            y_end = tf.reshape(y["end_positions"], [-1])

            loss_start_hard = loss_fn_hard(y_start, s_start)
            loss_end_hard = loss_fn_hard(y_end, s_end)

            loss_start_soft = loss_fn_soft(
                tf.nn.softmax(t_start, axis=-1),
                tf.nn.log_softmax(s_start, axis=-1)
            )
            loss_end_soft = loss_fn_soft(
                tf.nn.softmax(t_end, axis=-1),
                tf.nn.log_softmax(s_end, axis=-1)
            )

            loss_start = self.alpha * loss_start_hard + (1 - self.alpha) * loss_start_soft
            loss_end = self.alpha * loss_end_hard + (1 - self.alpha) * loss_end_soft

            loss = (loss_start + loss_end) / 2

        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))
        return {"loss": loss}

    def test_step(self, data):
        x, y = data
        student_outputs = self.student(x, training=False)
        s_start = student_outputs.start_logits
        s_end = student_outputs.end_logits
        y_start = tf.reshape(y["start_positions"], [-1])
        y_end = tf.reshape(y["end_positions"], [-1])
        loss_start = loss_fn_hard(y_start, s_start)
        loss_end = loss_fn_hard(y_end, s_end)
        loss = (loss_start + loss_end) / 2
        return {"loss": loss}

#### Train the distilled model

In [None]:
distilled_model = DistillationModel(student_model, teacher_model, alpha=0.7)
distilled_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
)
distilled_model.fit(train_tf_dataset, epochs=2, validation_data=val_tf_dataset)

#### Evaluate and compare size and speed

In [None]:
import time
import tempfile

# Measure inference time
sample = next(iter(val_tf_dataset))
start = time.time()
_ = student_model(sample[0])
end = time.time()
print("Student inference time (batch):", end - start)

start = time.time()
_ = teacher_model(sample[0])
end = time.time()
print("Teacher inference time (batch):", end - start)

# Model size comparison
student_model.save_pretrained(tempfile.mkdtemp())
teacher_model.save_pretrained(tempfile.mkdtemp())

print("Student model size (MB):", sum(os.path.getsize(os.path.join(root, f)) for root, _, files in os.walk(student_model.name_or_path) for f in files) / 1e6)
print("Teacher model size (MB):", sum(os.path.getsize(os.path.join(root, f)) for root, _, files in os.walk(teacher_model.name_or_path) for f in files) / 1e6)


In [None]:
student_model.summary()

In [None]:
teacher_model.summary()

In [None]:
student_model.save("student_model")
teacher_model.save("teacher_model")