### Overview
- This notebook shows an example on how to perform Knowledge distillation of Bert Model to a smaller model using KL divergence loss. 
- An example phishing dataset is used for demo
- These are the following steps:
    0. We load the dataset
    1. We train a teacher BERT model on this dataset
    2. We train a smaller student BERT model using KL divergence loss

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
# Warning control
import warnings
warnings.filterwarnings('ignore')

In [3]:
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
import evaluate
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd
import torch

from transformers import DistilBertForSequenceClassification, DistilBertConfig
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support


### 0. Load dataset

In [4]:
# Load the data
dataset_dict = load_dataset("imdb")
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
dataset_dict

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

In [5]:
# Split train into train & dev
train_dataset = dataset_dict['train']

train_df = pd.DataFrame(train_dataset)
X = train_df.drop('label', axis=1) 
y = train_df['label']
X_train, X_dev, y_train, y_dev = train_test_split(
    X, y, test_size=0.1, random_state=42, stratify=y
)

train_dataset = Dataset.from_pandas(pd.concat([X_train, y_train], axis=1))
dev_dataset = Dataset.from_pandas(pd.concat([X_dev, y_dev], axis=1))
test_dataset = dataset_dict['test']

train_dataset, dev_dataset, test_dataset

(Dataset({
     features: ['text', 'label', '__index_level_0__'],
     num_rows: 22500
 }),
 Dataset({
     features: ['text', 'label', '__index_level_0__'],
     num_rows: 2500
 }),
 Dataset({
     features: ['text', 'label'],
     num_rows: 25000
 }))

In [6]:
# View examples
train_dataset[:5]

