In [1]:
import os
import pandas as pd
import numpy as np
import torch
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from scipy.stats import pearsonr
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    EarlyStoppingCallback
)
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    PeftModel # Added for loading later
)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# --- Configuration ---
MODEL_NAME = "facebook/esm2_t12_35M_UR50D"  # Or choose a larger ESM model
DATA_FILE = "../protein_embeddings.parquet"      # <<< CHANGE THIS to your data file path
SEQUENCE_COL = "Sequence"                # <<< CHANGE THIS to your sequence column name
TARGET_COL = "DMS_score"                 # <<< CHANGE THIS to your DMS score column name
OUTPUT_DIR = "./esm2_lora_finetuned_dms"
LOGGING_DIR = "./logs"
TEST_SPLIT_SIZE = 0.15                   # Proportion of data for validation
RANDOM_SEED = 42


In [3]:
# LoRA Configuration
LORA_R = 8                               # LoRA rank (try 8, 16, 32)
LORA_ALPHA = 16                          # LoRA alpha (often 2*r)
LORA_DROPOUT = 0.1                       # LoRA dropout
# --- Inspect your chosen model's architecture (`print(model)`) to find suitable target modules ---
# For ESM models, attention query and value layers are common targets.
# The exact names might vary slightly based on the specific ESM version in transformers.
# Example for many ESM models:
LORA_TARGET_MODULES = [
    # "esm.encoder.layer.*.attention.self.query",
    # "esm.encoder.layer.*.attention.self.value",
    # Add other layers if needed, e.g., intermediate or output layers, but start simple.
]
for i in range(12):  # Assuming 12 layers (0-11)
    LORA_TARGET_MODULES.append(f'esm.encoder.layer.{i}.attention.self.value')
    LORA_TARGET_MODULES.append(f'esm.encoder.layer.{i}.attention.self.query')

print(LORA_TARGET_MODULES)

