In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from Bio import SeqIO
from tqdm import tqdm
from peft import PeftModel, PeftConfig

def compute_next_token_prediction_accuracy(model, tokenizer, sequence, device=torch.device("cpu")):
    """
    Computes next token prediction accuracy for a causal language model.
    
    For a given sequence, this function computes the probability distribution for the next token
    at every position (except the first token, which has no preceding context) and compares the 
    predicted token (argmax of the model's logits) with the actual token.
    
    Args:
        model: The autoregressive language model (e.g., progen2).
        tokenizer: The corresponding tokenizer.
        sequence: A string representing the protein (or other) sequence.
        device: Torch device on which to run computations.
    
    Returns:
        A tuple (correct, total) where:
          - correct: number of next-token predictions that exactly match the target token.
          - total: total number of predictions made (i.e., sequence length minus one).
    """
    # Tokenize the sequence.
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Forward pass through the model.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch, sequence_length, vocab_size]
    
    # Shift logits to compare predictions at positions 0 ... (L-2) with target tokens at 1 ... (L-1)
    logits_shifted = logits[:, :-1, :]
    predictions = torch.argmax(logits_shifted, dim=-1)
    targets = input_ids[:, 1:]
    
    # Count correct predictions.
    correct = (predictions == targets).sum().item()
    total = targets.numel()  # Total number of predictions
    
    return correct, total

# ----- Load model and tokenizer -----
base_model_name = "hugohrban/progen2-small"
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
model_pretrained = AutoModelForCausalLM.from_pretrained(base_model_name, trust_remote_code=True)

adapter_checkpoint = "/home/sdowell/scratch/Thesis/ADP1/runs/progen2_dgoa_finetune_1/checkpoint-3000"
model_with_adapter = AutoModelForCausalLM.from_pretrained(base_model_name, trust_remote_code=True)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

# ----- Set device -----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----- Process the entire test set from a FASTA file -----
fasta_file = "/home/sdowell/scratch/Thesis/ADP1/finetuning_data/test/dgoa_mutants_test.fasta"

# Process pretrained model.
total_correct = 0
total_predictions = 0
n_sequences = 0

for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    correct, predictions = compute_next_token_prediction_accuracy(model_pretrained, tokenizer, seq, device)
    total_correct += correct
    total_predictions += predictions
    n_sequences += 1

overall_accuracy = total_correct / total_predictions if total_predictions > 0 else 0.0
print(f"\nProcessed {n_sequences} sequences.")
print(f"Pretrained model Overall Next Token Prediction Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_predictions})")

# Reinitialize the counters for the finetuned model.
total_correct = 0
total_predictions = 0
n_sequences = 0

for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    correct, predictions = compute_next_token_prediction_accuracy(model_finetuned, tokenizer, seq, device)
    total_correct += correct
    total_predictions += predictions
    n_sequences += 1

overall_accuracy = total_correct / total_predictions if total_predictions > 0 else 0.0
print(f"\nProcessed {n_sequences} sequences.")
print(f"Finetuned model Overall Next Token Prediction Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_predictions})")


Processing sequences: 1953it [00:31, 61.82it/s]



Processed 1953 sequences.
Pretrained model Overall Next Token Prediction Accuracy: 0.4810 (189516/394025)


Processing sequences: 1953it [00:32, 59.21it/s]


Processed 1953 sequences.
Finetuned model Overall Next Token Prediction Accuracy: 0.9813 (386659/394025)





# ProGen2 Pre-trained Recall, Precision, F1

In [4]:
from collections import Counter

def compute_token_level_metrics(model, tokenizer, sequence, device=torch.device("cpu")):
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    logits_shifted = logits[:, :-1, :]
    predictions = torch.argmax(logits_shifted, dim=-1)
    targets = input_ids[:, 1:]

    preds = predictions.view(-1).tolist()
    trues = targets.view(-1).tolist()

    return trues, preds

from sklearn.metrics import classification_report

all_trues = []
all_preds = []

for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    trues, preds = compute_token_level_metrics(model_pretrained, tokenizer, seq, device)
    all_trues.extend(trues)
    all_preds.extend(preds)

# Convert token IDs back to tokens (optional but makes report readable)
id2token = tokenizer.convert_ids_to_tokens

true_tokens = [id2token(t) for t in all_trues]
pred_tokens = [id2token(p) for p in all_preds]

print(classification_report(true_tokens, pred_tokens, zero_division=0, digits=10))


Processing sequences: 1953it [00:28, 67.77it/s]


              precision    recall  f1-score   support

           2  0.0000000000 0.0000000000 0.0000000000         0
           A  0.4206917287 0.6628715519 0.5147176618     65435
           C  0.0033149171 0.0003905741 0.0006988120      7681
           D  0.2133455112 0.2425334392 0.2270050877     16373
           E  0.4952723399 0.5689154046 0.5295457782     19611
           F  0.3915498639 0.3122386619 0.3474254015     12436
           G  0.8263081210 0.8742663962 0.8496110177     36975
           H  0.0373134328 0.0041701418 0.0075018755      3597
           I  0.5933415754 0.2468677665 0.3486676526     27217
           K  0.5085600734 0.2262341901 0.3131588855     14706
           L  0.3394532064 0.6190370605 0.4384689596     31570
           M  0.0720000000 0.0485118350 0.0579669560      4267
           N  0.8333583998 0.6924518870 0.7563988806      8002
           P  0.6292978709 0.6924709986 0.6593747744     26378
           Q  0.2712440517 0.0362068966 0.0638859979     22040


# ProGen2 fine-tuned Recall, Precisio, F1

In [5]:
from sklearn.metrics import classification_report

all_trues = []
all_preds = []

for record in tqdm(SeqIO.parse(fasta_file, "fasta"), desc="Processing sequences"):
    seq = str(record.seq).strip()
    if not seq:
        continue
    trues, preds = compute_token_level_metrics(model_finetuned, tokenizer, seq, device)
    all_trues.extend(trues)
    all_preds.extend(preds)

# Convert token IDs back to tokens (optional but makes report readable)
id2token = tokenizer.convert_ids_to_tokens

true_tokens = [id2token(t) for t in all_trues]
pred_tokens = [id2token(p) for p in all_preds]

print(classification_report(true_tokens, pred_tokens, zero_division=0, digits=10))


Processing sequences: 1953it [00:29, 65.70it/s]


              precision    recall  f1-score   support

           A  0.9842795954 0.9903415603 0.9873012729     65435
           C  0.9963432154 0.9932300482 0.9947841961      7681
           D  0.9522445081 0.9742869358 0.9631396226     16373
           E  0.9834345952 0.9656825251 0.9744777195     19611
           F  0.9583992407 0.9744290769 0.9663476874     12436
           G  0.9935012404 0.9964300203 0.9949634751     36975
           H  0.9904640814 0.8662774534 0.9242177073      3597
           I  0.9902736014 0.9800859757 0.9851534513     27217
           K  0.9321231255 0.9636882905 0.9476429288     14706
           L  0.9958693442 0.9927779538 0.9943212462     31570
           M  0.9881460408 0.9767986876 0.9824395993      4267
           N  0.9886540916 0.8711572107 0.9261941141      8002
           P  0.9885978776 0.9959435894 0.9922571385     26378
           Q  0.9651863914 0.9950090744 0.9798708697     22040
           R  0.9901232788 0.9890537232 0.9895882120     13886
