In [1]:
pwd

'/home/sdowell/scratch/Thesis/distillation'

In [2]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments
from peft import PeftModel
import torch
import torch.nn.functional as F

# —————————————————————————————
# 1. Paths to your LoRA adapters
teacher_lora_path = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_150m_ecoli_finetuning_1/checkpoint-19000"
student_lora_path = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_8m_ecoli_finetuning_2/checkpoint-11500"
# —————————————————————————————

# 2. Tokenizer (same for both)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D", do_lower_case=False)

# 3. Load base models
base_teacher = AutoModelForMaskedLM.from_pretrained(
    "facebook/esm2_t30_150M_UR50D", output_hidden_states=True, return_dict=True
)
base_student = AutoModelForMaskedLM.from_pretrained(
    "facebook/esm2_t6_8M_UR50D", output_hidden_states=True, return_dict=True
)

# 4. Attach your fine-tuned LoRA adapters
teacher = PeftModel.from_pretrained(base_teacher, teacher_lora_path)
student = PeftModel.from_pretrained(base_student, student_lora_path)

# 5. Freeze everything except student’s LoRA
for p in teacher.parameters():            
    p.requires_grad = False
for name, p in student.named_parameters():
    if "lora_" not in name:
        p.requires_grad = False

teacher.eval()

# —————————————————————————————
# 6. Distillation loss
def distill_loss(student_logits, teacher_logits, T=2.0):
    """KL(student||teacher) on softened logits."""
    s = student_logits / T
    t = teacher_logits / T
    kl = F.kl_div(
        F.log_softmax(s, dim=-1),
        F.softmax(t, dim=-1),
        reduction="batchmean"
    )
    return kl * (T * T)

# 7. Hook into Trainer
def compute_loss(model, inputs, return_outputs=False):
    # forward teacher
    with torch.no_grad():
        t_out = teacher(**inputs)
    # forward student
    s_out = model(**inputs)

    loss = distill_loss(s_out.logits, t_out.logits, T=2.0)
    # optional: combine with CE on masked labels
    if inputs.get("labels") is not None:
        ce = F.cross_entropy(
            s_out.logits.view(-1, s_out.logits.size(-1)),
            inputs["labels"].view(-1)
        )
        alpha = 0.5
        loss = alpha * ce + (1 - alpha) * loss

    return (loss, s_out) if return_outputs else loss
# —————————————————————————————

# 8. Your datasets
train_dataset = "../BenchmarkingFinetuning/dataset_splits/finetuning_dataset/train.fasta"
eval_dataset  = "../BenchmarkingFinetuning/dataset_splits/finetuning_dataset/valid.fasta"

# 9. Training arguments
training_args = TrainingArguments(
    output_dir="distilled-esm2-8M",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    learning_rate=1e-4,
    num_train_epochs=100,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    eval_steps=1,
    save_steps=1,
    logging_steps=1,
    fp16=True,
)

optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4)
student.train()

for epoch in range(training_args.num_train_epochs):
    for batch in train_dataloader:
        batch = {k:v.to(device) for k,v in batch.items()}
        with torch.no_grad():
            teacher_out = teacher(**batch)
        student_out = student(**batch)

        loss = distill_loss(student_out.logits, teacher_out.logits, T=2.0)
        if "labels" in batch:
            ce = F.cross_entropy(
                student_out.logits.view(-1, student_out.logits.size(-1)),
                batch["labels"].view(-1)
            )
            loss = 0.5 * ce + 0.5 * loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # log to W&B, checkpoint, etc.



trainer.train()

# 12. Save back only the student LoRA weights
student.save_pretrained("distilled-esm2-8M-lora")


2025-05-13 15:53:36.202231: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-13 15:53:36.346787: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-13 15:53:36.347584: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-13 15:53:36.597387: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  trainer = Trainer(


TypeError: Trainer.__init__() got an unexpected keyword argument 'compute_loss'

In [8]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    get_linear_schedule_with_warmup
)
from peft import PeftModel
import wandb
from datasets import load_dataset

# ─── Hyper-parameters ─────────────────────────────────────
teacher_lora_path = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_150m_ecoli_finetuning_1/checkpoint-19000"
student_lora_path = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_8m_ecoli_finetuning_2/checkpoint-11500"
train_fasta         = "../BenchmarkingFinetuning/dataset_splits/finetuning_dataset/train.fasta"
valid_fasta         = "../BenchmarkingFinetuning/dataset_splits/finetuning_dataset/valid.fasta"

