In [None]:
!pip install transformers datasets peft bitsandbytes

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType
import torch.nn.functional as F

teacher_model_name = "meta-llama/Llama-3.1-8B"
student_model_name = "meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(student_model_name, use_fast=True)

dataset = load_dataset("hendrycks/ethics", "commonsense_train")
print("Dataset columns:", dataset["train"].column_names)

In [None]:
def preprocess_function(examples):
    if "scenario" in examples:
        texts = examples["scenario"]
    elif "prompt" in examples:
        texts = examples["prompt"]
    elif "question" in examples:
        texts = examples["question"]
    elif "text" in examples:
        texts = examples["text"]
    else:
        raise ValueError("No valid text field found in the dataset!")
    return tokenizer(texts, truncation=True, padding="max_length", max_length=128)

tokenized_datasets = dataset.map(preprocess_function, batched=True)

if "label" in tokenized_datasets["train"].column_names:
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

teacher = AutoModelForSequenceClassification.from_pretrained(
    teacher_model_name,
    load_in_8bit=True,
    device_map="auto",
)
teacher.eval()
for param in teacher.parameters():
    param.requires_grad = False

student = AutoModelForSequenceClassification.from_pretrained(
    student_model_name,
    load_in_8bit=True,
    device_map="auto",
)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
)
student = get_peft_model(student, lora_config)
print("Trainable parameters in student:")
student.print_trainable_parameters()

In [None]:
class DistillationTrainer(Trainer):
    def __init__(self, teacher, alpha=0.5, temperature=2.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher
        self.alpha = alpha
        self.temperature = temperature
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        student_logits = outputs.logits
        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits
        ce_loss = F.cross_entropy(student_logits, labels)
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
        kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (self.temperature ** 2)
        loss = self.alpha * kl_loss + (1 - self.alpha) * ce_loss
        return (loss, outputs) if return_outputs else loss

training_args = TrainingArguments(
    output_dir="./distilled_llama3",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    learning_rate=2e-5,
    fp16=True,
    report_to="none",
)

trainer = DistillationTrainer(
    teacher=teacher,
    model=student,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"] if "validation" in tokenized_datasets else None,
)

In [None]:
trainer.train()

trainer.save_model("./distilled_llama3_final")
