In [8]:
cd /home/sdowell/scratch/Thesis/ADP1

/home/sdowell/scratch/Thesis/ADP1


In [9]:
# Import packages
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
from peft import PeftModel, PeftConfig
from autoamp.evolveFinetune import *
import torch
from tqdm import tqdm
import math
from Bio import SeqIO 
import json
import warnings
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import PreTrainedTokenizer

# Example inputs
base_model_name = "facebook/esm2_t30_150M_UR50D" 
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
adapter_checkpoint = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_150m_ecoli_finetuning_1/checkpoint-19000"

# Load models
model_pretrained = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_with_adapter = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

def compute_mlm_loss(model, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Computes the MLM (masked language model) cross entropy loss for a given sequence.
    
    Args:
        sequence: Protein sequence as a string.
        mask_prob: The probability of masking a token.
        device: torch.device to run the computation.
    
    Returns:
        loss: The MLM cross entropy loss.
    """
    model.to(device)
    model.train()  # or model.eval() if you don't want dropout, etc.
    
    # Tokenize the sequence
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    
    # Create labels as a copy of input_ids.
    labels = input_ids.clone()
    
    # Create a mask for positions to replace.
    # Generate random values in [0, 1) for each token.
    probability_matrix = torch.rand(input_ids.shape).to(device)
    # Create a boolean mask for tokens to mask.
    mask = probability_matrix < mask_prob
    
    # For positions NOT selected for masking, set the corresponding label to -100 so they are ignored.
    labels[~mask] = -100
    
    # Replace the selected input positions with the mask token.
    mask_token_id = tokenizer.mask_token_id
    input_ids[mask] = mask_token_id
    
    # Forward pass: the model automatically computes the loss when labels are provided.
    outputs = model(input_ids, labels=labels)
    loss = outputs.loss
    return loss

# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sequence = "MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKALIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEAGAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLYRAGQSVERTAQQAAAFVKAYREAVQ"
loss = compute_mlm_loss(model_pretrained, sequence, mask_prob=0.15, device=device)
print(f"Pretrained MLM Cross Entropy Loss: {loss.item():.4f}")
loss = compute_mlm_loss(model_finetuned, sequence, mask_prob=0.15, device=device)
print(f"Finetuned MLM Cross Entropy Loss: {loss.item():.4f}")



Pretrained MLM Cross Entropy Loss: 1.6485
Finetuned MLM Cross Entropy Loss: 0.0446


In [10]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np

def compute_token_prediction_accuracy(model, tokenizer, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Computes token prediction accuracy for a masked language model.
    
    Args:
        model: The masked language model.
        tokenizer: The corresponding tokenizer.
        sequence: A string representing the protein (or other) sequence.
        mask_prob: The probability with which to mask tokens (e.g., 0.15 for 15%).
        device: torch.device to run the computations.
    
    Returns:
        A tuple (accuracy, correct, total_masked) where:
          - accuracy is the fraction of masked tokens correctly predicted.
          - correct is the number of correct predictions.
          - total_masked is the total number of masked tokens.
    """
    # Tokenize the input sequence.
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Create a copy of the original tokens to serve as labels.
    labels = input_ids.clone()
    
    # Create a random mask for tokens according to mask_prob.
    probability_matrix = torch.rand(input_ids.shape).to(device)
    mask_positions = probability_matrix < mask_prob
    
    # Replace the selected token positions in input_ids with the mask token ID.
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("The tokenizer does not have a mask token.")
    
    input_ids[mask_positions] = mask_token_id
    
    # Forward pass through the model.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch_size, sequence_length, vocab_size]
    
    # Get predictions (top candidate from the logits) using argmax.
    predictions = torch.argmax(logits, dim=-1)
    
    # Consider only the masked positions.
    masked_labels = labels[mask_positions]
    masked_predictions = predictions[mask_positions]
    
    # Calculate the number of correct predictions.
    correct = (masked_predictions == masked_labels).sum().item()
    total_masked = mask_positions.sum().item()
    
    accuracy = correct / total_masked if total_masked > 0 else 0.0
    return accuracy, correct, total_masked

sequence = (
    'MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKA'
    'LIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEA'
    'GAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLY'
    'RAGQSVERTAQQAAAFVKAYREAVQ'
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
accuracy, correct, total_masked = compute_token_prediction_accuracy(
    model_pretrained, tokenizer, sequence, mask_prob=0.15, device=device
)
print(f"pretrained Token Prediction Accuracy: {accuracy:.4f} ({correct}/{total_masked})")
accuracy, correct, total_masked = compute_token_prediction_accuracy(
    model_finetuned, tokenizer, sequence, mask_prob=0.15, device=device
)
print(f"finetuned Token Prediction Accuracy: {accuracy:.4f} ({correct}/{total_masked})")


pretrained Token Prediction Accuracy: 0.5000 (10/20)
finetuned Token Prediction Accuracy: 1.0000 (27/27)


In [11]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from Bio import SeqIO
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
from peft import PeftModel, PeftConfig
from autoamp.evolveFinetune import *
import torch
from tqdm import tqdm
import math
from Bio import SeqIO 
import json
import warnings
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import PreTrainedTokenizer


base_model_name = "facebook/esm2_t30_150M_UR50D" 
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
adapter_checkpoint = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_150m_ecoli_finetuning_1/checkpoint-19000"

# Load models
model_pretrained = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_with_adapter = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

def compute_token_prediction_accuracy(model, tokenizer, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Computes token prediction accuracy for a given sequence from a masked language model.
    
    Args:
        model: The masked language model.
        tokenizer: The corresponding tokenizer.
        sequence: A string representing the protein (or other) sequence.
        mask_prob: The probability of masking a token (default is 15%).
        device: Torch device on which to run computations.
    
    Returns:
        A tuple (correct, total_masked) where:
          - correct: number of masked tokens correctly predicted.
          - total_masked: total number of masked tokens.
    """
    # Tokenize input sequence.
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Create a copy of input_ids for ground-truth labels.
    labels = input_ids.clone()
    
    # Create a random mask for tokens.
    probability_matrix = torch.rand(input_ids.shape).to(device)
    mask_positions = probability_matrix < mask_prob
    
    # Replace tokens at masked positions with the mask token.
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("The tokenizer does not have a mask token.")
    input_ids[mask_positions] = mask_token_id
    
    # Forward pass through the model.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch, seq_length, vocab_size]
    
    # Get predicted token IDs.
    predictions = torch.argmax(logits, dim=-1)
    
    # Evaluate only on masked positions.
    masked_labels = labels[mask_positions]
    masked_predictions = predictions[mask_positions]
    
    correct = (masked_predictions == masked_labels).sum().item()
    total_masked = mask_positions.sum().item()
    
    return correct, total_masked

# ----- Main script to process FASTA file with a progress bar -----

fasta_file = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/dataset_splits/finetuning_dataset/test.fasta"

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

total_correct = 0
total_masked = 0
n_sequences = 0

# Count number of records in FASTA for progress bar (optional, if file is large)
records = list(SeqIO.parse(fasta_file, "fasta"))
n_records = len(records)

# Iterate over sequences with a progress bar.
for record in tqdm(records, desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    correct, masked = compute_token_prediction_accuracy(model_pretrained, tokenizer, seq, mask_prob=0.15, device=device)
    total_correct += correct
    total_masked += masked
    n_sequences += 1

# Compute overall accuracy.
overall_accuracy = total_correct / total_masked if total_masked > 0 else 0.0

print(f"\nProcessed {n_sequences} sequences.")
print(f"pretrained Overall Token Prediction Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_masked})")



total_correct = 0
total_masked = 0
n_sequences = 0

# Count number of records in FASTA for progress bar (optional, if file is large)
records = list(SeqIO.parse(fasta_file, "fasta"))
n_records = len(records)

# Iterate over sequences with a progress bar.
for record in tqdm(records, desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    correct, masked = compute_token_prediction_accuracy(model_finetuned, tokenizer, seq, mask_prob=0.15, device=device)
    total_correct += correct
    total_masked += masked
    n_sequences += 1

# Compute overall accuracy.
overall_accuracy = total_correct / total_masked if total_masked > 0 else 0.0

print(f"\nProcessed {n_sequences} sequences.")
print(f"finetuned Overall Token Prediction Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_masked})")


Processing sequences: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1953/1953 [00:54<00:00, 35.61it/s]



Processed 1953 sequences.
pretrained Overall Token Prediction Accuracy: 0.5505 (33014/59970)


Processing sequences: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1953/1953 [01:12<00:00, 26.83it/s]


Processed 1953 sequences.
finetuned Overall Token Prediction Accuracy: 0.9834 (58786/59779)





# ESM-2 Pretrained Recall, Precision, F1

In [12]:
def compute_token_predictions_for_metrics(model, tokenizer, sequence, mask_prob=0.15, device=torch.device("cpu")):
    """
    Returns token-level predictions and labels for masked positions, to support precision/recall/F1.
    
    Returns:
        Two lists: true token IDs and predicted token IDs at masked positions.
    """
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    labels = input_ids.clone()
    probability_matrix = torch.rand(input_ids.shape).to(device)
    mask_positions = probability_matrix < mask_prob

    # Ensure mask token is defined
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("Tokenizer does not have a mask token.")
    input_ids[mask_positions] = mask_token_id

    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    predictions = torch.argmax(logits, dim=-1)

    true_tokens = labels[mask_positions].tolist()
    predicted_tokens = predictions[mask_positions].tolist()

    return true_tokens, predicted_tokens

from sklearn.metrics import classification_report

all_trues = []
all_preds = []

# Use same file path and device setup as before
for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    trues, preds = compute_token_predictions_for_metrics(model_pretrained, tokenizer, seq, mask_prob=0.15, device=device)
    all_trues.extend(trues)
    all_preds.extend(preds)

# Optionally convert to token strings for readability
id2token = tokenizer.convert_ids_to_tokens
true_tokens = [id2token(t) for t in all_trues]
pred_tokens = [id2token(p) for p in all_preds]

# Print precision, recall, F1 for each amino acid
print(classification_report(true_tokens, pred_tokens, zero_division=0, digits=10))


Processing sequences: 1953it [00:55, 35.39it/s]


              precision    recall  f1-score   support

       <cls>  1.0000000000 1.0000000000 1.0000000000       318
       <eos>  1.0000000000 1.0000000000 1.0000000000       271
           A  0.5720624627 0.6826188779 0.6224698236      9821
           C  0.0000000000 0.0000000000 0.0000000000      1119
           D  0.3596673597 0.3502024291 0.3548717949      2470
           E  0.5634062747 0.7229364005 0.6332790043      2956
           F  0.3867791842 0.2958579882 0.3352636391      1859
           G  0.9094237811 0.8600328048 0.8840389659      5487
           H  0.0250000000 0.0073260073 0.0113314448       546
           I  0.5677771396 0.4785766158 0.5193747537      4131
           K  0.3681233933 0.3232505643 0.3442307692      2215
           L  0.4487759239 0.6179944170 0.5199638663      4657
           M  0.9002036660 0.4799131379 0.6260623229       921
           N  0.6700167504 0.6768189509 0.6734006734      1182
           P  0.6564500485 0.8521782926 0.7416173570      3971


# ESM-2 fine-tuned Recall, Precision, F1

In [13]:

all_trues = []
all_preds = []

# Use same file path and device setup as before
for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    trues, preds = compute_token_predictions_for_metrics(model_finetuned, tokenizer, seq, mask_prob=0.15, device=device)
    all_trues.extend(trues)
    all_preds.extend(preds)

# Optionally convert to token strings for readability
id2token = tokenizer.convert_ids_to_tokens
true_tokens = [id2token(t) for t in all_trues]
pred_tokens = [id2token(p) for p in all_preds]

# Print precision, recall, F1 for each amino acid
print(classification_report(true_tokens, pred_tokens, zero_division=0, digits=10))


Processing sequences: 1953it [01:16, 25.59it/s]


              precision    recall  f1-score   support

       <cls>  1.0000000000 1.0000000000 1.0000000000       286
       <eos>  1.0000000000 0.9667774086 0.9831081081       301
           A  0.9853016754 0.9881455520 0.9867215646      9701
           C  0.9948186528 0.9956784788 0.9952483801      1157
           D  0.9659591510 0.9824847251 0.9741518578      2455
           E  0.9846461949 0.9836612204 0.9841534612      2999
           F  0.9635274905 0.9440000000 0.9536637931      1875
           G  0.9931072012 0.9969045885 0.9950022717      5492
           H  0.9937888199 0.8695652174 0.9275362319       552
           I  0.9922923918 0.9760332600 0.9840956725      4089
           K  0.9804010939 0.9804010939 0.9804010939      2194
           L  0.9963814389 0.9948990436 0.9956396895      4705
           M  0.9893617021 0.9779179811 0.9836065574       951
           N  0.9872122762 0.9460784314 0.9662077597      1224
           P  0.9892583120 0.9933230611 0.9912865197      3894