output_dir       = "distilled-esm2-8M"
epochs           = 3
train_bs         = 16
eval_bs          = 32
lr               = 1e-4
mlm_prob         = 0.15
T                = 2.0
logging_steps    = 100
eval_steps       = 500
save_steps       = 500

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─── Models & Tokenizer ────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(
    "facebook/esm2_t30_150M_UR50D", do_lower_case=False
)
base_teacher = AutoModelForMaskedLM.from_pretrained(
    "facebook/esm2_t30_150M_UR50D", output_hidden_states=True, return_dict=True
)
base_student = AutoModelForMaskedLM.from_pretrained(
    "facebook/esm2_t6_8M_UR50D", output_hidden_states=True, return_dict=True
)
teacher = PeftModel.from_pretrained(base_teacher, teacher_lora_path)
student = PeftModel.from_pretrained(base_student, student_lora_path)

# ─── Freeze parameters ────────────────────────────────────
# Freeze all teacher params
teacher.eval()
for param in teacher.parameters():
    param.requires_grad = False
# Freeze student base model, leave only LoRA adapter params trainable
for name, param in student.named_parameters():
    if 'lora_' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

teacher.to(device)
student.to(device).train()

# ─── Data ─────────────────────────────────────────────────
ds = load_dataset(
    "text", data_files={"train": train_fasta, "validation": valid_fasta}
)
def tokenize_fn(examples):
    return tokenizer(
        examples["text"], truncation=True, max_length=1024
    )
ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"])
ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
train_ds = ds["train"]
eval_ds  = ds["validation"]

collator = DataCollatorForLanguageModeling(
    tokenizer, mlm=True, mlm_probability=mlm_prob
)
train_loader = DataLoader(
    train_ds, batch_size=train_bs, shuffle=True, collate_fn=collator
)
eval_loader = DataLoader(
    eval_ds, batch_size=eval_bs, shuffle=False, collate_fn=collator
)

# ─── Optimizer & Scheduler ────────────────────────────────
optimizer   = torch.optim.AdamW(student.parameters(), lr=lr)
total_steps = len(train_loader) * epochs
scheduler   = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# ─── Mixed-precision setup ─────────────────────────────────
scaler = torch.amp.GradScaler()

# ─── W&B setup ─────────────────────────────────────────────
wandb.login()
wandb.init(
    project="esm2-distill",
    name="manual-loop",
    config={
        "epochs": epochs,
        "train_bs": train_bs,
        "eval_bs": eval_bs,
        "lr": lr,
        "mlm_prob": mlm_prob,
        "temperature": T,
    }
)

# ─── Distillation loss ────────────────────────────────────
def distill_loss(s_logits, t_logits, T):
    s = s_logits / T
    t = t_logits / T
    kld = F.kl_div(
        F.log_softmax(s, dim=-1),
        F.softmax(t, dim=-1),
        reduction="batchmean"
    )
    return kld * (T * T)

