# Distill a Finetuned ESM2 150M model into a ESM2 8M model

In [2]:
pwd

'/home/sdowell/scratch/Thesis/distillation'

In [3]:
# Import packages
import sys
import pLM_KD
import torch

2025-05-11 16:54:21.538442: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-11 16:54:21.730983: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-11 16:54:21.732418: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-11 16:54:22.014917: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Load teacher and student models

In [4]:
# Configuration dictionary (or loaded from YML)
config_dict = {
    "train_path": "../BenchmarkingFinetuning/dataset_splits/finetuning_dataset/train.fasta",
    "eval_path": "../BenchmarkingFinetuning/dataset_splits/finetuning_dataset/valid.fasta",
    "base_model": "facebook/esm2_t6_8M_UR50D",  # Student model
    "teacher_model_path": "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_150m_ecoli_finetuning_1/checkpoint-19000",
    "student_model_path": "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_8m_ecoli_finetuning_2/checkpoint-11500",
    "wandb_project": "esm2_knowledge_distillation",
    "training_args": {
        "output_dir": "distilled_esm2_model",
        "per_device_train_batch_size": 32,  
        "per_device_eval_batch_size": 32,
        "num_train_epochs": 100,
        "learning_rate": 1e-4,
        "alpha": 0.5,
        "temperature": 2.0,
        "fp16": True
    }
}

# Create a DistillationConfig instance
config = pLM_KD.DistillationConfig(**config_dict)

# Choose the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============================================
# Load teacher model using PEFT wrapper (LoRA)
# ============================================
from peft import PeftModel

# Use the base model id for teacher (should match the one used during fine-tuning)
teacher_base_model_id = "facebook/esm2_t30_150M_UR50D"
# Load the base teacher model
teacher_base = pLM_KD.EsmForMaskedLM.from_pretrained(teacher_base_model_id).to(device)
# Wrap the base model with LoRA adapters using the previously fine-tuned checkpoint
teacher_model = PeftModel.from_pretrained(teacher_base, config.teacher_model_path).to(device)
teacher_model.eval()  # Set teacher to evaluation mode
print(f"Teacher model device: {next(teacher_model.parameters()).device}")  # Debug print

# Load tokenizer using the teacher base model identifier (assumes the tokenizer is shared)
tokenizer = pLM_KD.EsmTokenizer.from_pretrained(teacher_base_model_id)

# ============================================
# Load student model
# ============================================
student_model = pLM_KD.EsmForMaskedLM.from_pretrained(config.base_model).to(device)
# Wrap the student model with LoRA adpaters
student_model = PeftModel.from_pretrained(student_model, config.student_model_path).to(device)
student_model.train()
print(f"Student model device: {next(student_model.parameters()).device}")  # Debug print

# Set max length
max_length = min(
    getattr(teacher_model.config, "max_position_embeddings", 1024),
    getattr(student_model.config, "max_position_embeddings", 1024)
)