['esm.encoder.layer.0.attention.self.value', 'esm.encoder.layer.0.attention.self.query', 'esm.encoder.layer.1.attention.self.value', 'esm.encoder.layer.1.attention.self.query', 'esm.encoder.layer.2.attention.self.value', 'esm.encoder.layer.2.attention.self.query', 'esm.encoder.layer.3.attention.self.value', 'esm.encoder.layer.3.attention.self.query', 'esm.encoder.layer.4.attention.self.value', 'esm.encoder.layer.4.attention.self.query', 'esm.encoder.layer.5.attention.self.value', 'esm.encoder.layer.5.attention.self.query', 'esm.encoder.layer.6.attention.self.value', 'esm.encoder.layer.6.attention.self.query', 'esm.encoder.layer.7.attention.self.value', 'esm.encoder.layer.7.attention.self.query', 'esm.encoder.layer.8.attention.self.value', 'esm.encoder.layer.8.attention.self.query', 'esm.encoder.layer.9.attention.self.value', 'esm.encoder.layer.9.attention.self.query', 'esm.encoder.layer.10.attention.self.value', 'esm.encoder.layer.10.attention.self.query', 'esm.encoder.layer.11.attenti

In [4]:
# Training Configuration
LEARNING_RATE = 1e-4 # Adjust as needed (can be higher than full fine-tuning)
BATCH_SIZE = 8      # Adjust based on GPU memory
NUM_EPOCHS = 10     # Adjust based on convergence (use early stopping)
WEIGHT_DECAY = 0.01
EVAL_STEPS = 50    # Evaluate every N steps
SAVE_STEPS = 100   # Save checkpoint every N steps
LOGGING_STEPS = 10
EARLY_STOPPING_PATIENCE = 3 # Stop if validation metric doesn't improve for N evaluations



In [5]:
# --- Set Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Dataset

In [6]:
# --- Load and Prepare Data ---
print(f"Loading data from {DATA_FILE}...")
df = pd.read_parquet(DATA_FILE)

Loading data from ../protein_embeddings.parquet...


In [7]:
# Ensure target is numeric
df[TARGET_COL] = pd.to_numeric(df[TARGET_COL])


In [8]:

# Basic validation
if SEQUENCE_COL not in df.columns or TARGET_COL not in df.columns:
    raise ValueError(f"Columns '{SEQUENCE_COL}' or '{TARGET_COL}' not found in {DATA_FILE}")

print(f"Data loaded: {len(df)} samples")

Data loaded: 1140 samples


In [9]:
# Split data
train_df, val_df = train_test_split(
    df, test_size=TEST_SPLIT_SIZE, random_state=RANDOM_SEED
)

In [10]:
print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")

# Convert pandas DataFrames to Hugging Face Datasets
train_dataset = Dataset.from_pandas(train_df[[SEQUENCE_COL, TARGET_COL]].rename(columns={TARGET_COL: 'label', SEQUENCE_COL: 'text'}))
val_dataset = Dataset.from_pandas(val_df[[SEQUENCE_COL, TARGET_COL]].rename(columns={TARGET_COL: 'label', SEQUENCE_COL: 'text'}))


Training samples: 969
Validation samples: 171


# Tokenizer and Models


In [11]:
print(f"Loading tokenizer for {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print(f"Loading base model {MODEL_NAME} for regression...")
# We use AutoModelForSequenceClassification but configure it for regression
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=1,  # Single output value for regression
    problem_type="regression", # Ensures appropriate loss function (MSE)
)



Loading tokenizer for facebook/esm2_t12_35M_UR50D...
Loading base model facebook/esm2_t12_35M_UR50D for regression...


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
for name, module in model.named_modules():
    print(name)


esm
esm.embeddings
esm.embeddings.word_embeddings
esm.embeddings.dropout
esm.embeddings.position_embeddings
esm.encoder
esm.encoder.layer
esm.encoder.layer.0
esm.encoder.layer.0.attention
esm.encoder.layer.0.attention.self
esm.encoder.layer.0.attention.self.query
esm.encoder.layer.0.attention.self.key
esm.encoder.layer.0.attention.self.value
esm.encoder.layer.0.attention.self.dropout
esm.encoder.layer.0.attention.self.rotary_embeddings
esm.encoder.layer.0.attention.output
esm.encoder.layer.0.attention.output.dense
esm.encoder.layer.0.attention.output.dropout
esm.encoder.layer.0.attention.LayerNorm
esm.encoder.layer.0.intermediate
esm.encoder.layer.0.intermediate.dense
esm.encoder.layer.0.output
esm.encoder.layer.0.output.dense
esm.encoder.layer.0.output.dropout
esm.encoder.layer.0.LayerNorm
esm.encoder.layer.1
esm.encoder.layer.1.attention
esm.encoder.layer.1.attention.self
esm.encoder.layer.1.attention.self.query
esm.encoder.layer.1.attention.self.key
esm.encoder.layer.1.attention.se

In [13]:
def tokenize_function(examples):
    # Ensure sequences are strings
    sequences = [str(seq) for seq in examples["text"]]
    # Truncation is important for models with fixed input size like ESM
    return tokenizer(sequences, padding="max_length", truncation=True, max_length=1024) # Adjust max_length if needed, but ESM models often have a limit

print("Tokenizing datasets...")
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_val_dataset = val_dataset.map(tokenize_function, batched=True)

# Remove unnecessary columns
tokenized_train_dataset = tokenized_train_dataset.remove_columns(["text", "__index_level_0__"])
tokenized_val_dataset = tokenized_val_dataset.remove_columns(["text", "__index_level_0__"])
tokenized_train_dataset.set_format("torch")
tokenized_val_dataset.set_format("torch")

Tokenizing datasets...


Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 969/969 [00:01<00:00, 493.93 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 171/171 [00:00<00:00, 486.78 examples/s]


# PEFT

In [14]:
print("Configuring LoRA...")
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, # Use SEQ_CLS even for regression with this model head
    inference_mode=False,
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=LORA_TARGET_MODULES,
    # bias="none" # or "all" or "lora_only". Check peft docs. 'none' often works well.
)


Configuring LoRA...


In [15]:

# Wrap the base model with PEFT adapter
model = get_peft_model(model, peft_config)
print("LoRA model configured:")
model.print_trainable_parameters() # Should show a very small % of trainable parameters

model = model.to(device) # Move PEFT model to GPU if available
print(model.device)

LoRA model configured:
trainable params: 415,681 || all params: 34,409,043 || trainable%: 1.2081
cuda:0


# Configure Training

In [34]:
# --- Define Metrics ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Logits are the direct output for regression
    predictions = logits.squeeze(-1) # Remove the last dimension if necessary
    # labels = labels.squeeze(-1)

    # Calculate Pearson Correlation
    pearson_corr, p_value = pearsonr(predictions, labels)

    # Calculate Mean Squared Error (optional but good to track)
    mse = ((predictions - labels) ** 2).mean().item()

    return {
        "pearsonr": pearson_corr,
        "mse": mse,
    }

In [35]:
print("Setting up Training Arguments...")
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=WEIGHT_DECAY,
    evaluation_strategy="steps", # Evaluate periodically
    eval_steps=EVAL_STEPS,
    save_strategy="steps",       # Save periodically
    save_steps=SAVE_STEPS,
    logging_dir=LOGGING_DIR,
    logging_steps=LOGGING_STEPS,
    load_best_model_at_end=True, # Load the best checkpoint at the end of training
    metric_for_best_model="pearsonr", # Use Pearson correlation to determine the best model
    greater_is_better=True,      # Higher Pearson correlation is better
    save_total_limit=2,          # Only keep the best and the latest checkpoint
    fp16=torch.cuda.is_available(), # Enable mixed precision if GPU is available
    report_to="tensorboard",     # Log metrics for visualization
    seed=RANDOM_SEED,
    # dataloader_num_workers=4, # Optional: Speed up data loading if I/O is bottleneck
)