# ─── Training loop ────────────────────────────────────────
global_step = 0
for epoch in range(1, epochs + 1):
    # Training batches
    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}

        # Teacher forward
        with torch.no_grad():
            t_out = teacher(**batch)

        optimizer.zero_grad()
        # Student forward + compute loss
        with torch.amp.autocast(device_type="cuda"):
            s_out = student(**batch)
            loss_kd = distill_loss(s_out.logits, t_out.logits, T)
            if batch.get("labels") is not None:
                ce = F.cross_entropy(
                    s_out.logits.view(-1, s_out.logits.size(-1)),
                    batch["labels"].view(-1)
                )
                loss = 0.5 * ce + 0.5 * loss_kd
            else:
                loss = loss_kd

        # Backward + update
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        global_step += 1

        # Log training metrics
        if global_step % logging_steps == 0:
            wandb.log({
                "train/loss": loss.item(),
                "train/lr": scheduler.get_last_lr()[0],
                "step": global_step,
                "epoch": epoch + global_step / total_steps
            })

        # Step-level evaluation
        if global_step % eval_steps == 0:
            student.eval()
            total_eval_loss, eval_batches = 0.0, 0
            for ev_batch in eval_loader:
                ev_batch = {k: v.to(device) for k, v in ev_batch.items()}
                with torch.no_grad():
                    t_eval = teacher(**ev_batch)
                    s_eval = student(**ev_batch)
                    eval_loss = distill_loss(s_eval.logits, t_eval.logits, T)
                total_eval_loss += eval_loss.item()
                eval_batches += 1
            avg_eval_loss = total_eval_loss / eval_batches
            wandb.log({"eval/loss": avg_eval_loss, "step": global_step})
            student.train()

        # Checkpointing
        if global_step % save_steps == 0:
            ckpt_dir = os.path.join(output_dir, f"checkpoint-{global_step}")
            os.makedirs(ckpt_dir, exist_ok=True)
            student.save_pretrained(ckpt_dir)
            wandb.save(f"{ckpt_dir}/*")

    # End-of-epoch validation
    student.eval()
    total_val_loss, val_batches = 0.0, 0
    for ev_batch in eval_loader:
        ev_batch = {k: v.to(device) for k, v in ev_batch.items()}
        with torch.no_grad():
            t_eval = teacher(**ev_batch)
            s_eval = student(**ev_batch)
            val_loss = distill_loss(s_eval.logits, t_eval.logits, T)
        total_val_loss += val_loss.item()
        val_batches += 1
    avg_val_loss = total_val_loss / val_batches
    wandb.log({"validation/loss": avg_val_loss, "epoch": epoch})
    print(f"Finished epoch {epoch}/{epochs} - val_loss: {avg_val_loss:.4f}")
    student.train()

# ─── Final save ────────────────────────────────────────────
student.save_pretrained(output_dir)
wandb.finish()




VBox(children=(Label(value='0.027 MB of 0.027 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112385411332878, max=1.0…



KeyboardInterrupt: 

In [1]:
import os
import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    get_linear_schedule_with_warmup
)
from peft import PeftModel
import wandb
from datasets import load_dataset
from tqdm import tqdm

# ─── Hyper-parameters ─────────────────────────────────────
teacher_lora_path = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_150m_ecoli_finetuning_1/checkpoint-19000"
student_lora_path = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_8m_ecoli_finetuning_2/checkpoint-11500"
train_fasta         = "../BenchmarkingFinetuning/dataset_splits/finetuning_dataset/train.fasta"
valid_fasta         = "../BenchmarkingFinetuning/dataset_splits/finetuning_dataset/valid.fasta"

output_dir       = "distilled-esm2-8M"
epochs           = 50  # Increased for more epochs
train_bs         = 16
eval_bs          = 32
lr               = 1e-4
mlm_prob         = 0.15
T                = 2.0
save_epochs      = 10  # Save every 10 epochs

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# ─── Models & Tokenizer ────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(
    "facebook/esm2_t30_150M_UR50D", do_lower_case=False
)
base_teacher = AutoModelForMaskedLM.from_pretrained(
    "facebook/esm2_t30_150M_UR50D", output_hidden_states=True, return_dict=True
)
base_student = AutoModelForMaskedLM.from_pretrained(
    "facebook/esm2_t6_8M_UR50D", output_hidden_states=True, return_dict=True
)
teacher = PeftModel.from_pretrained(base_teacher, teacher_lora_path)
student = PeftModel.from_pretrained(base_student, student_lora_path)

# Properly handle parameter freezing
teacher.eval()
for p in teacher.parameters():
    p.requires_grad = False

# For student: freeze only the base model, keep adapters trainable
for name, param in student.named_parameters():
    if "lora" not in name.lower():  # Keep LoRA parameters trainable
        param.requires_grad = False
    else:
        param.requires_grad = True  # Explicitly set LoRA parameters as trainable

# Verify we have trainable parameters
trainable_params = sum(p.numel() for p in student.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in student.parameters())
print(f"Trainable params: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")

teacher.to(device)
student.to(device)
student.train()

# ─── Data ─────────────────────────────────────────────────
ds = load_dataset(
    "text", data_files={"train": train_fasta, "validation": valid_fasta}
)
def tokenize_fn(examples):
    return tokenizer(
        examples["text"], truncation=True, max_length=1024
    )
ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"])
ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
train_ds = ds["train"]
eval_ds  = ds["validation"]

collator = DataCollatorForLanguageModeling(
    tokenizer, mlm=True, mlm_probability=mlm_prob
)
train_loader = DataLoader(
    train_ds, batch_size=train_bs, shuffle=True, collate_fn=collator
)
eval_loader = DataLoader(
    eval_ds, batch_size=eval_bs, shuffle=False, collate_fn=collator
)

# ─── Optimizer & Scheduler ────────────────────────────────
# Only optimize parameters that require gradients
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, student.parameters()), 
    lr=lr
)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# ─── Mixed-precision setup ─────────────────────────────────
scaler = torch.cuda.amp.GradScaler()