# Ensure max_length is a multiple of pad_to_multiple_of
pad_to_multiple_of = 8 if config.training_args.fp16 else None
if pad_to_multiple_of:
    max_length = (max_length // pad_to_multiple_of) * pad_to_multiple_of
print(f"Using max_length: {max_length}, which is divisible by {pad_to_multiple_of}")

tokenizer.model_max_length = max_length

# Load sequences
train_sequences = [seq.sequence for seq in pLM_KD.read_fasta(config.train_path)]
eval_sequences = [seq.sequence for seq in pLM_KD.read_fasta(config.eval_path)]
print(f"Loaded {len(train_sequences)} training and {len(eval_sequences)} evaluation sequences")

# Create datasets
train_dataset = pLM_KD.SequenceDataset(train_sequences)
eval_dataset = pLM_KD.SequenceDataset(eval_sequences)

# Initialize data collator
data_collator = pLM_KD.HybridDataCollator(
    tokenizer=tokenizer,
    model_type="mlm",
    mlm_probability=0.15,
    max_length=max_length,
    pad_to_multiple_of=8 if config.training_args.fp16 else None
)

# Initialize distillation trainer
trainer = pLM_KD.DistillationTrainer(
    teacher_model=teacher_model,
    alpha=config.training_args.alpha,
    temperature=config.training_args.temperature,
    model=student_model,
    args=config.training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

# Train the model
#train_result = trainer.train()

# Save the model
#trainer.save_model()

# Evaluate the model
#eval_metrics = trainer.evaluate()
#print("Evaluation metrics:", eval_metrics)


[34m[1mwandb[0m: Currently logged in as: [33msdowell[0m ([33msdowell1[0m). Use [1m`wandb login --relogin`[0m to force relogin


2025-05-11 16:55:19,130 INFO: Distillation configuration saved to distilled_esm2_model/distillation_config.yaml


Using device: cuda
Teacher model device: cuda:0
Student model device: cuda:0
Using max_length: 1024, which is divisible by 8
Loaded 7489 training and 1404 evaluation sequences


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


# Debugging

In [None]:
import torch
import torch.nn.functional as F

# Create a small DataLoader for the evaluation set using your data collator
from torch.utils.data import DataLoader

# For example, use batch_size=8 (or any small batch size)
eval_loader = DataLoader(eval_dataset, batch_size=8, collate_fn=data_collator)

# Get one batch from the evaluation data
batch = next(iter(eval_loader))

# Ensure the batch is on the correct device
for key, value in batch.items():
    if isinstance(value, torch.Tensor):
        batch[key] = value.to(device)

# Set models to evaluation mode (if not already)
teacher_model.eval()
student_model.eval()

# Forward pass through both models (teacher: no grad needed)
with torch.no_grad():
    teacher_outputs = teacher_model(**batch)
    student_outputs = student_model(**batch)

# Extract logits
teacher_logits = teacher_outputs.logits  # shape: [batch_size, seq_length, vocab_size]
student_logits = student_outputs.logits

# Retrieve the temperature from your configuration
T = config.training_args.temperature

# Compute softened probability distributions:
# Teacher uses softmax, and student uses log_softmax for KL divergence stability
teacher_probs = F.softmax(teacher_logits / T, dim=-1)
student_log_probs = F.log_softmax(student_logits / T, dim=-1)

# Compute the KL divergence with batch mean reduction and scale by T^2
kl_div = F.kl_div(student_log_probs, teacher_probs, reduction='mean') * (T ** 2)

print(f"KL divergence on one evaluation batch: {kl_div.item():.4f}")


In [21]:
student_log_probs[0][0]

tensor([-0.7156, -4.9638, -4.8050, -4.9654, -3.8123, -3.6673, -3.7609, -3.6731,
        -3.7292, -3.8123, -3.5947, -3.6705, -3.8625, -3.8000, -3.9839, -3.7046,
        -3.7173, -3.9350, -3.8993, -3.7534, -3.8414, -3.9192, -4.2384, -3.9956,
        -3.6264, -5.9818, -6.0206, -6.1081, -6.7169, -6.8043, -6.7890, -6.8330,
        -4.9627], device='cuda:0')

In [22]:
teacher_probs[0][0]

tensor([0.3348, 0.0121, 0.0263, 0.0121, 0.0243, 0.0304, 0.0265, 0.0598, 0.0569,
        0.0315, 0.0268, 0.0361, 0.0324, 0.0194, 0.0456, 0.0223, 0.0223, 0.0166,
        0.0273, 0.0162, 0.0260, 0.0109, 0.0132, 0.0221, 0.0101, 0.0082, 0.0053,
        0.0046, 0.0020, 0.0019, 0.0022, 0.0014, 0.0125], device='cuda:0')

In [9]:
from peft import get_peft_model_state_dict

# This prints only LoRA-adapted parameters
print("Trainable LoRA parameters in student model:", len(get_peft_model_state_dict(student_model)))

# Optional: print all trainable params by name
for name, param in student_model.named_parameters():
    if param.requires_grad:
        print(f"{name}: requires_grad={param.requires_grad}")
print("")
for name, _ in student_model.named_modules():
    if "attention" in name or "key" in name or "value" in name:
        print(name)


Trainable LoRA parameters in student model: 30
base_model.model.lm_head.modules_to_save.default.bias: requires_grad=True
base_model.model.lm_head.modules_to_save.default.dense.weight: requires_grad=True
base_model.model.lm_head.modules_to_save.default.dense.bias: requires_grad=True
base_model.model.lm_head.modules_to_save.default.layer_norm.weight: requires_grad=True
base_model.model.lm_head.modules_to_save.default.layer_norm.bias: requires_grad=True
base_model.model.lm_head.modules_to_save.default.decoder.weight: requires_grad=True

base_model.model.esm.encoder.layer.0.attention
base_model.model.esm.encoder.layer.0.attention.self
base_model.model.esm.encoder.layer.0.attention.self.query
base_model.model.esm.encoder.layer.0.attention.self.key
base_model.model.esm.encoder.layer.0.attention.self.key.base_layer
base_model.model.esm.encoder.layer.0.attention.self.key.lora_dropout
base_model.model.esm.encoder.layer.0.attention.self.key.lora_dropout.default
base_model.model.esm.encoder.layer