In [13]:
# Import packages
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
from peft import PeftModel, PeftConfig
from autoamp.evolveFinetune import *
import torch
from tqdm import tqdm
import math
from Bio import SeqIO 
import json
import warnings
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import PreTrainedTokenizer

# Example inputs
base_model_name = "facebook/esm2_t30_150M_UR50D" 
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
adapter_checkpoint = "/home/sdowell/scratch/Thesis/ADP1/runs/esm2_dgoa_finetune_1/checkpoint-3000"

# Load models
model_pretrained = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_with_adapter = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

def compute_mlm_loss(model, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Computes the MLM (masked language model) cross entropy loss for a given sequence.
    
    Args:
        sequence: Protein sequence as a string.
        mask_prob: The probability of masking a token.
        device: torch.device to run the computation.
    
    Returns:
        loss: The MLM cross entropy loss.
    """
    model.to(device)
    model.train()  # or model.eval() if you don't want dropout, etc.
    
    # Tokenize the sequence
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    
    # Create labels as a copy of input_ids.
    labels = input_ids.clone()
    
    # Create a mask for positions to replace.
    # Generate random values in [0, 1) for each token.
    probability_matrix = torch.rand(input_ids.shape).to(device)
    # Create a boolean mask for tokens to mask.
    mask = probability_matrix < mask_prob
    
    # For positions NOT selected for masking, set the corresponding label to -100 so they are ignored.
    labels[~mask] = -100
    
    # Replace the selected input positions with the mask token.
    mask_token_id = tokenizer.mask_token_id
    input_ids[mask] = mask_token_id
    
    # Forward pass: the model automatically computes the loss when labels are provided.
    outputs = model(input_ids, labels=labels)
    loss = outputs.loss
    return loss

# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sequence = "MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKALIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEAGAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLYRAGQSVERTAQQAAAFVKAYREAVQ"
loss = compute_mlm_loss(model_pretrained, sequence, mask_prob=0.15, device=device)
print(f"Pretrained MLM Cross Entropy Loss: {loss.item():.4f}")
loss = compute_mlm_loss(model_finetuned, sequence, mask_prob=0.15, device=device)
print(f"Finetuned MLM Cross Entropy Loss: {loss.item():.4f}")



Pretrained MLM Cross Entropy Loss: 1.5282
Finetuned MLM Cross Entropy Loss: 0.0510


In [14]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np

def compute_token_prediction_accuracy(model, tokenizer, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Computes token prediction accuracy for a masked language model.
    
    Args:
        model: The masked language model.
        tokenizer: The corresponding tokenizer.
        sequence: A string representing the protein (or other) sequence.
        mask_prob: The probability with which to mask tokens (e.g., 0.15 for 15%).
        device: torch.device to run the computations.
    
    Returns:
        A tuple (accuracy, correct, total_masked) where:
          - accuracy is the fraction of masked tokens correctly predicted.
          - correct is the number of correct predictions.
          - total_masked is the total number of masked tokens.
    """
    # Tokenize the input sequence.
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Create a copy of the original tokens to serve as labels.
    labels = input_ids.clone()
    
    # Create a random mask for tokens according to mask_prob.
    probability_matrix = torch.rand(input_ids.shape).to(device)
    mask_positions = probability_matrix < mask_prob
    
    # Replace the selected token positions in input_ids with the mask token ID.
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("The tokenizer does not have a mask token.")
    
    input_ids[mask_positions] = mask_token_id
    
    # Forward pass through the model.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch_size, sequence_length, vocab_size]
    
    # Get predictions (top candidate from the logits) using argmax.
    predictions = torch.argmax(logits, dim=-1)
    
    # Consider only the masked positions.
    masked_labels = labels[mask_positions]
    masked_predictions = predictions[mask_positions]
    
    # Calculate the number of correct predictions.
    correct = (masked_predictions == masked_labels).sum().item()
    total_masked = mask_positions.sum().item()
    
    accuracy = correct / total_masked if total_masked > 0 else 0.0
    return accuracy, correct, total_masked

sequence = (
    'MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKA'
    'LIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEA'
    'GAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLY'
    'RAGQSVERTAQQAAAFVKAYREAVQ'
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
accuracy, correct, total_masked = compute_token_prediction_accuracy(
    model_pretrained, tokenizer, sequence, mask_prob=0.15, device=device
)
print(f"pretrained Token Prediction Accuracy: {accuracy:.4f} ({correct}/{total_masked})")
accuracy, correct, total_masked = compute_token_prediction_accuracy(
    model_finetuned, tokenizer, sequence, mask_prob=0.15, device=device
)
print(f"finetuned Token Prediction Accuracy: {accuracy:.4f} ({correct}/{total_masked})")


pretrained Token Prediction Accuracy: 0.6207 (18/29)
finetuned Token Prediction Accuracy: 1.0000 (28/28)


In [15]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from Bio import SeqIO
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
from peft import PeftModel, PeftConfig
from autoamp.evolveFinetune import *
import torch
from tqdm import tqdm
import math
from Bio import SeqIO 
import json
import warnings
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import PreTrainedTokenizer


base_model_name = "facebook/esm2_t30_150M_UR50D" 
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
adapter_checkpoint = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_150m_ecoli_finetuning_1/checkpoint-19000"

# Load models
model_pretrained = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_with_adapter = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

def compute_token_prediction_accuracy(model, tokenizer, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Computes token prediction accuracy for a given sequence from a masked language model.
    
    Args:
        model: The masked language model.
        tokenizer: The corresponding tokenizer.
        sequence: A string representing the protein (or other) sequence.
        mask_prob: The probability of masking a token (default is 15%).
        device: Torch device on which to run computations.
    
    Returns:
        A tuple (correct, total_masked) where:
          - correct: number of masked tokens correctly predicted.
          - total_masked: total number of masked tokens.
    """
    # Tokenize input sequence.
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Create a copy of input_ids for ground-truth labels.
    labels = input_ids.clone()
    
    # Create a random mask for tokens.
    probability_matrix = torch.rand(input_ids.shape).to(device)
    mask_positions = probability_matrix < mask_prob
    
    # Replace tokens at masked positions with the mask token.
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("The tokenizer does not have a mask token.")
    input_ids[mask_positions] = mask_token_id
    
    # Forward pass through the model.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch, seq_length, vocab_size]
    
    # Get predicted token IDs.
    predictions = torch.argmax(logits, dim=-1)
    
    # Evaluate only on masked positions.
    masked_labels = labels[mask_positions]
    masked_predictions = predictions[mask_positions]
    
    correct = (masked_predictions == masked_labels).sum().item()
    total_masked = mask_positions.sum().item()
    
    return correct, total_masked

# ----- Main script to process FASTA file with a progress bar -----

fasta_file = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/dataset_splits/finetuning_dataset/test.fasta"

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

total_correct = 0
total_masked = 0
n_sequences = 0

# Count number of records in FASTA for progress bar (optional, if file is large)
records = list(SeqIO.parse(fasta_file, "fasta"))
n_records = len(records)

# Iterate over sequences with a progress bar.
for record in tqdm(records, desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    correct, masked = compute_token_prediction_accuracy(model_pretrained, tokenizer, seq, mask_prob=0.15, device=device)
    total_correct += correct
    total_masked += masked
    n_sequences += 1

# Compute overall accuracy.
overall_accuracy = total_correct / total_masked if total_masked > 0 else 0.0

print(f"\nProcessed {n_sequences} sequences.")
print(f"pretrained Overall Token Prediction Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_masked})")



total_correct = 0
total_masked = 0
n_sequences = 0

# Count number of records in FASTA for progress bar (optional, if file is large)
records = list(SeqIO.parse(fasta_file, "fasta"))
n_records = len(records)

# Iterate over sequences with a progress bar.
for record in tqdm(records, desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    correct, masked = compute_token_prediction_accuracy(model_finetuned, tokenizer, seq, mask_prob=0.15, device=device)
    total_correct += correct
    total_masked += masked
    n_sequences += 1

# Compute overall accuracy.
overall_accuracy = total_correct / total_masked if total_masked > 0 else 0.0

print(f"\nProcessed {n_sequences} sequences.")
print(f"finetuned Overall Token Prediction Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_masked})")


Processing sequences: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:20<00:00, 23.34it/s]



Processed 469 sequences.
pretrained Overall Token Prediction Accuracy: 0.4311 (20685/47977)


Processing sequences: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:22<00:00, 20.46it/s]


Processed 469 sequences.
finetuned Overall Token Prediction Accuracy: 0.9964 (47962/48137)





# ESM-2 Pretrained Recall, Precision, F1

In [16]:
def compute_token_predictions_for_metrics(model, tokenizer, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Returns token-level predictions and labels for masked positions, to support precision/recall/F1.
    
    Returns:
        Two lists: true token IDs and predicted token IDs at masked positions.
    """
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    labels = input_ids.clone()
    probability_matrix = torch.rand(input_ids.shape).to(device)
    mask_positions = probability_matrix < mask_prob

    # Ensure mask token is defined
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("Tokenizer does not have a mask token.")
    input_ids[mask_positions] = mask_token_id

    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    predictions = torch.argmax(logits, dim=-1)

    true_tokens = labels[mask_positions].tolist()
    predicted_tokens = predictions[mask_positions].tolist()

    return true_tokens, predicted_tokens

from sklearn.metrics import classification_report

all_trues = []
all_preds = []

# Use same file path and device setup as before
for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    trues, preds = compute_token_predictions_for_metrics(model_pretrained, tokenizer, seq, mask_prob=0.15, device=device)
    all_trues.extend(trues)
    all_preds.extend(preds)

# Optionally convert to token strings for readability
id2token = tokenizer.convert_ids_to_tokens
true_tokens = [id2token(t) for t in all_trues]
pred_tokens = [id2token(p) for p in all_preds]

# Print precision, recall, F1 for each amino acid
print(classification_report(true_tokens, pred_tokens, zero_division=0, digits=10))


Processing sequences: 469it [00:19, 23.56it/s]


              precision    recall  f1-score   support

       <cls>  1.0000000000 1.0000000000 1.0000000000        75
       <eos>  1.0000000000 1.0000000000 1.0000000000        72
           A  0.4294642857 0.4363801315 0.4328945888      4409
           C  0.6569536424 0.7304860088 0.6917712692       679
           D  0.3701272056 0.3741186230 0.3721122112      2411
           E  0.3464970995 0.4730429485 0.4000000000      3283
           F  0.5200789993 0.4460756635 0.4802431611      1771
           G  0.5300938708 0.5938756573 0.5601750547      3233
           H  0.7000000000 0.3504736130 0.4670874662       739
           I  0.5624235006 0.4106344951 0.4746900826      2238
           K  0.2740207272 0.5126519882 0.3571428571      3043
           L  0.4102945150 0.6418005072 0.5005768914      4732
           M  0.7654320988 0.2809667674 0.4110497238      1324
           N  0.4622356495 0.1673045380 0.2456844641      1829
           P  0.4087231797 0.5222471910 0.4585635359      2225


# ESM-2 fine-tuned Recall, Precision, F1

In [17]:

all_trues = []
all_preds = []

# Use same file path and device setup as before
for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    trues, preds = compute_token_predictions_for_metrics(model_finetuned, tokenizer, seq, mask_prob=0.15, device=device)
    all_trues.extend(trues)
    all_preds.extend(preds)

# Optionally convert to token strings for readability
id2token = tokenizer.convert_ids_to_tokens
true_tokens = [id2token(t) for t in all_trues]
pred_tokens = [id2token(p) for p in all_preds]

# Print precision, recall, F1 for each amino acid
print(classification_report(true_tokens, pred_tokens, zero_division=0, digits=10))


Processing sequences: 469it [00:22, 20.61it/s]


              precision    recall  f1-score   support

       <cls>  1.0000000000 1.0000000000 1.0000000000        79
       <eos>  1.0000000000 0.8607594937 0.9251700680        79
           A  0.9929223744 0.9970197157 0.9949668268      4362
           C  1.0000000000 0.9958391123 0.9979152189       721
           D  0.9962732919 0.9889025894 0.9925742574      2433
           E  0.9959814529 0.9978321462 0.9969059406      3229
           F  0.9994400896 0.9983221477 0.9988808058      1788
           G  0.9972367209 0.9963190184 0.9967776584      3260
           H  0.9924050633 0.9936628644 0.9930335655       789
           I  0.9978308026 0.9956709957 0.9967497291      2310
           K  0.9980525803 0.9977287476 0.9978906377      3082
           L  0.9968500630 0.9985275557 0.9976881042      4754
           M  0.9954751131 0.9969788520 0.9962264151      1324
           N  0.9879448909 0.9907887162 0.9893647600      1737
           P  0.9949541284 0.9958677686 0.9954107389      2178
