In [1]:
!pip install torch transformers datasets  --quiet

    sys-platform (=="darwin") ; extra == 'objc'
                 ~^[0m[33m
[0m

In [5]:
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from torch.nn import functional as F
from sklearn.metrics import accuracy_score

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load teacher and student models
teacher_model_name = "bert-large-uncased"
student_model_name = "bert-base-uncased"
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model_name, num_labels=2).to(device)
student_model = AutoModelForSequenceClassification.from_pretrained(student_model_name, num_labels=2).to(device)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(student_model_name)

# Load SST-2 dataset
dataset = load_dataset("glue", "sst2")

# Data collator function
def data_collator(features):
    batch = {k: torch.tensor([f[k] for f in features]) for k in features[0].keys()}
    return batch

# Distillation loss function
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, temperature=2.0):
    loss_ce = F.cross_entropy(student_logits, labels)
    loss_kl = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=-1),
        F.softmax(teacher_logits / temperature, dim=-1),
        reduction="batchmean"
    ) * (temperature ** 2)
    return alpha * loss_ce + (1 - alpha) * loss_kl

# Modify the preprocess_function
def preprocess_function(examples):
    result = tokenizer(examples['sentence'], truncation=True, padding="max_length", max_length=128)
    result["labels"] = examples["label"]
    return result

# Apply preprocessing and remove the original 'label' column
encoded_dataset = dataset.map(preprocess_function, batched=True, remove_columns=['label', 'sentence', 'idx'])

# Print sample to verify
print("Sample input:", encoded_dataset["train"][0])
print("Keys in sample:", encoded_dataset["train"][0].keys())

# Modify the DistillationTrainer class
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(self, model, inputs, return_outputs=False):
        # Ensure inputs are on the correct device
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        # Extract labels
        labels = inputs.pop("labels")

        # Forward pass through the model
        outputs = model(**inputs)
        student_logits = outputs.logits

        # Teacher model forward pass
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
            teacher_logits = teacher_outputs.logits

        # Compute distillation loss
        loss = distillation_loss(student_logits, teacher_logits, labels, self.alpha, self.temperature)

        if return_outputs:
            outputs["loss"] = loss
            return (loss, outputs)
        return loss

# Compute metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, predictions)}

# Training arguments for teacher model fine-tuning
teacher_training_args = TrainingArguments(
    output_dir="./teacher_model_sst2",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    logging_dir="./teacher_logs_sst2",
    save_total_limit=1,
    save_steps=1000,
    learning_rate=2e-5,
    weight_decay=0.01,
)

# Trainer for teacher model
teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    compute_metrics=compute_metrics,
)

# Evaluate the teacher model before fine-tuning
teacher_results = teacher_trainer.evaluate()
print(f"Pre Fine-tuned Teacher Model Performance: {teacher_results}")

# Fine-tune the teacher model
print("Fine-tuning the teacher model...")
teacher_trainer.train()

# Evaluate the fine-tuned teacher model
teacher_results = teacher_trainer.evaluate()
print(f"Fine-tuned Teacher Model Performance: {teacher_results}")

# Save the fine-tuned teacher model
teacher_model.save_pretrained("./fine_tuned_teacher_model_sst2")

# Training arguments for distillation
training_args = TrainingArguments(
    output_dir="./distilled_model_sst2",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    logging_dir="./logs_sst2",
    save_total_limit=2,
    save_steps=1000,
    learning_rate=5e-5,
    weight_decay=0.01,
    dataloader_pin_memory=True,
)

# Trainer instance for distillation (using the fine-tuned teacher model)
distillation_trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    teacher_model=teacher_model,
    compute_metrics=compute_metrics
)

# Evaluate the student model before distillation
pre_distillation_results = distillation_trainer.evaluate()
print(f"Pre-distillation Student Model Performance: {pre_distillation_results}")

# Train the student model (distillation)
print("Starting distillation...")
distillation_trainer.train()

# Evaluate the distilled student model
post_distillation_results = distillation_trainer.evaluate()
print(f"Post-distillation Student Model Performance: {post_distillation_results}")

# Save the distilled student model
student_model.save_pretrained("./distilled_student_model_sst2")
tokenizer.save_pretrained("./distilled_student_model_sst2")

# Final comparison
print("\nPerformance Comparison:")
print(f"Fine-tuned Teacher Model: {teacher_results}")
print(f"Pre-distillation Student Model: {pre_distillation_results}")
print(f"Post-distillation Student Model: {post_distillation_results}")

Using device: cuda


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-large-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

Sample input: {'input_ids': [101, 5342, 2047, 3595, 8496, 2013, 1996, 18643, 3197, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

Pre Fine-tuned Teacher Model Performance: {'eval_loss': 0.7186707854270935, 'eval_accuracy': 0.43004587155963303, 'eval_runtime': 4.3114, 'eval_samples_per_second': 202.257, 'eval_steps_per_second': 3.247}
Fine-tuning the teacher model...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.1618,0.204711,0.932339
2,0.1092,0.267002,0.916284
3,0.0604,0.293599,0.919725


Fine-tuned Teacher Model Performance: {'eval_loss': 0.2935987710952759, 'eval_accuracy': 0.9197247706422018, 'eval_runtime': 4.4775, 'eval_samples_per_second': 194.752, 'eval_steps_per_second': 3.127, 'epoch': 3.0}


Pre-distillation Student Model Performance: {'eval_loss': 1.4624680280685425, 'eval_accuracy': 0.4908256880733945, 'eval_runtime': 6.188, 'eval_samples_per_second': 140.919, 'eval_steps_per_second': 2.262}
Starting distillation...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.2603,0.296434,0.912844
2,0.1327,0.323626,0.915138
3,0.0636,0.318399,0.924312


Post-distillation Student Model Performance: {'eval_loss': 0.31839948892593384, 'eval_accuracy': 0.9243119266055045, 'eval_runtime': 5.8479, 'eval_samples_per_second': 149.113, 'eval_steps_per_second': 2.394, 'epoch': 3.0}

Performance Comparison:
Fine-tuned Teacher Model: {'eval_loss': 0.2935987710952759, 'eval_accuracy': 0.9197247706422018, 'eval_runtime': 4.4775, 'eval_samples_per_second': 194.752, 'eval_steps_per_second': 3.127, 'epoch': 3.0}
Pre-distillation Student Model: {'eval_loss': 1.4624680280685425, 'eval_accuracy': 0.4908256880733945, 'eval_runtime': 6.188, 'eval_samples_per_second': 140.919, 'eval_steps_per_second': 2.262}
Post-distillation Student Model: {'eval_loss': 0.31839948892593384, 'eval_accuracy': 0.9243119266055045, 'eval_runtime': 5.8479, 'eval_samples_per_second': 149.113, 'eval_steps_per_second': 2.394, 'epoch': 3.0}
