In [14]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from Bio import SeqIO
from tqdm import tqdm

def compute_next_token_prediction_accuracy(model, tokenizer, sequence, device=torch.device("cpu")):
    """
    Computes next token prediction accuracy for a causal language model.
    
    For a given sequence, this function computes the probability distribution for the next token
    at every position (except the first token, which has no preceding context) and compares the 
    predicted token (argmax of the model's logits) with the actual token.
    
    Args:
        model: The autoregressive language model (e.g., progen2).
        tokenizer: The corresponding tokenizer.
        sequence: A string representing the protein (or other) sequence.
        device: Torch device on which to run computations.
    
    Returns:
        A tuple (correct, total) where:
          - correct: number of next-token predictions that exactly match the target token.
          - total: total number of predictions made (i.e., sequence length minus one).
    """
    # Tokenize the sequence.
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Forward pass through the model.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch, sequence_length, vocab_size]
    
    # Shift logits to compare predictions at positions 0 ... (L-2) with target tokens at 1 ... (L-1)
    logits_shifted = logits[:, :-1, :]
    predictions = torch.argmax(logits_shifted, dim=-1)
    targets = input_ids[:, 1:]
    
    # Count correct predictions.
    correct = (predictions == targets).sum().item()
    total = targets.numel()  # Total number of predictions
    
    return correct, total

# ----- Load model and tokenizer -----
base_model_name = "hugohrban/progen2-small"
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
model_pretrained = AutoModelForCausalLM.from_pretrained(base_model_name, trust_remote_code=True)

adapter_checkpoint = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/progen2_151m_ecoli_finetuning_1"
model_with_adapter = AutoModelForCausalLM.from_pretrained(base_model_name, trust_remote_code=True)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

# ----- Set device -----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----- Process the entire test set from a FASTA file -----
fasta_file = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/dataset_splits/finetuning_dataset/test.fasta"

# Process pretrained model.
total_correct = 0
total_predictions = 0
n_sequences = 0

for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    correct, predictions = compute_next_token_prediction_accuracy(model_pretrained, tokenizer, seq, device)
    total_correct += correct
    total_predictions += predictions
    n_sequences += 1

overall_accuracy = total_correct / total_predictions if total_predictions > 0 else 0.0
print(f"\nProcessed {n_sequences} sequences.")
print(f"Pretrained model Overall Next Token Prediction Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_predictions})")

# Reinitialize the counters for the finetuned model.
total_correct = 0
total_predictions = 0
n_sequences = 0

for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    correct, predictions = compute_next_token_prediction_accuracy(model_finetuned, tokenizer, seq, device)
    total_correct += correct
    total_predictions += predictions
    n_sequences += 1

overall_accuracy = total_correct / total_predictions if total_predictions > 0 else 0.0
print(f"\nProcessed {n_sequences} sequences.")
print(f"Finetuned model Overall Next Token Prediction Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_predictions})")


Processing sequences: 469it [00:13, 34.26it/s]



Processed 469 sequences.
Pretrained model Overall Next Token Prediction Accuracy: 0.6015 (191120/317715)


Processing sequences: 469it [00:14, 32.95it/s]


Processed 469 sequences.
Finetuned model Overall Next Token Prediction Accuracy: 0.9952 (316205/317715)





# ProGen2 Pre-trained Recall, Precision, F1

In [17]:
from collections import Counter

def compute_token_level_metrics(model, tokenizer, sequence, device=torch.device("cpu")):
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    logits_shifted = logits[:, :-1, :]
    predictions = torch.argmax(logits_shifted, dim=-1)
    targets = input_ids[:, 1:]

    preds = predictions.view(-1).tolist()
    trues = targets.view(-1).tolist()

    return trues, preds

from sklearn.metrics import classification_report

all_trues = []
all_preds = []

for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    trues, preds = compute_token_level_metrics(model_pretrained, tokenizer, seq, device)
    all_trues.extend(trues)
    all_preds.extend(preds)

# Convert token IDs back to tokens (optional but makes report readable)
id2token = tokenizer.convert_ids_to_tokens

true_tokens = [id2token(t) for t in all_trues]
pred_tokens = [id2token(p) for p in all_preds]

print(classification_report(true_tokens, pred_tokens, zero_division=0, digits=10))


Processing sequences: 469it [00:13, 35.77it/s]


              precision    recall  f1-score   support

           2  0.00000000 0.00000000 0.00000000         0
           A  0.52735066 0.61495980 0.56779565     28610
           C  0.99314317 0.78820200 0.87888350      4594
           D  0.51745205 0.59601028 0.55395990     16342
           E  0.53813590 0.68990893 0.60464366     21742
           F  0.71216475 0.62726736 0.66702552     11853
           G  0.65625135 0.69199835 0.67365096     21883
           H  0.72688172 0.46510714 0.56725006      5087
           I  0.88326271 0.54855263 0.67678571     15200
           K  0.49651198 0.55596526 0.52455940     20611
           L  0.45155999 0.80028749 0.57735119     31305
           M  0.79707856 0.48815280 0.60548808      8272
           N  0.79842715 0.33700210 0.47395577     11448
           P  0.67230225 0.73629679 0.70284584     14960
           Q  0.71782865 0.33721688 0.45886918     12274
           R  0.67363936 0.61669400 0.64391012     22559
           S  0.54737634 0.392228

# ProGen2 fine-tuned Recall, Precisio, F1

In [18]:
from sklearn.metrics import classification_report

all_trues = []
all_preds = []

for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    trues, preds = compute_token_level_metrics(model_finetuned, tokenizer, seq, device)
    all_trues.extend(trues)
    all_preds.extend(preds)

# Convert token IDs back to tokens (optional but makes report readable)
id2token = tokenizer.convert_ids_to_tokens

true_tokens = [id2token(t) for t in all_trues]
pred_tokens = [id2token(p) for p in all_preds]

print(classification_report(true_tokens, pred_tokens, zero_division=0, digits=10))


Processing sequences: 469it [00:13, 35.01it/s]


              precision    recall  f1-score   support

           A  0.99494756 0.99804264 0.99649270     28610
           C  0.99912358 0.99259904 0.99585062      4594
           D  0.99485609 0.99412557 0.99449070     16342
           E  0.99763396 0.98905344 0.99332517     21742
           F  0.99847999 0.99755336 0.99801646     11853
           G  0.98684270 0.99739524 0.99209091     21883
           H  0.99448493 0.99252998 0.99350649      5087
           I  0.99519674 0.99506579 0.99513126     15200
           K  0.99767386 0.99883557 0.99825438     20611
           L  0.99773206 0.99776394 0.99774800     31305
           M  0.99564797 0.99564797 0.99564797      8272
           N  0.99339323 0.98506289 0.98921053     11448
           P  0.99599439 0.99725936 0.99662647     14960
           Q  0.99431357 0.99722992 0.99576961     12274
           R  0.99658658 0.99654240 0.99656449     22559
           S  0.99021713 0.99389295 0.99205163     16702
           T  0.99360265 0.992838