Original Notebook: [7_1_knowledge_distillation_Llama.ipynb](https://github.com/peremartra/Large-Language-Model-Notebooks-Course/blob/main/6-PRUNING/7_1_knowledge_distillation_Llama.ipynb) by [Pere Martra](https://www.linkedin.com/in/pere-martra/)

**Teacher Model:** [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)

**Student Model:** [oopere/Llama-3.2-1B-pruned-40pct](https://huggingface.co/oopere/Llama-3.2-1B-pruned-40pct)

Vast AI Environment: GPU A100

---
Example in a production environment:
- Teacher: a 10B-parameter LLM model
- Student: a 1B-parameter LLM model

# Set Up

## Install Libraries & Configure Variables

In [1]:
import sys

In [4]:
!{sys.executable} -m pip install datasets==3.2.0 transformers==4.47.1 torch --quiet

[0m

In [5]:
!{sys.executable} -m pip install torchvision --quiet

[0m

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW
from datasets import load_dataset
from torch.nn import functional as F
from torch.utils.data import DataLoader

In [3]:
print(torch.__version__)

2.8.0+cu128


In [4]:
import torchvision
print(torchvision.__version__)

0.23.0+cu128


## Login to Hugging Face

In [5]:
!pip install huggingface_hub --quiet

[0m

In [6]:
from getpass import getpass
hf_token = getpass("Hugging Face: ")

Hugging Face:  ········


In [7]:
from huggingface_hub import login
login(token=hf_token)

## Download the Models
The teacher model will be the same model used as the base to create the pruned model we are going to train.

* **Teacher model:** `"meta-llama/Llama-3.2-1B"` (the base model)
* **Student Model:** `"oopere/pruned40-llama-3.2-1B"` (the pruned version of the base model)

In **some** scenarios the teacher model must be the **same** model used to create the pruned version.
- Imagine you have a model that works perfectly and has been trained with proprietary **company data**, thus containing specific knowledge. In this case, if the goal is to replicate the behavior of this model in a smaller one, it wouldn’t make sense to use a larger model that hasn’t been trained on the same data.

In [8]:
# Load teacher and student models and their tokenizers
teacher_model_name = "meta-llama/Llama-3.2-1B"
student_model_name = "oopere/pruned40-llama-3.2-1B"

In [9]:
# Initialize tokenizer
# The same tokenizer can be used for both models since they're both Llama-based
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

In [10]:
# Load model
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name)

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

In [11]:
# Load model
student_model = AutoModelForCausalLM.from_pretrained(student_model_name)

config.json:   0%|          | 0.00/884 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.83G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/180 [00:00<?, ?B/s]

# Load the Data


During the pruning process, the model inevitably lost some capabilities, as expected.
- In this case, the pruned model lost the generation capabilities much more tan the comprehension capabilities.

The **dataset** to be used will largely **depend on** the results aimed to be achieved through the Knowledge Distillation process.
- In this case, `Lambada` that showed the most degradation, both in its standard and OpenAI versions.
  - This benchmark evaluates the model's ability to predict the last word of a text. However, these are not simple texts; the model must pay close attention since the last word needs to be inferred by considering the entire story, requiring understanding of broader context, coherence, and fluency.
- Other suitable alternatives could include `BookCorpus`.
- **Used benchmark:** a small portion of the `ptb_text_only` dataset because of both time and memory constraints.
  - It may not be very suitable dataset, as it is (only) for text generation task.

In [12]:
# Data Loading
dataset = load_dataset("ptb_text_only", "penn_treebank", split="train", trust_remote_code=True)
# Take a subset for faster training/testing
original_dataset = dataset
dataset = dataset.select(range(1000))

README.md: 0.00B [00:00, ?B/s]

ptb_text_only.py: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/5.10M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/400k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/450k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/42068 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3761 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3370 [00:00<?, ? examples/s]

The `tokenize_function` is where the real preprocessing magic happens: it transforms raw text into a format that the models can understand.

The function expects a dictionary input with a key `sentence` that contains the text to be tokenized.
The text is procesed using the tokenizer previously loaded. Using the parameters:
- `padding`="max_length": ensures all sequences have the **same length** by adding padding tokens
- `truncation`=True: cuts off sequences that are too long
- `max_length`=128: sets the maximum sequence length, suitable for Llama models
- `return_tensors`="pt": returns **PyTorch tensors** instead of lists

Then, the function prepares output:
- Creates **input_ids**: the numerical representation of input tokens
- Creates **labels**: in this case, identical to input_ids (clone) for language modeling
- Returns **attention masks** to indicate which tokens are **padding** vs. **real content**

In **knowledge distillation**, the aim is to transfer knowledge from a larger teacher model to a smaller student model. The **quality** of this process heavily depends on **how data is prepared**.

- The careful **padding** and **truncation** ensure that all sequences are properly formatted for both teacher and student models.
- The **attention masks** help models focus on relevant parts of the input
- The **consistent sequence length** (128 tokens) optimizes memory usage while maintaining enough context for learning

**Note:** We are setting up for language modeling specifically, which is why the labels are identical to the inputs. In language modeling, the task is to predict the next token given the previous tokens, so each input sequence serves as its own target.

In [13]:
# Create a tokenization function
def tokenize_function(examples):
    # Tokenize with padding and truncation
    tokenized = tokenizer(
        examples["sentence"],
        padding="max_length",
        truncation=True,
        max_length=128,  # Adjusted for Llama models
        return_tensors="pt"
    )

    # Create input_ids and labels for language modeling
    input_ids = tokenized["input_ids"]
    labels = input_ids.clone() # create a copy of input ids

    return {
        "input_ids": input_ids,
        "attention_mask": tokenized["attention_mask"],
        "labels": labels
    }

Time to use the `tokenize_function` to tokenize the Dataset. This code uses the datasets `map` function, which is specially **designed for processing large datasets**.

Parameters:
- `batched`=True: Processes multiple examples in batches for efficiency.
- `batch_size`=32: Specifies the size of each batch during mapping. A **smaller batch size** ensures compatibility with **memory constraints**.
- `remove_columns`=dataset.column_names: Removes original columns after tokenization to **avoid redundancy** and reduce memory usage.
- `num_proc`=4: Enables **parallel processing** with four processes, speeding up the operation on large datasets.
- `desc`="Processing examples": Displays a description in the progress bar for better clarity.
- `load_from_cache_file`=False: Disables caching to **ensure fresh processing** of the dataset, which is helpful during debugging.

The `tokenized_datasets` object contains the preprocessed data with *input_ids*, *attention_mask*, and *labels* - ready for use in model training!

In [14]:
# Process the dataset with progress bar
print("Tokenizing dataset...")
tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    batch_size=32,  # Smaller batch size for mapping
    remove_columns=dataset.column_names,
    desc="Processing examples",
    load_from_cache_file=False  # Disable caching for debugging
)

Tokenizing dataset...


Processing examples:   0%|          | 0/1000 [00:00<?, ? examples/s]

The format has to be converted to `torch`, making it compatible with PyTorch functions.

In [15]:
# Convert to PyTorch format
tokenized_datasets.set_format("torch")

Set up how data will be fed into the models during training:

The `DataLoader` is a PyTorch utility that **efficiently handles batching and iteration** over the dataset.
- `batch_size=`: Specifies the number of samples per batch.
  - A smaller batch size is used here due to the memory constraints of large models like Llama.
- `shuffle=True`: Randomizes the order of data samples in each epoch.
  - It improves the model’s **generalization** by reducing the likelihood of learning spurious patterns from data order.



In [16]:
# Create DataLoader
dataloader = DataLoader(
    tokenized_datasets,
    batch_size=8,  # Reduced batch size due to model size
    shuffle=True
)

# Knowledge Distillation

Start moving the models to a cuda device (GPU), if available.

In [17]:
# Move models to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model.to(device)
student_model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=4916, bias=False)
          (up_proj): Linear(in_features=2048, out_features=4916, bias=False)
          (down_proj): Linear(in_features=4916, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm):

## Teacher: Evaluation Mode

**Crucial for knowledge distillation:**


Put the **teacher model** into **evaluation mode**, which:
- Disables dropout layers
- Freezes batch normalization statistics
- Ensures consistent outputs for the same inputs

**Why Important:**
- The teacher model should provide **stable, consistent predictions** to guide the student
- The teacher model is **not trianed** anymore - it's only being used to generate "soft targets"
- Any **randomness** (like dropout) would make the **knowledge transfer less reliable**

In [18]:
# Set teacher model to evaluation mode
teacher_model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm):

## Student: Trainig Mode

The training process in knowledge distillation involves transferring knowledge from a larger teacher model to a smaller student model, with the idea that the student model mimics the behaviour of the teacher model.

### Optimizer & Training Loop

The optimizer updates the student model's parameters to minimize the loss, improving its ability to replicate the teacher's outputs.
- **AdamW** is kind of a **standard** for **trasformers** based models.

Hyperparameters' Role:
- `temperature`: Controls how "soft" the teacher's predictions are made.
  - **Higher** temperature (2.0) **smooths out** the probability distributions.
- `alpha`: **Balances** the importance of **matching the teacher's predictions** versus **ground truth**.
- `accumulation_steps`: Allows for **larger effective batch** sizes without increasing memory usage.



In [20]:
# Define optimizer for student model
optimizer = torch.optim.AdamW(student_model.parameters(), lr=1e-5)  # Reduced learning rate for Llama

# Training loop
num_epochs = 10
temperature = 2.0  # Increased temperature for Llama
alpha = 1  # Weight for soft loss

accumulation_steps = 4  # Gradient accumulation for larger effective batch size


**Note:**
In this case we do **pure knowledge distillation**:
- We are **not** training the model **from scratch**, but **only** training it **to mimic** the teacher’s behavior.
  - `alpha` parameters is **not used**.

In [21]:

for epoch in range(num_epochs):
    ### 1 - Model Preparation ###
    # initializes each training epoch,
    # putting the student model in training mode
    student_model.train()
    # Initializes total_loss to track the cumulative loss for the epoch.
    total_loss = 0

    for batch_idx, batch in enumerate(dataloader):
        ### 2 - Data procesing.  ###
        # Moves the batch data to the appropriate device (CPU/GPU) for processing.
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        ### 3 - Teacher Model Inference ###
        # Disables gradient computation to save memory and speed up inference.
        with torch.no_grad():
            teacher_outputs = teacher_model(
                input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            # Applies temperature scaling to soften the teacher's predictions
            teacher_logits = teacher_outputs.logits / temperature

        ### 4 - Student Model Inference. ###
        # The student model generates logits for the same input data.
        student_outputs = student_model(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        student_logits = student_outputs.logits

        ### 5 - Compute loss ###
        # Converts logits to probabilities using softmax
        teacher_probs = F.softmax(teacher_logits, dim=-1)
        # Computes the KL Divergence between the teacher's probabilities and the student's log probabilities.
        # KL Divergence measures how well the student's predictions match the teacher's.
        student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
        # Note: for more stability, "kl_div" takes 'student_log_probs' but not 'student_probs'
        loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
        # The loss is divided by accumulation_steps to balance gradient updates across accumulated batches.
        loss = loss / accumulation_steps

        ### 6- Backward pass ###
        loss.backward()

        ### 7 - Optimization Gradient Accumulation ###
        # Accumulates gradients over multiple batches
        # Updates model parameters when enough gradients are accumulated
        # Resets gradients after update
        if ((batch_idx + 1) % accumulation_steps == 0) or (batch_idx + 1 == len(dataloader)):
            optimizer.step()
            optimizer.zero_grad()

        ### 8 - Loss Tracking ###
        # Scales the loss back up by multiplying it with accumulation_steps to reflect the actual batch contribution.
        total_loss += loss.item() * accumulation_steps

        # if (batch_idx + 1) % 100 == 0:
        #    print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}, Loss: {loss.item():.4f}")

    ### 9 - Epoch-Level Reporting
    # Computes the average loss for the epoch by dividing the total loss by the number of batches.
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")


Epoch 1/10, Average Loss: 30.1915
Epoch 2/10, Average Loss: 8.3191
Epoch 3/10, Average Loss: 6.3736
Epoch 4/10, Average Loss: 5.4338
Epoch 5/10, Average Loss: 4.8372
Epoch 6/10, Average Loss: 4.4554
Epoch 7/10, Average Loss: 4.1792
Epoch 8/10, Average Loss: 3.9438
Epoch 9/10, Average Loss: 3.7438
Epoch 10/10, Average Loss: 3.5864


# Store the Model
At the end of the training Loop, we have a model that can be store or uploaded to Hugging Face.

In [22]:
student_model_name = "pruned_distilgpt2_kd_gem"

In [23]:
# Save the fine-tuned student model
student_model.save_pretrained(student_model_name)
tokenizer.save_pretrained(student_model_name)

('pruned_distilgpt2_kd_gem/tokenizer_config.json',
 'pruned_distilgpt2_kd_gem/special_tokens_map.json',
 'pruned_distilgpt2_kd_gem/tokenizer.json')

In [24]:
student_model.push_to_hub(student_model_name,
                  private=False,
                  use_temp_dir=False)

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...distilgpt2_kd_gem/model.safetensors:   0%|          |  553kB / 3.66GB            

CommitInfo(commit_url='https://huggingface.co/Saralatifi/pruned_distilgpt2_kd_gem/commit/2272eb9de53b4addafddd337a2c6eb6750539069', commit_message='Upload LlamaForCausalLM', commit_description='', oid='2272eb9de53b4addafddd337a2c6eb6750539069', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Saralatifi/pruned_distilgpt2_kd_gem', endpoint='https://huggingface.co', repo_type='model', repo_id='Saralatifi/pruned_distilgpt2_kd_gem'), pr_revision=None, pr_num=None)

In [25]:
tokenizer.push_to_hub(student_model_name,
                      private=False,
                      use_temp_dir=False)

README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...ed_distilgpt2_kd_gem/tokenizer.json: 100%|##########| 17.2MB / 17.2MB            

CommitInfo(commit_url='https://huggingface.co/Saralatifi/pruned_distilgpt2_kd_gem/commit/e9a053a26acfab8d1cbe46cdcac04a3e4746f320', commit_message='Upload tokenizer', commit_description='', oid='e9a053a26acfab8d1cbe46cdcac04a3e4746f320', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Saralatifi/pruned_distilgpt2_kd_gem', endpoint='https://huggingface.co', repo_type='model', repo_id='Saralatifi/pruned_distilgpt2_kd_gem'), pr_revision=None, pr_num=None)

# When to use KD versus other forms of fine-tuning?
There are different efficient ways to introduce knowledge into a model: LoRA and QLoRA. Their use compared to **KD** serves **different purposes**.

**KD** helps us imitate a model:  we might have a model that has already been fine-tuned with our data and gone through a Pruning process. To **recover the lost capacity**, the best approach is to perform a KD process from the base model.
- In general, we could use **LoRA** or **QLoRA** to improve the model's performance, and we would benefit from the reduction in trainable weights that these two techniques bring.