In [1]:
# Import packages
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
from peft import PeftModel, PeftConfig
from autoamp.evolveFinetune import *
import torch
from tqdm import tqdm
import math
from Bio import SeqIO 
import json
import warnings
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import PreTrainedTokenizer

# Example inputs
base_model_name = "facebook/esm2_t30_150M_UR50D" 
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
adapter_checkpoint = "/home/sdowell/scratch/Thesis/ADP1/runs/esm2_dgoa_finetune_1/checkpoint-3000"

# Load models
model_pretrained = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_with_adapter = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

def compute_mlm_loss(model, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Computes the MLM (masked language model) cross entropy loss for a given sequence.
    
    Args:
        sequence: Protein sequence as a string.
        mask_prob: The probability of masking a token.
        device: torch.device to run the computation.
    
    Returns:
        loss: The MLM cross entropy loss.
    """
    model.to(device)
    model.train()  # or model.eval() if you don't want dropout, etc.
    
    # Tokenize the sequence
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    
    # Create labels as a copy of input_ids.
    labels = input_ids.clone()
    
    # Create a mask for positions to replace.
    # Generate random values in [0, 1) for each token.
    probability_matrix = torch.rand(input_ids.shape).to(device)
    # Create a boolean mask for tokens to mask.
    mask = probability_matrix < mask_prob
    
    # For positions NOT selected for masking, set the corresponding label to -100 so they are ignored.
    labels[~mask] = -100
    
    # Replace the selected input positions with the mask token.
    mask_token_id = tokenizer.mask_token_id
    input_ids[mask] = mask_token_id
    
    # Forward pass: the model automatically computes the loss when labels are provided.
    outputs = model(input_ids, labels=labels)
    loss = outputs.loss
    return loss

# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sequence = "MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKALIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEAGAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLYRAGQSVERTAQQAAAFVKAYREAVQ"
loss = compute_mlm_loss(model_pretrained, sequence, mask_prob=0.15, device=device)
print(f"Pretrained MLM Cross Entropy Loss: {loss.item():.4f}")
loss = compute_mlm_loss(model_finetuned, sequence, mask_prob=0.15, device=device)
print(f"Finetuned MLM Cross Entropy Loss: {loss.item():.4f}")



2025-04-27 16:30:43.481496: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-27 16:30:43.498205: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-27 16:30:43.498232: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-27 16:30:43.509811: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Pretrained MLM Cross Entropy Loss: 1.3430
Finetuned MLM Cross Entropy Loss: 0.0133


In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np

def compute_token_prediction_accuracy(model, tokenizer, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Computes token prediction accuracy for a masked language model.
    
    Args:
        model: The masked language model.
        tokenizer: The corresponding tokenizer.
        sequence: A string representing the protein (or other) sequence.
        mask_prob: The probability with which to mask tokens (e.g., 0.15 for 15%).
        device: torch.device to run the computations.
    
    Returns:
        A tuple (accuracy, correct, total_masked) where:
          - accuracy is the fraction of masked tokens correctly predicted.
          - correct is the number of correct predictions.
          - total_masked is the total number of masked tokens.
    """
    # Tokenize the input sequence.
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Create a copy of the original tokens to serve as labels.
    labels = input_ids.clone()
    
    # Create a random mask for tokens according to mask_prob.
    probability_matrix = torch.rand(input_ids.shape).to(device)
    mask_positions = probability_matrix < mask_prob
    
    # Replace the selected token positions in input_ids with the mask token ID.
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("The tokenizer does not have a mask token.")
    
    input_ids[mask_positions] = mask_token_id
    
    # Forward pass through the model.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch_size, sequence_length, vocab_size]
    
    # Get predictions (top candidate from the logits) using argmax.
    predictions = torch.argmax(logits, dim=-1)
    
    # Consider only the masked positions.
    masked_labels = labels[mask_positions]
    masked_predictions = predictions[mask_positions]
    
    # Calculate the number of correct predictions.
    correct = (masked_predictions == masked_labels).sum().item()
    total_masked = mask_positions.sum().item()
    
    accuracy = correct / total_masked if total_masked > 0 else 0.0
    return accuracy, correct, total_masked

sequence = (
    'MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKA'
    'LIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEA'
    'GAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLY'
    'RAGQSVERTAQQAAAFVKAYREAVQ'
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
accuracy, correct, total_masked = compute_token_prediction_accuracy(
    model_pretrained, tokenizer, sequence, mask_prob=0.15, device=device
)
print(f"pretrained Token Prediction Accuracy: {accuracy:.4f} ({correct}/{total_masked})")
accuracy, correct, total_masked = compute_token_prediction_accuracy(
    model_finetuned, tokenizer, sequence, mask_prob=0.15, device=device
)
print(f"finetuned Token Prediction Accuracy: {accuracy:.4f} ({correct}/{total_masked})")


pretrained Token Prediction Accuracy: 0.3929 (11/28)
finetuned Token Prediction Accuracy: 1.0000 (24/24)


In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from Bio import SeqIO
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
from peft import PeftModel, PeftConfig
from autoamp.evolveFinetune import *
import torch
from tqdm import tqdm
import math
from Bio import SeqIO 
import json
import warnings
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import PreTrainedTokenizer


base_model_name = "facebook/esm2_t30_150M_UR50D" 
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
adapter_checkpoint = "/home/sdowell/scratch/Thesis/ADP1/runs/esm2_dgoa_finetune_1/checkpoint-3000"

# Load models
model_pretrained = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_with_adapter = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

def compute_token_prediction_accuracy(model, tokenizer, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Computes token prediction accuracy for a given sequence from a masked language model.
    
    Args:
        model: The masked language model.
        tokenizer: The corresponding tokenizer.
        sequence: A string representing the protein (or other) sequence.
        mask_prob: The probability of masking a token (default is 15%).
        device: Torch device on which to run computations.
    
    Returns:
        A tuple (correct, total_masked) where:
          - correct: number of masked tokens correctly predicted.
          - total_masked: total number of masked tokens.
    """
    # Tokenize input sequence.
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Create a copy of input_ids for ground-truth labels.
    labels = input_ids.clone()
    
    # Create a random mask for tokens.
    probability_matrix = torch.rand(input_ids.shape).to(device)
    mask_positions = probability_matrix < mask_prob
    
    # Replace tokens at masked positions with the mask token.
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("The tokenizer does not have a mask token.")
    input_ids[mask_positions] = mask_token_id
    
    # Forward pass through the model.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch, seq_length, vocab_size]
    
    # Get predicted token IDs.
    predictions = torch.argmax(logits, dim=-1)
    
    # Evaluate only on masked positions.
    masked_labels = labels[mask_positions]
    masked_predictions = predictions[mask_positions]
    
    correct = (masked_predictions == masked_labels).sum().item()
    total_masked = mask_positions.sum().item()
    
    return correct, total_masked

# ----- Main script to process FASTA file with a progress bar -----

fasta_file = "/home/sdowell/scratch/Thesis/ADP1/finetuning_data/test/dgoa_mutants_test.fasta"

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

total_correct = 0
total_masked = 0
n_sequences = 0

# Count number of records in FASTA for progress bar (optional, if file is large)
records = list(SeqIO.parse(fasta_file, "fasta"))
n_records = len(records)

# Iterate over sequences with a progress bar.
for record in tqdm(records, desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    correct, masked = compute_token_prediction_accuracy(model_pretrained, tokenizer, seq, mask_prob=0.15, device=device)
    total_correct += correct
    total_masked += masked
    n_sequences += 1

# Compute overall accuracy.
overall_accuracy = total_correct / total_masked if total_masked > 0 else 0.0

print(f"\nProcessed {n_sequences} sequences.")
print(f"pretrained Overall Token Prediction Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_masked})")



total_correct = 0
total_masked = 0
n_sequences = 0

# Count number of records in FASTA for progress bar (optional, if file is large)
records = list(SeqIO.parse(fasta_file, "fasta"))
n_records = len(records)

# Iterate over sequences with a progress bar.
for record in tqdm(records, desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    correct, masked = compute_token_prediction_accuracy(model_finetuned, tokenizer, seq, mask_prob=0.15, device=device)
    total_correct += correct
    total_masked += masked
    n_sequences += 1

# Compute overall accuracy.
overall_accuracy = total_correct / total_masked if total_masked > 0 else 0.0

print(f"\nProcessed {n_sequences} sequences.")
print(f"finetuned Overall Token Prediction Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_masked})")


Processing sequences: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1953/1953 [00:51<00:00, 37.65it/s]



Processed 1953 sequences.
pretrained Overall Token Prediction Accuracy: 0.5504 (32944/59852)


Processing sequences: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1953/1953 [01:07<00:00, 28.79it/s]


Processed 1953 sequences.
finetuned Overall Token Prediction Accuracy: 0.9837 (58985/59963)





# ESM-2 Pretrained Recall, Precision, F1

In [4]:
def compute_token_predictions_for_metrics(model, tokenizer, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Returns token-level predictions and labels for masked positions, to support precision/recall/F1.
    
    Returns:
        Two lists: true token IDs and predicted token IDs at masked positions.
    """
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    labels = input_ids.clone()
    probability_matrix = torch.rand(input_ids.shape).to(device)
    mask_positions = probability_matrix < mask_prob

    # Ensure mask token is defined
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("Tokenizer does not have a mask token.")
    input_ids[mask_positions] = mask_token_id

    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    predictions = torch.argmax(logits, dim=-1)

    true_tokens = labels[mask_positions].tolist()
    predicted_tokens = predictions[mask_positions].tolist()

    return true_tokens, predicted_tokens

from sklearn.metrics import classification_report

all_trues = []
all_preds = []

# Use same file path and device setup as before
for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    trues, preds = compute_token_predictions_for_metrics(model_pretrained, tokenizer, seq, mask_prob=0.15, device=device)
    all_trues.extend(trues)
    all_preds.extend(preds)

# Optionally convert to token strings for readability
id2token = tokenizer.convert_ids_to_tokens
true_tokens = [id2token(t) for t in all_trues]
pred_tokens = [id2token(p) for p in all_preds]

# Print precision, recall, F1 for each amino acid
print(classification_report(true_tokens, pred_tokens, zero_division=0, digits=10))


Processing sequences: 1953it [00:51, 38.26it/s]


              precision    recall  f1-score   support

       <cls>  1.0000000000 1.0000000000 1.0000000000       292
       <eos>  1.0000000000 1.0000000000 1.0000000000       300
           A  0.5594525235 0.6846733668 0.6157612278      9552
           C  0.0000000000 0.0000000000 0.0000000000      1150
           D  0.3456896552 0.3436161097 0.3446497636      2334
           E  0.5600841441 0.7114228457 0.6267470943      2994
           F  0.3869565217 0.2918032787 0.3327102804      1830
           G  0.9061837796 0.8583941606 0.8816418330      5480
           H  0.0372670807 0.0109489051 0.0169252468       548
           I  0.5784687592 0.4726749760 0.5202479884      4172
           K  0.3698347107 0.3231046931 0.3448940270      2216
           L  0.4502942580 0.6065995286 0.5168888077      4667
           M  0.8897338403 0.4952380952 0.6363018355       945
           N  0.6953748006 0.6733590734 0.6841898784      1295
           P  0.6525709748 0.8504527814 0.7384857335      3865


# ESM-2 fine-tuned Recall, Precision, F1

In [5]:

all_trues = []
all_preds = []

# Use same file path and device setup as before
for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    trues, preds = compute_token_predictions_for_metrics(model_finetuned, tokenizer, seq, mask_prob=0.15, device=device)
    all_trues.extend(trues)
    all_preds.extend(preds)

# Optionally convert to token strings for readability
id2token = tokenizer.convert_ids_to_tokens
true_tokens = [id2token(t) for t in all_trues]
pred_tokens = [id2token(p) for p in all_preds]

# Print precision, recall, F1 for each amino acid
print(classification_report(true_tokens, pred_tokens, zero_division=0, digits=10))


Processing sequences: 1953it [01:07, 28.93it/s]


              precision    recall  f1-score   support

       <cls>  1.0000000000 1.0000000000 1.0000000000       279
       <eos>  1.0000000000 0.9894366197 0.9946902655       284
       <unk>  0.0000000000 0.0000000000 0.0000000000         1
           A  0.9862832758 0.9901060792 0.9881909804      9804
           C  0.9945945946 0.9901345291 0.9923595506      1115
           D  0.9664000000 0.9829129373 0.9745865268      2458
           E  0.9854088904 0.9837398374 0.9845736566      2952
           F  0.9704109589 0.9435269046 0.9567801189      1877
           G  0.9931494502 0.9965629522 0.9948532731      5528
           H  0.9877551020 0.8673835125 0.9236641221       558
           I  0.9909274194 0.9764092376 0.9836147592      4027
           K  0.9808917197 0.9800000000 0.9804456571      2200
           L  0.9957654033 0.9917756221 0.9937665082      4742
           M  0.9859813084 0.9836829837 0.9848308051       858
           N  0.9825479930 0.9615713066 0.9719464825      1171


In [6]:
import torch
from tqdm import tqdm
from sklearn.metrics import classification_report
from Bio import SeqIO

# ---------- 1. Helper functions ----------

def prepare_masked_input(tokenizer, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Prepares masked input_ids and labels for consistent evaluation.
    
    Returns:
        masked input_ids, ground truth labels, attention_mask, mask_positions
    """
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)

    labels = input_ids.clone()
    probability_matrix = torch.rand(input_ids.shape).to(device)
    mask_positions = probability_matrix < mask_prob

    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("Tokenizer does not have a mask token.")
    
    masked_input_ids = input_ids.clone()
    masked_input_ids[mask_positions] = mask_token_id

    return masked_input_ids, labels, attention_mask, mask_positions

def compute_predictions(model, input_ids, attention_mask):
    """
    Computes model predictions.
    
    Returns:
        predictions tensor
    """
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    return predictions

# ---------- 2. Evaluation loop ----------

def evaluate_models(fasta_file, model_pretrained, model_finetuned, tokenizer, mask_prob=0.15, device=torch.device("cpu")):
    """
    Evaluates both models on the same masked inputs.
    """

    all_true_tokens = []

    all_preds_pretrained = []
    all_preds_finetuned = []

    for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
        seq = str(record.seq).strip()
        if not seq:
            continue

        # Step 1: Prepare masked inputs
        masked_input_ids, labels, attention_mask, mask_positions = prepare_masked_input(tokenizer, seq, mask_prob=mask_prob, device=device)

        # Step 2: Predictions
        preds_pretrained = compute_predictions(model_pretrained, masked_input_ids, attention_mask)
        preds_finetuned = compute_predictions(model_finetuned, masked_input_ids, attention_mask)

        # Step 3: Select only masked positions
        true_at_mask = labels[mask_positions].tolist()
        preds_pretrained_at_mask = preds_pretrained[mask_positions].tolist()
        preds_finetuned_at_mask = preds_finetuned[mask_positions].tolist()

        all_true_tokens.extend(true_at_mask)
        all_preds_pretrained.extend(preds_pretrained_at_mask)
        all_preds_finetuned.extend(preds_finetuned_at_mask)

    # ---------- 3. Generate reports ----------

    id2token = tokenizer.convert_ids_to_tokens

    true_tokens = [id2token(t) for t in all_true_tokens]
    preds_pretrained_tokens = [id2token(p) for p in all_preds_pretrained]
    preds_finetuned_tokens = [id2token(p) for p in all_preds_finetuned]

    print("=== Pretrained Model Report ===")
    print(classification_report(true_tokens, preds_pretrained_tokens, zero_division=0, digits=10))

    print("\n=== Finetuned Model Report ===")
    print(classification_report(true_tokens, preds_finetuned_tokens, zero_division=0, digits=10))

# ---------- 4. Usage example ----------

# Assuming you already have these:
# model_pretrained, model_finetuned, tokenizer, fasta_file, device

evaluate_models(fasta_file, model_pretrained, model_finetuned, tokenizer, mask_prob=0.15, device=device)


Processing sequences: 1953it [01:43, 18.81it/s]


=== Pretrained Model Report ===
              precision    recall  f1-score   support

       <cls>  1.0000000000 1.0000000000 1.0000000000       298
       <eos>  1.0000000000 1.0000000000 1.0000000000       282
           A  0.5714886106 0.6824934752 0.6220778627      9962
           C  0.0000000000 0.0000000000 0.0000000000      1146
           D  0.3593489149 0.3420738975 0.3504986770      2517
           E  0.5663481953 0.7207024654 0.6342695794      2961
           F  0.3711126469 0.2978369384 0.3304615385      1803
           G  0.9025263952 0.8545162442 0.8778653952      5602
           H  0.0264900662 0.0076335878 0.0118518519       524
           I  0.5746632273 0.4868868383 0.5271460497      4118
           K  0.3732503888 0.3208556150 0.3450754853      2244
           L  0.4626911787 0.6148634777 0.5280324401      4871
           M  0.8865784499 0.5016042781 0.6407103825       935
           N  0.6792777300 0.6616415410 0.6703436572      1194
           P  0.6391454965 0.83