# ─── W&B setup ─────────────────────────────────────────────
wandb.login()
wandb.init(
    project="esm2-distill",
    name="epoch-based-logging",
    config={
        "epochs": epochs,
        "train_bs": train_bs,
        "eval_bs": eval_bs,
        "lr": lr,
        "mlm_prob": mlm_prob,
        "temperature": T,
    }
)

# ─── Distillation loss ────────────────────────────────────
def distill_loss(s_logits, t_logits, T):
    s = s_logits / T
    t = t_logits / T
    kld = F.kl_div(
        F.log_softmax(s, dim=-1),
        F.softmax(t, dim=-1),
        reduction="batchmean"
    )
    return kld * (T * T)

# ─── Training loop ────────────────────────────────────────
start_time = time.time()

# Suppress all tqdm output except final result
import sys
from contextlib import contextmanager

@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

for epoch in range(1, epochs + 1):
    epoch_train_loss = 0.0
    epoch_train_steps = 0
    epoch_start_time = time.time()
    
    # Simple progress indicator without tqdm
    print(f"\nEpoch {epoch}/{epochs}")
    print("Training...", end='', flush=True)
    
    for i, batch in enumerate(train_loader):
        batch = {k: v.to(device) for k, v in batch.items()}

        optimizer.zero_grad()
        
        # Teacher forward (no gradient needed)
        with torch.no_grad():
            t_out = teacher(**batch)

        # Student forward + compute loss
        with torch.cuda.amp.autocast():
            s_out = student(**batch)
            loss_kd = distill_loss(s_out.logits, t_out.logits, T)
            
            if batch.get("labels") is not None:
                ce = F.cross_entropy(
                    s_out.logits.view(-1, s_out.logits.size(-1)),
                    batch["labels"].view(-1)
                )
                loss = 0.5 * ce + 0.5 * loss_kd
            else:
                loss = loss_kd

        # Backward + update
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        # Track epoch training loss
        epoch_train_loss += loss.item()
        epoch_train_steps += 1
        
        # Print progress dots every 100 batches
        if (i + 1) % 100 == 0:
            print('.', end='', flush=True)
    
    print(" Done!")
    
    # End-of-epoch: calculate average training loss
    avg_train_loss = epoch_train_loss / epoch_train_steps
    
    # End-of-epoch validation
    print("Validating...", end='', flush=True)
    student.eval()
    total_val_loss, val_batches = 0.0, 0
    
    for i, ev_batch in enumerate(eval_loader):
        ev_batch = {k: v.to(device) for k, v in ev_batch.items()}
        with torch.no_grad():
            t_eval = teacher(**ev_batch)
            s_eval = student(**ev_batch)
            val_loss = distill_loss(s_eval.logits, t_eval.logits, T)
        total_val_loss += val_loss.item()
        val_batches += 1
        
        # Print progress dots every 50 batches
        if (i + 1) % 50 == 0:
            print('.', end='', flush=True)
    
    print(" Done!")
    avg_val_loss = total_val_loss / val_batches
    
    # Calculate times
    epoch_time = time.time() - epoch_start_time
    total_time = time.time() - start_time
    eta_seconds = (total_time / epoch) * (epochs - epoch)
    eta_str = f"{eta_seconds/3600:.1f}h" if eta_seconds > 3600 else f"{eta_seconds/60:.0f}m"
    
    # Log to wandb
    wandb.log({
        "train_loss": avg_train_loss,
        "validation_loss": avg_val_loss,
        "learning_rate": scheduler.get_last_lr()[0],
        "epoch": epoch
    })
    
    # Print epoch summary
    print(f"Completed in {epoch_time/60:.1f}m | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | ETA: {eta_str}")
    
    student.train()
    
    # Save checkpoints at specified intervals
    if epoch % save_epochs == 0:
        ckpt_dir = os.path.join(output_dir, f"checkpoint-epoch-{epoch}")
        os.makedirs(ckpt_dir, exist_ok=True)
        student.save_pretrained(ckpt_dir)
        wandb.save(f"{ckpt_dir}/*")
        print(f"✓ Checkpoint saved at epoch {epoch}")
    
    print("-" * 60)

