# Knowledge Distillation ‚öóÔ∏èüßë‚Äçüè´üë®‚Äçüéì

In NLP, the quest for efficient yet powerful models is an ongoing challenge. This notebook explores task-specific knowledge distillation, a **compression** technique that enables us to transfer the knowledge of a large, specialized teacher model to a smaller, more efficient student model‚Äîmuch like a seasoned expert mentoring a promising apprentice.

In our case, the teacher is a `bert-base` model (110M params), fine-tuned for sentiment analysis on the IMDB dataset, where it has learned to capture the subtle nuances of positive and negative reviews. Our goal is to distill this knowledge into a more lightweight `distilbert-base` model (67M params), training it to mimic the teacher‚Äôs outputs while maintaining strong performance.

This approach significantly reduces computational overhead, making it easier to deploy powerful NLP models in resource-constrained environments without sacrificing accuracy.

# Compressing the finetuned BERT into DistilBERT

Install depandancies and import libs

In [None]:
!pip install -q -U transformers
!pip install -q -U datasets
!pip install -q -U accelerate
!pip install -q -U bitsandbytes


[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m44.0/44.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m10.0/10.0 MB[0m [31m91.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m485.4/485.4 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m345.1/345.1 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m76.1/76.1 MB[0m [31m

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DistilBertForSequenceClassification,
    get_scheduler,
    DistilBertConfig
)
from datasets import load_dataset, DatasetDict

## 1- Load IMDB Dataset

In [None]:
dataset = load_dataset("stanfordnlp/imdb")

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

## 2- Load models

Load teacher model : a bert model fine tuned for sentiment analysis for movie reviews (110 M params)

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# Load fine-tuned teacher model (BERT fine-tuned on IMDB)
teacher_model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb").to(device)

config.json:   0%|          | 0.00/511 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [None]:
tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

I use the same sequence length (128) as the teacher model. If you use a different sequence length for distillation, the student model might struggle to align its learned representations with the teacher's leading to suboptimal performance

In [None]:
# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format("torch", columns=["input_ids", "attention_mask", "label"])

In [None]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 50000
    })
})

Load Student Model : Disitilled BERT 67M params (12 att. heads and 6 layer). For further compression, I load only 8 att. heads and 4 layers, about 52,7M params

In [None]:
# Load student model
dbert_config = DistilBertConfig(n_heads=8, n_layers=4,num_labels=2)
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased",config=dbert_config).to(device)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# @title
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
print_trainable_parameters(teacher_model.parameters)

trainable params: 109483778 || all params: 109483778 || trainable%: 100.0


In [None]:
print_trainable_parameters(student_model)

trainable params: 52779266 || all params: 52779266 || trainable%: 100.0


## 3- Distillation with soft targets

In [None]:
# Create dataloaders
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=16, shuffle=True)
test_dataloader = DataLoader(tokenized_datasets["test"], batch_size=16)

In [None]:
# Define distillation loss
def distillation_loss(student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5):

    # Soften the teacher's logits with temperature
    soft_teacher = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
    soft_student = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)

    # KL divergence loss (distillation loss)
    soft_loss = torch.nn.functional.kl_div(soft_student, soft_teacher, reduction="batchmean") * (temperature**2)

    # Cross-entropy loss (task-specific loss)
    hard_loss = torch.nn.functional.cross_entropy(student_logits, labels)

    # Combined loss
    combined_loss=alpha * soft_loss + (1 - alpha) * hard_loss

    return combined_loss

The teacher model was fine-tuned for 4 epochs, but distillation often requires fewer epochs because the student model learns from the teacher's soft targets, which are easier to fit. So I'll use 3 epochs.

In [None]:
# Training loop
num_epochs = 3

# Define optimizer and learning rate scheduler
optimizer = AdamW(student_model.parameters(), lr=5e-5,weight_decay=0.01) #L2 Regularization (penalize large weights) with 0.01 strength
num_training_steps = len(train_dataloader) * num_epochs
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
max_grad_norm=1.0