Setting up Training Arguments...




In [36]:
# Data collator handles padding within batches dynamically
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [37]:
import torch
from transformers import Trainer
from typing import Dict, Union, Any, Optional, Tuple, List
import logging # Using logging is better for debug messages

# Configure basic logging (optional, but helpful for debugging)
# logging.basicConfig(level=logging.DEBUG) # Uncomment this line to enable debug prints
logger = logging.getLogger(__name__)

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        Compute loss, explicitly handling potential dict-like loss output
        and unexpected keyword arguments like 'num_items_in_batch'.
        """
        # Make a copy to avoid modifying the original input dict if labels are popped
        inputs_copy = inputs.copy()
        labels = inputs_copy.pop("labels", None)

        # Ensure labels are on the correct device if model is on GPU and labels exist
        if labels is not None and hasattr(model, 'device'):
            labels = labels.to(model.device)

        # logger.debug(f"Inputs passed to model: {list(inputs_copy.keys())}")
        # logger.debug(f"Labels present: {labels is not None}")
        if labels is not None:
            # logger.debug(f"Labels tensor device: {labels.device}")
            pass

        # Forward pass using the modified inputs (without labels if they were popped)
        outputs = model(**inputs_copy)
        # logger.debug(f"Model output type: {type(outputs)}")
        # logger.debug(f"Model outputs: {outputs}") # Caution: Can be very verbose

        loss = None # Initialize loss

        # Check if the model computed the loss automatically
        if hasattr(outputs, "loss") and outputs.loss is not None:
            loss_val = outputs.loss # Get the value from the outputs object
            # logger.debug(f"Type of outputs.loss: {type(loss_val)}")
            # logger.debug(f"Value of outputs.loss: {loss_val}")

            # --- Critical Check ---
            # Check if the extracted loss is actually a tensor
            if isinstance(loss_val, torch.Tensor):
                loss = loss_val
            # --- Handle if loss_val is dict ---
            elif isinstance(loss_val, dict):
                # If outputs.loss is unexpectedly a dict, try common keys
                if 'loss' in loss_val and isinstance(loss_val['loss'], torch.Tensor):
                     loss = loss_val['loss']
                     logger.warning("outputs.loss was a dict, extracted loss tensor from key 'loss'. Check model output structure.")
                # Add other potential keys if you inspect outputs and find loss elsewhere
                # elif 'some_other_key' in loss_val ...
                else:
                    # If we can't find a tensor in the dict, raise an error
                    raise TypeError(f"outputs.loss is a dictionary but does not contain a recognizable loss tensor under expected keys: {loss_val}")
            else:
                 # Handle other unexpected types
                 raise TypeError(f"outputs.loss has an unexpected type: {type(loss_val)}. Expected torch.Tensor or dict containing tensor.")

        # If we successfully extracted a valid loss tensor, ensure it's a scalar
        if loss is not None:
            # logger.debug(f"Extracted Loss Tensor: {loss}, Shape: {loss.shape}, Dim: {loss.dim()}")
            # Ensure it's a tensor before calling .dim() - redundant now due to checks above, but safe
            if isinstance(loss, torch.Tensor):
                if loss.dim() != 0:
                    # This often happens in multi-GPU (DistributedDataParallel) scenarios
                    # where the loss might be calculated per device or per sample initially.
                    # Averaging across the batch dimension is usually correct.
                    logger.warning(f"Loss tensor has dim {loss.dim()} (expected 0). Averaging loss.")
                    loss = loss.mean()
                    # logger.debug(f"Averaged Loss: {loss}")
            else:
                 # This should not happen due to the checks above, but as a safeguard:
                 raise TypeError(f"Variable 'loss' was expected to be a Tensor but got {type(loss)} before dimension check.")

        # Handle cases where loss couldn't be determined (e.g., no labels provided during eval)
        elif not return_outputs:
             # During training, loss should always be computable if labels are provided.
             # If loss is None here during training, something is wrong upstream.
             # During evaluation without labels, loss will be None, which is okay.
             logger.warning("compute_loss: Loss is None (e.g., no labels provided or model didn't return loss), and return_outputs=False.")
             # What should happen here? The default Trainer might expect a Tensor.
             # Returning None might cause issues downstream in the training loop.
             # Let's raise an error if loss is required (i.e., labels were present initially)
             if labels is not None:
                 raise ValueError("Labels were provided, but loss could not be computed or extracted from model outputs.")
             # If no labels provided (eval mode), returning None for loss might be acceptable depending on Trainer internals
             # or how evaluation loop handles it. Let's stick to returning None in this specific eval case.

        # else: loss is None, but return_outputs is True (likely evaluation), which is fine.

        # logger.debug(f"Final computed loss value: {loss}")
        return (loss, outputs) if return_outputs else loss

In [38]:
# --- Initialize Trainer ---

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
    # data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)]
)

  trainer = Trainer(


In [39]:
# --- Train the Model ---
print("Starting training...")
train_result = trainer.train()

# Save training metrics
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

print("Training finished.")

Starting training...


Step,Training Loss,Validation Loss


***** train metrics *****
  epoch                    =      4.918
  total_flos               =   925679GF
  train_loss               =     0.0429
  train_runtime            = 0:03:17.10
  train_samples_per_second =     49.162
  train_steps_per_second   =       6.19
Training finished.


In [40]:
# --- Evaluate the Best Model ---
print("Evaluating the best model on the validation set...")
eval_results = trainer.evaluate(eval_dataset=tokenized_val_dataset)

print("Validation Results:")
print(eval_results)
trainer.log_metrics("eval", eval_results)
trainer.save_metrics("eval", eval_results)

Evaluating the best model on the validation set...


Validation Results:
{'eval_loss': 0.048407647758722305, 'eval_pearsonr': 0.12966124713420868, 'eval_mse': 0.048407647758722305, 'eval_runtime': 2.3744, 'eval_samples_per_second': 72.019, 'eval_steps_per_second': 9.266, 'epoch': 4.918032786885246}
***** eval metrics *****
  epoch                   =      4.918
  eval_loss               =     0.0484
  eval_mse                =     0.0484
  eval_pearsonr           =     0.1297
  eval_runtime            = 0:00:02.37
  eval_samples_per_second =     72.019
  eval_steps_per_second   =      9.266


In [41]:
# --- Save the final PEFT adapter ---
# The best model checkpoint is already saved by `load_best_model_at_end=True`
# You can optionally save the adapter explicitly if needed
final_adapter_path = os.path.join(OUTPUT_DIR, "final_adapter")
model.save_pretrained(final_adapter_path)
tokenizer.save_pretrained(final_adapter_path)
print(f"Final PEFT adapter saved to {final_adapter_path}")

Final PEFT adapter saved to ./esm2_lora_finetuned_dms/final_adapter


# Loading

In [42]:
print("\n--- Example Inference ---")
# Load the base model again
base_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=1,
    problem_type="regression",
)

# Load the PEFT adapter weights on top of the base model
inference_model = PeftModel.from_pretrained(base_model, final_adapter_path)
inference_model = inference_model.to(device)
inference_model.eval() # Set to evaluation mode

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



--- Example Inference ---


PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): EsmForSequenceClassification(
      (esm): EsmModel(
        (embeddings): EsmEmbeddings(
          (word_embeddings): Embedding(33, 480, padding_idx=1)
          (dropout): Dropout(p=0.0, inplace=False)
          (position_embeddings): Embedding(1026, 480, padding_idx=1)
        )
        (encoder): EsmEncoder(
          (layer): ModuleList(
            (0-11): 12 x EsmLayer(
              (attention): EsmAttention(
                (self): EsmSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=480, out_features=480, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=480, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
          

In [43]:

# Example new sequence (replace with actual sequences you want to predict)
new_sequence = "MVNLARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLREKMRRRLESGDKWFSLEFFPPRTAEGAVNLISRFDRMAAGGPLYIDVTWHPAGDPGSDKETSSMMIASTAVNYCGLETILHMTCCRQRLEEITGHLHKAKQLGLKNIMALRGDPIGDQWEEEEGGFNYAVDLVKHIRSEFGDYFDICVAGYPKGHPEAGSFEADLKHLKEKVSAGADFIITQLFFEADTFFRFVKACTDMGITCPIVPGIFPIQGYHSLRQLVKLSKLEVPQEIKDVIEPIKDNDAAIRNYGIELAVSLCQELLASGLVPGLHFYTLNREMATTEVLKRLGMWTEDPRRPLPWALSAHPKRREEDVRPIFWASRPKSYIYRTQEWDEFPNGRWGNSSSPAFGELKDYYLFYLKSKSPKEELLKMWGEELTSEESVFEVFVLYLSGEPNRNGHKVTCLPWNDEPLAAETSLLKEELLRVNRQGILTINSQPNINGKPSSDPIVGWGPSGGYVFQKAYLEFFTSRETAEALLQVLKKYELRVNYHLVNVKGENITNAPELQPNAVTWGIFPGREIIQPTVVDPVSFMFWKDEAFALWIERWGKLYEEESPSRTIIQYIHDNYFLVNLVDNDFPLDNCLWQVVEDTLELLNRPTQNARETEAP" # A protein sequence string (use one from your data or a hypothetical one)
print(f"Predicting DMS score for: {new_sequence[:30]}...") # Print start of sequence

# Tokenize the new sequence
inputs = tokenizer(new_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024)
inputs = {k: v.to(device) for k, v in inputs.items()} # Move inputs to device


Predicting DMS score for: MVNLARGNSSLNPCLEGSASSGSESSKDSS...


In [44]:

# Make prediction
with torch.no_grad():
    outputs = inference_model(**inputs)
    predicted_dms_score = outputs.logits.item() # Get the single regression value

print(f"Predicted DMS Score: {predicted_dms_score:.4f}")

Predicted DMS Score: 0.2048


# Eval