# ─── Final save ────────────────────────────────────────────
print("\nSaving final model...")
student.save_pretrained(output_dir)
wandb.finish()
total_time = time.time() - start_time
print(f"Training completed! Total time: {total_time/3600:.1f} hours")

2025-05-13 17:07:05.580267: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-13 17:07:05.602981: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-13 17:07:05.603028: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-13 17:07:05.619235: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Using device: cuda
GPU Name: NVIDIA A100-SXM4-80GB
GPU Memory: 84.99 GB
Trainable params: 61,440 / 8,016,187 (0.77%)


  scaler = torch.cuda.amp.GradScaler()
[34m[1mwandb[0m: Currently logged in as: [33msdowell[0m ([33msdowell1[0m). Use [1m`wandb login --relogin`[0m to force relogin



Epoch 1/50
Training...

  with torch.cuda.amp.autocast():


............................................................. Done!
Validating.............. Done!
Completed in 7.0m | Train Loss: 32.7651 | Val Loss: 30.2341 | ETA: 5.7h
------------------------------------------------------------

Epoch 2/50
Training................................................................ Done!
Validating.............. Done!
Completed in 6.9m | Train Loss: 13.3446 | Val Loss: 19.9945 | ETA: 5.5h
------------------------------------------------------------

Epoch 3/50
Training................................................................ Done!
Validating.............. Done!
Completed in 6.9m | Train Loss: 10.3499 | Val Loss: 16.5507 | ETA: 5.4h
------------------------------------------------------------

Epoch 4/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.1m | Train Loss: 8.9819 | Val Loss: 14.4167 | ETA: 5.3h
------------------------------------------------------------

Epo



✓ Checkpoint saved at epoch 10
------------------------------------------------------------

Epoch 11/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.0m | Train Loss: 6.5618 | Val Loss: 10.2540 | ETA: 4.5h
------------------------------------------------------------

Epoch 12/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.1m | Train Loss: 6.4931 | Val Loss: 10.1126 | ETA: 4.4h
------------------------------------------------------------

Epoch 13/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.1m | Train Loss: 6.3922 | Val Loss: 10.0256 | ETA: 4.3h
------------------------------------------------------------

Epoch 14/50
Training................................................................ Done!
Validating.............. Done!
Completed in 6.9m | 



✓ Checkpoint saved at epoch 20
------------------------------------------------------------

Epoch 21/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.1m | Train Loss: 6.0374 | Val Loss: 9.2900 | ETA: 3.4h
------------------------------------------------------------

Epoch 22/50
Training................................................................ Done!
Validating.............. Done!
Completed in 6.9m | Train Loss: 5.9915 | Val Loss: 9.1626 | ETA: 3.3h
------------------------------------------------------------

Epoch 23/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.1m | Train Loss: 5.9488 | Val Loss: 9.0885 | ETA: 3.1h
------------------------------------------------------------

Epoch 24/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.3m | Tra



✓ Checkpoint saved at epoch 30
------------------------------------------------------------

Epoch 31/50
Training................................................................ Done!
Validating.............. Done!
Completed in 6.9m | Train Loss: 5.8138 | Val Loss: 8.8721 | ETA: 2.2h
------------------------------------------------------------

Epoch 32/50
Training................................................................ Done!
Validating.............. Done!
Completed in 6.9m | Train Loss: 5.8115 | Val Loss: 8.8764 | ETA: 2.1h
------------------------------------------------------------

Epoch 33/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.1m | Train Loss: 5.7760 | Val Loss: 8.7723 | ETA: 2.0h
------------------------------------------------------------

Epoch 34/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.0m | Tra



✓ Checkpoint saved at epoch 40
------------------------------------------------------------

Epoch 41/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.0m | Train Loss: 5.6992 | Val Loss: 8.6905 | ETA: 1.0h
------------------------------------------------------------

Epoch 42/50
Training................................................................ Done!
Validating.............. Done!
Completed in 6.8m | Train Loss: 5.6881 | Val Loss: 8.5810 | ETA: 56m
------------------------------------------------------------

Epoch 43/50
Training................................................................ Done!
Validating.............. Done!
Completed in 7.2m | Train Loss: 5.6782 | Val Loss: 8.5956 | ETA: 49m
------------------------------------------------------------

Epoch 44/50
Training................................................................ Done!
Validating.............. Done!
Completed in 6.9m | Train



✓ Checkpoint saved at epoch 50
------------------------------------------------------------

Saving final model...


VBox(children=(Label(value='3.419 MB of 3.419 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
learning_rate,▂▄▅▇████▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁
train_loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,50.0
learning_rate,0.0
train_loss,5.63245
validation_loss,8.5515


Training completed! Total time: 5.8 hours


In [None]:
import os
import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    AutoConfig,
    DataCollatorForLanguageModeling,
    get_linear_schedule_with_warmup
)
from peft import PeftModel
import wandb
from datasets import load_dataset

# ─── Hyper-parameters ─────────────────────────────────────
teacher_lora_path = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_150m_ecoli_finetuning_1/checkpoint-19000"
student_lora_path = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_8m_ecoli_finetuning_2/checkpoint-11500"
train_fasta = "../BenchmarkingFinetuning/dataset_splits/finetuning_dataset/train.fasta"
valid_fasta = "../BenchmarkingFinetuning/dataset_splits/finetuning_dataset/valid.fasta"

output_dir = "distilled-esm2-8M"
epochs = 50
train_bs = 16
eval_bs = 32
lr = 1e-4
mlm_prob = 0.15
T = 2.0
save_epochs = 10

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# ─── Models & Tokenizer ────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D", do_lower_case=False)

teacher_config = AutoConfig.from_pretrained("facebook/esm2_t30_150M_UR50D")
teacher_config.output_hidden_states = True
teacher_config.return_dict = True
base_teacher = AutoModelForMaskedLM.from_pretrained(
    "facebook/esm2_t30_150M_UR50D", config=teacher_config
)

student_config = AutoConfig.from_pretrained(
    "facebook/esm2_t6_8M_UR50D",
    attention_probs_dropout_prob=0.2,
    hidden_dropout_prob=0.2
)
student_config.output_hidden_states = True
student_config.return_dict = True
base_student = AutoModelForMaskedLM.from_pretrained(
    "facebook/esm2_t6_8M_UR50D",
    config=student_config
)

teacher = PeftModel.from_pretrained(base_teacher, teacher_lora_path)
student = PeftModel.from_pretrained(base_student, student_lora_path)

teacher.eval()
for p in teacher.parameters():
    p.requires_grad = False

for name, param in student.named_parameters():
    param.requires_grad = "lora" in name.lower()

trainable_params = sum(p.numel() for p in student.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in student.parameters())
print(f"Trainable params: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")

teacher.to(device)
student.to(device)
student.train()

# ─── Data ─────────────────────────────────────────────────
ds = load_dataset("text", data_files={"train": train_fasta, "validation": valid_fasta})

def tokenize_fn(examples):
    return tokenizer(examples["text"], truncation=True, max_length=1024)

ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"])
ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
train_ds = ds["train"]
eval_ds = ds["validation"]

collator = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=mlm_prob)
train_loader = DataLoader(train_ds, batch_size=train_bs, shuffle=True, collate_fn=collator)
eval_loader = DataLoader(eval_ds, batch_size=eval_bs, shuffle=False, collate_fn=collator)

# ─── Optimizer & Scheduler ────────────────────────────────
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, student.parameters()),
    lr=lr,
    weight_decay=0.01
)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# ─── Mixed-precision setup ─────────────────────────────────
scaler = torch.cuda.amp.GradScaler()

# ─── W&B setup ─────────────────────────────────────────────
wandb.login()
wandb.init(
    project="esm2-distill",
    name="epoch-based-logging",
    config={"epochs": epochs, "train_bs": train_bs, "eval_bs": eval_bs, "lr": lr, "mlm_prob": mlm_prob, "temperature": T}
)

# ─── Distillation loss ────────────────────────────────────
def distill_loss(s_logits, t_logits, T):
    s = s_logits / T
    t = t_logits / T
    kld = F.kl_div(F.log_softmax(s, dim=-1), F.softmax(t, dim=-1), reduction="batchmean")
    return kld * (T * T)

# ─── Training loop ────────────────────────────────────────
start_time = time.time()
best_val_loss = float("inf")
patience = 5
patience_counter = 0

for epoch in range(1, epochs + 1):
    epoch_train_loss = 0.0
    epoch_train_steps = 0
    print(f"\nEpoch {epoch}/{epochs}\nTraining...", end='', flush=True)

    for i, batch in enumerate(train_loader):
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()

        with torch.no_grad():
            t_out = teacher(**batch)

        with torch.cuda.amp.autocast():
            s_out = student(**batch)
            loss_kd = distill_loss(s_out.logits, t_out.logits, T)
            ce = F.cross_entropy(
                s_out.logits.view(-1, s_out.logits.size(-1)),
                batch["labels"].view(-1)
            ) if batch.get("labels") is not None else 0.0
            loss = 0.5 * ce + 0.5 * loss_kd

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        epoch_train_loss += loss.item()
        epoch_train_steps += 1
        if (i + 1) % 100 == 0:
            print('.', end='', flush=True)

    print(" Done!")
    avg_train_loss = epoch_train_loss / epoch_train_steps

    # Validation
    print("Validating...", end='', flush=True)
    student.eval()
    total_val_loss, val_batches = 0.0, 0

    for i, ev_batch in enumerate(eval_loader):
        ev_batch = {k: v.to(device) for k, v in ev_batch.items()}
        with torch.no_grad():
            t_eval = teacher(**ev_batch)
            s_eval = student(**ev_batch)
            val_loss = distill_loss(s_eval.logits, t_eval.logits, T)
        total_val_loss += val_loss.item()
        val_batches += 1
        if (i + 1) % 50 == 0:
            print('.', end='', flush=True)

    print(" Done!")
    avg_val_loss = total_val_loss / val_batches

    wandb.log({"train_loss": avg_train_loss, "validation_loss": avg_val_loss, "learning_rate": scheduler.get_last_lr()[0], "epoch": epoch})
    print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    # Early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        ckpt_dir = os.path.join(output_dir, "best-checkpoint")
        os.makedirs(ckpt_dir, exist_ok=True)
        student.save_pretrained(ckpt_dir)
        wandb.save(f"{ckpt_dir}/*")
        print("✓ Best checkpoint saved")
    else:
        patience_counter += 1
        print(f"⚠️  No improvement. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("⛔ Early stopping triggered.")
            break

    if epoch % save_epochs == 0:
        ckpt_dir = os.path.join(output_dir, f"checkpoint-epoch-{epoch}")
        os.makedirs(ckpt_dir, exist_ok=True)
        student.save_pretrained(ckpt_dir)
        wandb.save(f"{ckpt_dir}/*")
        print(f"✓ Checkpoint saved at epoch {epoch}")
    student.train()

print("\nSaving final model...")
student.save_pretrained(output_dir)
wandb.finish()
print("Training completed!")


Using device: cuda
GPU Name: NVIDIA A100-SXM4-80GB
GPU Memory: 84.99 GB
Trainable params: 61,440 / 8,016,187 (0.77%)


Map:   0%|          | 0/97597 [00:00<?, ? examples/s]

Map:   0%|          | 0/18263 [00:00<?, ? examples/s]

  scaler = torch.cuda.amp.GradScaler()



Epoch 1/50
Training...

  with torch.cuda.amp.autocast():


............................................................. Done!
Validating.............. Done!
Train Loss: 87.3363 | Val Loss: 64.1484




✓ Best checkpoint saved

Epoch 2/50
Training................................................................ Done!
Validating.............. Done!
Train Loss: 35.6082 | Val Loss: 46.0822
✓ Best checkpoint saved

Epoch 3/50
Training................................................................ Done!
Validating.............. Done!
Train Loss: 28.2499 | Val Loss: 41.5601
✓ Best checkpoint saved

Epoch 4/50
Training................................................................ Done!
Validating.............. Done!
Train Loss: 25.1349 | Val Loss: 39.7372
✓ Best checkpoint saved

Epoch 5/50
Training............