for epoch in range(num_epochs):

    student_model.train()
    teacher_model.eval()

    total_loss = 0
    total_teacher_correct=0
    total_student_correct=0
    total_samples = 0

    #for batch in train_dataloader:
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1} [Train]", unit="batch", leave= False):

        # Move batch to device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        # Disable gradient calculation for teacher model
        with torch.no_grad():
            # Forward pass through teacher model
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits

            # Eval on the fly per batch : Compute teacher accuracy
            teacher_preds = torch.argmax(teacher_logits, dim=-1)
            teacher_correct = (teacher_preds == labels).sum().item()
            total_teacher_correct += teacher_correct

        # Forward pass through student model
        student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits

        #Eval on the fly: Compute student accuracy
        student_preds = torch.argmax(student_logits, dim=-1)
        student_correct = (student_preds == labels).sum().item()
        total_student_correct += student_correct

        # Compute distillation loss
        loss = distillation_loss(student_logits, teacher_logits, labels)

        # Backpropagation with gradient clipping
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_grad_norm)
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        # number of samples in batch
        total_samples += labels.size(0)

    # Calculate the accuracy for this epoch
    teacher_accuracy = 100 * total_teacher_correct / total_samples
    student_accuracy = 100 * total_student_correct / total_samples

    #print(f"Epoch {epoch + 1} - Average Train Loss: {total_loss / len(train_dataloader):.4f}, "
    #f"Teacher Train Accuracy: {teacher_accuracy:.2f}%,"
    #f"Student Train Accuracy: {student_accuracy:.2f}%")


    # Set models to evaluation mode
    student_model.eval()
    teacher_model.eval()

    # Validation loop
    test_loss = 0
    test_teacher_correct = 0
    test_student_correct = 0
    test_samples = 0

    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc=f"Epoch {epoch + 1} [Eval]", unit="batch", leave= False):

            # Move batch to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            # Forward pass through teacher model
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits

            # Forward pass through student model
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits

            # Compute test loss
            loss = distillation_loss(student_logits, teacher_logits, labels)
            test_loss += loss.item()

            # Compute teacher accuracy
            teacher_preds = torch.argmax(teacher_logits, dim=-1)
            test_teacher_correct += (teacher_preds == labels).sum().item()

            # Compute student accuracy
            student_preds = torch.argmax(student_logits, dim=-1)
            test_student_correct += (student_preds == labels).sum().item()

            test_samples += labels.size(0)

    # Calculate val accuracy
    test_teacher_accuracy = 100 * test_teacher_correct / test_samples
    test_student_accuracy = 100 * test_student_correct / test_samples

    # Print epoch metrics
    print(
    f"Epoch {epoch + 1} - "
    f"dstl_loss={total_loss / len(train_dataloader):.4f}, val_dstl_loss={test_loss / len(test_dataloader):.4f} |"
    f"tch_acc={teacher_accuracy:.2f}%, val_tch_acc={test_teacher_accuracy:.2f}% |"
    f"std_acc={student_accuracy:.2f}%, val_std_acc={test_student_accuracy:.2f}%",
    flush=True
    )


                                                                       

Epoch 1 - dstl_loss=0.7593, val_dstl_loss=0.6479 |tch_acc=99.89%, val_tch_acc=89.09% |std_acc=85.18%, val_std_acc=84.10%


                                                                       

Epoch 2 - dstl_loss=0.4161, val_dstl_loss=0.8187 |tch_acc=99.89%, val_tch_acc=89.09% |std_acc=93.24%, val_std_acc=84.26%


                                                                       

Epoch 3 - dstl_loss=0.1989, val_dstl_loss=1.1808 |tch_acc=99.89%, val_tch_acc=89.09% |std_acc=97.31%, val_std_acc=84.03%




## Push to HF hub

In [None]:
model_id = "zakariajaadi/distilbert-base-uncased-imdb"
student_model.push_to_hub(model_id)

model.safetensors:   0%|          | 0.00/211M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/zakariajaadi/distilbert-base-uncased-imdb/commit/95eb74b27e157b4106434a7f762f1c0360306001', commit_message='Upload DistilBertForSequenceClassification', commit_description='', oid='95eb74b27e157b4106434a7f762f1c0360306001', pr_url=None, repo_url=RepoUrl('https://huggingface.co/zakariajaadi/distilbert-base-uncased-imdb', endpoint='https://huggingface.co', repo_type='model', repo_id='zakariajaadi/distilbert-base-uncased-imdb'), pr_revision=None, pr_num=None)