{'text': ['"Algie, the Miner" is one bad and unfunny silent comedy. The timing of the slapstick is completely off. This is the kind of humor with certain sequences that make you wonder if they\'re supposed to be funny or not. However, the actual quality of the film is irrelevant. This is mandatory viewing for film buffs mainly because its one of the earliest examples of gay cinema. The main character of Algie is an effeminate guy, acting much like the stereotypical "pansy" common in many early films. The film has the homophobic attitude common of the time. "Algie, the Miner" is pretty awful, but fascinating from a historical viewpoint. (3/10)',
  'This is a complete Hoax...<br /><br />The movie clearly has been shot in north western Indian state of Rajasthan. Look at the chase scene - the vehicles are Indian; the writing all over is Hindi - language used in India. The drive through is on typical Jaipur streets. Also the palace is in Amer - about 10 miles from Jaipur, Rajasthan. The fil

In [7]:
# Optionally select only few training examples to verify functionality of the code
if 0:
    train_dataset = train_dataset.select(range(100))
    dev_dataset = dev_dataset.select(range(50))
    test_dataset = test_dataset.select(range(50))

### 1. Train teacher model

In [8]:
# Load the model
model_path = "google-bert/bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_path, 
                                                           num_labels=2, 
                                                           id2label=id2label, 
                                                           label2id=label2id,)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
# Num of model params
num_params = sum(p.numel() for p in model.parameters())
print(f"Model's number of parameters: {num_params:,}")

Model's number of parameters: 109,483,778


In [10]:
# Only train classifier layer. Freeze remaining params
for name, param in model.named_parameters():
    if ("encoder" in name) or ("pooler" in name):
        param.requires_grad = False

# print layers
for name, param in model.named_parameters():
   print(name, param.requires_grad)

bert.embeddings.word_embeddings.weight True
bert.embeddings.position_embeddings.weight True
bert.embeddings.token_type_embeddings.weight True
bert.embeddings.LayerNorm.weight True
bert.embeddings.LayerNorm.bias True
bert.encoder.layer.0.attention.self.query.weight False
bert.encoder.layer.0.attention.self.query.bias False
bert.encoder.layer.0.attention.self.key.weight False
bert.encoder.layer.0.attention.self.key.bias False
bert.encoder.layer.0.attention.self.value.weight False
bert.encoder.layer.0.attention.self.value.bias False
bert.encoder.layer.0.attention.output.dense.weight False
bert.encoder.layer.0.attention.output.dense.bias False
bert.encoder.layer.0.attention.output.LayerNorm.weight False
bert.encoder.layer.0.attention.output.LayerNorm.bias False
bert.encoder.layer.0.intermediate.dense.weight False
bert.encoder.layer.0.intermediate.dense.bias False
bert.encoder.layer.0.output.dense.weight False
bert.encoder.layer.0.output.dense.bias False
bert.encoder.layer.0.output.LayerNor

In [11]:
# Tokenizer & Preprocess text
tokenizer = AutoTokenizer.from_pretrained(model_path)

def preprocess_function(examples):
    return tokenizer(examples["text"], padding='max_length', truncation=True)

train_tokenized_data = train_dataset.map(preprocess_function, batched=True)
dev_tokenized_data = dev_dataset.map(preprocess_function, batched=True)
test_tokenized_data = test_dataset.map(preprocess_function, batched=True)

train_tokenized_data, len(train_tokenized_data['input_ids'])

Map: 100%|██████████| 22500/22500 [00:15<00:00, 1417.61 examples/s]
Map: 100%|██████████| 2500/2500 [00:01<00:00, 1474.30 examples/s]
Map: 100%|██████████| 25000/25000 [00:17<00:00, 1469.24 examples/s]


(Dataset({
     features: ['text', 'label', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 22500
 }),
 22500)

In [12]:
# Compute metrics function
eval_acc = evaluate.load("accuracy")
eval_roc_auc = evaluate.load("roc_auc")
def compute_metrics(eval_pred):
    predictions, labels = eval_pred    # shape: (num_examples, num_classes), (num_examples, )
    
    probabilities = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True)
    positive_class_probs = probabilities[:, 1]
    roc_auc = eval_roc_auc.compute(prediction_scores=positive_class_probs, references=labels)['roc_auc']
    
    preds = np.argmax(predictions, axis=1)
    accuracy = eval_acc.compute(predictions=preds, references=labels)

    return {"accuracy": accuracy, "roc_auc": roc_auc}


In [13]:
# create data collator
# Transforms list of dicts to dict of lists (batch)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
batch_size=64
num_epochs=3

training_args = TrainingArguments(
    output_dir="output_dir/bert_imdb_teacher_model",
    learning_rate=2e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    dataloader_num_workers=8
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=data_collator,
    args=training_args,
    train_dataset=train_tokenized_data,
    eval_dataset=dev_tokenized_data,
    compute_metrics=compute_metrics,
)

trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Roc Auc
1,0.4338,0.361541,{'accuracy': 0.856},0.953689
2,0.1987,0.344137,{'accuracy': 0.8712},0.958824
3,0.1111,0.385555,{'accuracy': 0.8772},0.959486


TrainOutput(global_step=1056, training_loss=0.24787978692488236, metrics={'train_runtime': 386.9079, 'train_samples_per_second': 174.46, 'train_steps_per_second': 2.729, 'total_flos': 1.77599962368e+16, 'train_loss': 0.24787978692488236, 'epoch': 3.0})

In [18]:
# Score on test set

predictions = trainer.predict(test_tokenized_data)
logits, labels = predictions.predictions, predictions.label_ids
metrics = compute_metrics((logits, labels))
print(metrics)

# More eval metrics
preds = np.argmax(logits, axis=1)
accuracy = accuracy_score(labels, preds)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
print(f"Teacher model:") 
print(f"Test set - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

{'accuracy': {'accuracy': 0.86472}, 'roc_auc': 0.9529937472000001}
Teacher model:
Test set - Accuracy: 0.8647, Precision: 0.8120, Recall: 0.9492, F1 Score: 0.8753


In [19]:
# Select a test example
print(test_dataset[0])
print("GT label:", id2label[test_dataset[0]['label']])

{'text': 'I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn\'t match the background, and painfully one-dimensional characters cannot be overcome with a \'sci-fi\' setting. (I\'m sure there are those of you out there who think Babylon 5 is good sci-fi TV. It\'s not. It\'s clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It\'s really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it\'s rubbish as 

In [20]:
# Run Inference on this test example

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
input_text = test_dataset[0]['text']
inputs = tokenizer(input_text, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    preds = torch.argmax(logits, dim=-1)
pred = model.config.id2label[preds.item()]
print(f"Model prediction: {pred}")

Model prediction: NEGATIVE


### 2. Train student model

In [66]:
# Load teacher model and tokenizer
model_path = "output_dir/bert_imdb_teacher_model/checkpoint-1056"
tokenizer = AutoTokenizer.from_pretrained(model_path)
teacher_model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)

In [67]:
# Load student model
small_bert_config = DistilBertConfig(n_heads=8, n_layers=1)
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased",
                                                                config=small_bert_config,).to(device)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [68]:
# Num of model params
num_params = sum(p.numel() for p in teacher_model.parameters())
print(f"Teacher model number of parameters: {num_params:,}")

num_params = sum(p.numel() for p in student_model.parameters())
print(f"Student model number of parameters: {num_params:,}")

Teacher model number of parameters: 109,483,778
Student model number of parameters: 31,515,650


In [69]:
# # Only train classifier layer. Freeze remaining params
# for name, param in student_model.named_parameters():
#     if ("distilbert" in name) or ("pre_classifier" in name):
#         param.requires_grad = False

# # Number of trainable params
# num_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
# print(f"Student model number of trainable parameters: {num_params:,}")

# print layers
for name, param in student_model.named_parameters():
   print(name, param.requires_grad)

distilbert.embeddings.word_embeddings.weight True
distilbert.embeddings.position_embeddings.weight True
distilbert.embeddings.LayerNorm.weight True
distilbert.embeddings.LayerNorm.bias True
distilbert.transformer.layer.0.attention.q_lin.weight True
distilbert.transformer.layer.0.attention.q_lin.bias True
distilbert.transformer.layer.0.attention.k_lin.weight True
distilbert.transformer.layer.0.attention.k_lin.bias True
distilbert.transformer.layer.0.attention.v_lin.weight True
distilbert.transformer.layer.0.attention.v_lin.bias True
distilbert.transformer.layer.0.attention.out_lin.weight True
distilbert.transformer.layer.0.attention.out_lin.bias True
distilbert.transformer.layer.0.sa_layer_norm.weight True
distilbert.transformer.layer.0.sa_layer_norm.bias True
distilbert.transformer.layer.0.ffn.lin1.weight True
distilbert.transformer.layer.0.ffn.lin1.bias True
distilbert.transformer.layer.0.ffn.lin2.weight True
distilbert.transformer.layer.0.ffn.lin2.bias True
distilbert.transformer.lay

In [70]:
# Tokenizer & Preprocess text
tokenizer = AutoTokenizer.from_pretrained(model_path)

def preprocess_function(examples):
    return tokenizer(examples["text"], padding='max_length', truncation=True)

train_tokenized_data = train_dataset.map(preprocess_function, batched=True)
dev_tokenized_data = dev_dataset.map(preprocess_function, batched=True)
test_tokenized_data = test_dataset.map(preprocess_function, batched=True)

train_tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
dev_tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

train_tokenized_data, len(train_tokenized_data['input_ids'])

Map:   0%|          | 0/22500 [00:00<?, ? examples/s]

Map: 100%|██████████| 22500/22500 [00:15<00:00, 1443.13 examples/s]
Map: 100%|██████████| 2500/2500 [00:01<00:00, 1438.94 examples/s]
Map: 100%|██████████| 25000/25000 [00:17<00:00, 1458.88 examples/s]


(Dataset({
     features: ['text', 'label', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 22500
 }),
 22500)

In [71]:
batch_size = 64

# create training data loader
train_dataloader = DataLoader(train_tokenized_data, batch_size=batch_size, num_workers=8)
# create dev data loader
dev_dataloader = DataLoader(dev_tokenized_data, batch_size=batch_size, num_workers=8)
# create testing data loader
test_dataloader = DataLoader(test_tokenized_data, batch_size=batch_size, num_workers=8)

print(len(train_dataloader), len(dev_dataloader), len(test_dataloader))


352 40 391


In [72]:
# Evaluate function
def evaluate_model(model, dataloader, device):
    model.eval()  
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    return accuracy, precision, recall, f1


In [73]:
# Teacher model - dev & test set scores

print(f"Teacher:")

accuracy, precision, recall, f1 = evaluate_model(teacher_model, dev_dataloader, device)
print(f"Dev set - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

accuracy, precision, recall, f1 = evaluate_model(teacher_model, test_dataloader, device)
print(f"Test set - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")


Teacher:
Dev set - Accuracy: 0.8772, Precision: 0.8327, Recall: 0.9440, F1 Score: 0.8849
Test set - Accuracy: 0.8684, Precision: 0.8187, Recall: 0.9464, F1 Score: 0.8780


In [74]:
# Student model - before training - dev & test set scores

print(f"Student (before training):")

accuracy, precision, recall, f1 = evaluate_model(student_model, dev_dataloader, device)
print(f"Dev set - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

accuracy, precision, recall, f1 = evaluate_model(student_model, test_dataloader, device)
print(f"Test set - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")


Student (before training):
Dev set - Accuracy: 0.5248, Precision: 0.5143, Recall: 0.8936, F1 Score: 0.6528
Test set - Accuracy: 0.5151, Precision: 0.5087, Recall: 0.8798, F1 Score: 0.6447


In [75]:
# Training loop

# Hyperparameters
batch_size = 64
lr = 1e-4
num_epochs = 2
T = 2.0 # temperature
alpha = 0.5

# Optimizer
optimizer = optim.Adam(student_model.parameters(), lr=lr)

# Knowledge Distillation loss (KL div loss + CE loss)
def knowledge_distillation_loss(student_logits, teacher_logits, true_labels, T=2.0, alpha=0.5):    
    # KL Divergence loss for distillation
    teacher_probs = nn.functional.softmax(teacher_logits / T, dim=1)
    student_log_probs = nn.functional.log_softmax(student_logits / T, dim=1)
    distill_loss = nn.functional.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (T ** 2)
    # Cross-entropy loss for true labels
    ce_loss = nn.CrossEntropyLoss()(student_logits, true_labels)
    # Total loss
    loss = alpha * distill_loss + (1.0 - alpha) * ce_loss
    return loss

# Train
for epoch in range(num_epochs):
    for batch_ix, batch in enumerate(train_dataloader):

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits

        # Student model
        student_outputs = student_model(input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits

        # Loss + Backprop
        loss = knowledge_distillation_loss(student_logits, teacher_logits, labels, T, alpha)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_ix + 1) % 50 == 0:
            print(f"Processed {batch_ix + 1} batches")

    print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")

    # Dev set scores
    accuracy, precision, recall, f1 = evaluate_model(student_model, dev_dataloader, device)
    print(f"Student Dev set - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

    student_model.train()

Processed 50 batches
Processed 100 batches
Processed 150 batches
Processed 200 batches
Processed 250 batches
Processed 300 batches
Processed 350 batches
Epoch 1 completed with loss: 0.33669546246528625
Student Dev set - Accuracy: 0.8932, Precision: 0.9032, Recall: 0.8808, F1 Score: 0.8919
Processed 50 batches
Processed 100 batches
Processed 150 batches
Processed 200 batches
Processed 250 batches
Processed 300 batches
Processed 350 batches
Epoch 2 completed with loss: 0.15537892282009125
Student Dev set - Accuracy: 0.8880, Precision: 0.8719, Recall: 0.9096, F1 Score: 0.8904


In [76]:
# Test set scores
print(f"Student (after training):")

accuracy, precision, recall, f1 = evaluate_model(student_model, test_dataloader, device)
print(f"Test set - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

Student (after training):
Test set - Accuracy: 0.8744, Precision: 0.8647, Recall: 0.8876, F1 Score: 0.8760


### References
1. https://github.com/ShawhinT/YouTube-Blog/blob/main/LLMs/model-compression
2. https://huggingface.co/docs/transformers/en/tasks/sequence_classification

