In [13]:
# Import packages
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
from peft import PeftModel, PeftConfig
from autoamp.evolveFinetune import *
import torch
from tqdm import tqdm
import math
from Bio import SeqIO 
import json
import warnings
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import PreTrainedTokenizer

# Example inputs
base_model_name = "facebook/esm2_t30_150M_UR50D" 
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
adapter_checkpoint = "/home/sdowell/scratch/Thesis/BenchmarkingFinetuning/runs/esm_150m_ecoli_finetuning_1/checkpoint-19000"

# Load models
model_pretrained = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_with_adapter = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

# The protein sequence (concatenated)
topA_seq = (
"MGKALVIVESPAKAKTINKYLGSDYVVKSSVGHIRDLPTSGSAAKKSADSTSTKTAKKPK" 
    "KDERGALVNRMGVDPWHNWEAHYEVLPGKEKVVSELKQLAEKADHIYLATDLDREGEAIA" 
    "WHLREVIGGDDARYSRVVFNEITKNAIRQAFNKPGELNIDRVNAQQARRFMDRVVGYMVS" 
    "PLLWKKIARGLSAGRVQSVAVRLVVEREREIKAFVPEEFWEVDASTTTPSGEALALQVTH" 
    "QNDKPFRPVNKEQTQAAVSLLEKARYSVLEREDKPTTSKPGAPFITSTLQQAASTRLGFG" 
    "VKKTMMMAQRLYEAGYITYMRTDSTNLSQDAVNMVRGYISDNFGKKYLPESPNQYASKEN" 
    "SQEAHEAIRPSDVNVMAESLKDMEADAQKLYQLIWRQFVACQMTPAKYDSTTLTVGAGDF" 
    "RLKARGRILRFDGWTKVMPALRKGDEDRILPAVNKGDALTLVELTPAQHFTKPPARFSEA" 
    "SLVKELEKRGIGRPSTYASIISTIQDRGYVRVENRRFYAEKMGEIVTDRLEENFRELMNY" 
    "DFTAQMENSLDQVANHEAEWKAVLDHFFSDFTQQLDKAEKDPEEGGMRPNQMVLTSIDCP" 
    "TCGRKMGIRTASTGVFLGCSGYALPPKERCKTTINLVPENEVLNVLEGEDAETNALRAKR" 
    "RCPKCGTAMDSYLIDPKRKLHVCGNNPTCDGYEIEEGEFRIKGYDGPIVECEKCGSEMHL" 
    "KMGRFGKYMACTNEECKNTRKILRNGEVAPPKEDPVPLPELPCEKSDAYFVLRDGAAGVF" 
    "LAANTFPKSRETRAPLVEELYRFRDRLPEKLRYLADAPQQDPEGNKTMVRFSRKTKQQYV" 
    "SSEKDGKATGWSAFYVDGKWVEGKK" 
)

spoT_seq = ("MYLFESLNQLIQTYLPEDQIKRLRQAYLVARDAHEGQTRSSGEPYITHPVAVACILAEMK"
                    "LDYETLMAALLHDVIEDTPATYQDMEQLFGKSVAELVEGVSKLDKLKFRDKKEAQAENFR"
                    "KMIMAMVQDIRVILIKLADRTHNMRTLGSLRPDKRRRIARETLEIYSPLAHRLGIHHIKT"
                    "ELEELGFEALYPNRYRVIKEVVKAARGNRKEMIQKILSEIEGRLQEAGIPCRVSGREKHL"
                    "YSIYCKMVLKEQRFHSIMDIYAFRVIVNDSDTCYRVLGQMHSLYKPRPGRVKDYIAIPKA"
                    "NGYQSLHTSMIGPHGVPVEVQIRTEDMDQMAEMGVAAHWAYKEHGETSTTAQIRAQRWMQ"
                    "SLLELQQSAGSSFEFIESVKSDLFPDEIYVFTPEGRIVELPAGATPVDFAYAVHTDIGHA"
                    "CVGARVDRQPYPLSQPLTSGQTVEIITAPGARPNAAWLNFVVSSKARAKIRQLLKNLKRD"
                    "DSVSLGRRLLNHALGGSRKLNEIPQENIQRELDRMKLATLDDLLAEIGLGNAMSVVVAKN"
                    "LQHGDASIPPATQSHGHLPIKGADGVLITFAKCCRPIPGDPIIAHVSPGKGLVIHHESCR"
                    "NIRGYQKEPEKFMAVEWDKETAQEFITEIKVEMFNHQGALANLTAAINTTTSNIQSLNTE"
                    "EKDGRVYSAFIRLTARDRVHLANIMRKIRVMPDVIKVTRNRN")

yeiB_seq = ("MERNVTLDFVRGVAILGILLLNISAFGLPKAAYLNPAWYGAITPRDAWTWAFLDLIGQVK"
"FLTLFALLFGAGLQMLLPRGRRWIQSRLTLLVLLGFIHGLLFWDGDILLAYGLVGLICWR"
"LVRDAPSVKSLFNTGVMLYLVGLGVLLLLGLISDSQTSRAWTPDASAILYEKYWKLHGGV"
"EAISNRADGVGNSLLALGAQYGWQLAGMMLIGAALMRSGWLKGQFSLRHYRRTGFVLVAI"
"GVTINLPAIALQWQLDWAYRWCAFLLQMPRELSAPFQAIGYASLFYGFWPQLSRFKLVLA"
"IACVGRMALTNYLLQTLICTTLFYHLGLFMHFDRLELLAFVIPVWLANILFSVIWLRYFR"
"QGPVEWLWRQLTLRAAGPAISKTSR")

# List of mutations provided as strings
gene_mutation = {"topA":"H33Y", "spoT":"K662I", "yeiB":"L143I"}

mutations = [gene_mutation["spoT"]]

# Function to parse a mutation string
def parse_mutation(mutation_str):
    wt = mutation_str[0]  # wild-type residue
    mutant = mutation_str[-1]  # mutant residue
    pos = int(mutation_str[1:-1])  # position as provided (assumed 1-indexed)
    return wt, pos, mutant

def compute_mutation_llr(model, tokenizer, sequence, mutation, device):
    """
    Computes the log-likelihood ratio for the mutation at a specified position.
    
    Args:
       model: The masked language model (either pretrained or finetuned)
       tokenizer: Corresponding tokenizer
       sequence: Protein sequence string
       mutation: Mutation string, e.g. "F33I"
       device: torch.device instance
       
    Returns:
       A tuple: (llr, log_prob_wildtype, log_prob_mutant)
    """
    wt, pos, mutant = parse_mutation(mutation)
    
    # Adjust position from 1-indexed to 0-indexed.
    seq_index = pos - 1
    
    # Optional check: if there are special tokens added, you might need to add an offset.
    # For example, if the tokenizer prepends a BOS token:
    # extra_offset = 1 if tokenizer.cls_token_id is not None else 0
    # token_index = seq_index + extra_offset
    # For simplicity, we assume a direct 1:1 mapping.
    token_index = seq_index

    # Make sure the wild-type residue in the sequence matches what you expect:
    if sequence[seq_index] != wt:
        print(f"Warning: at position {pos}, expected wild-type '{wt}', found '{sequence[seq_index]}'.")

    # Tokenize sequence. (This returns a batch of one.)
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Create a masked input copy.
    masked_input_ids = input_ids.clone()
    mask_token_id = tokenizer.mask_token_id

    # Replace the token at token_index with the mask token.
    # Note: adjust token_index if your tokenizer adds extra tokens.
    masked_input_ids[0, token_index] = mask_token_id
    
    # Forward pass through the model.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(masked_input_ids, attention_mask=attention_mask)
        # outputs.logits: shape [batch, seq_length, vocab_size]
        logits = outputs.logits
    
    # Get logits for the masked position
    masked_logits = logits[0, token_index]
    
    # Compute probabilities over the vocabulary
    probs = torch.nn.functional.softmax(masked_logits, dim=-1)

    # Get the token ids for the wild-type and mutant residues
    wt_token_id = tokenizer.convert_tokens_to_ids(wt)
    mutant_token_id = tokenizer.convert_tokens_to_ids(mutant)
    
    # Extract probabilities. (A small epsilon is added to avoid log(0).)
    eps = 1e-10
    prob_wt = probs[wt_token_id].item()
    prob_mutant = probs[mutant_token_id].item()
    log_prob_wt = math.log(prob_wt + eps)
    log_prob_mutant = math.log(prob_mutant + eps)
    
    # Calculate log-likelihood ratio
    llr = log_prob_mutant - log_prob_wt
    return llr, log_prob_wt, log_prob_mutant

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Evaluate for each mutation on both models.
results = {}
for mutation in mutations:
    wt, pos, mutant = parse_mutation(mutation)
    llr_pretrained, log_wt_pretrained, log_mut_pretrained = compute_mutation_llr(
        model_pretrained, tokenizer, spoT_seq, mutation, device
    )
    llr_finetuned, log_wt_finetuned, log_mut_finetuned = compute_mutation_llr(
        model_finetuned, tokenizer, spoT_seq, mutation, device
    )
    
    results[mutation] = {
        "pretrained": {
            "llr": llr_pretrained,
            "log_prob_wt": log_wt_pretrained,
            "log_prob_mutant": log_mut_pretrained
        },
        "finetuned": {
            "llr": llr_finetuned,
            "log_prob_wt": log_wt_finetuned,
            "log_prob_mutant": log_mut_finetuned
        }
    }

# Print the results.
for mutation, scores in results.items():
    print(f"Mutation {mutation}:")
    print(f"  Pretrained -> LLR: {scores['pretrained']['llr']:.4f}, "
          f"log_prob(wt): {scores['pretrained']['log_prob_wt']:.4f}, "
          f"log_prob(mut): {scores['pretrained']['log_prob_mutant']:.4f}")
    print(f"  Finetuned  -> LLR: {scores['finetuned']['llr']:.4f}, "
          f"log_prob(wt): {scores['finetuned']['log_prob_wt']:.4f}, "
          f"log_prob(mut): {scores['finetuned']['log_prob_mutant']:.4f}")


Mutation K662I:
  Pretrained -> LLR: -0.7654, log_prob(wt): -3.1817, log_prob(mut): -3.9471
  Finetuned  -> LLR: -5.5716, log_prob(wt): -8.6611, log_prob(mut): -14.2327


In [11]:
import torch
import math
import pandas as pd
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Example configuration
base_model_name = "facebook/esm2_t30_150M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Load one of your models; for instance, the pretrained model:
model = AutoModelForMaskedLM.from_pretrained(base_model_name)

def compute_llrs_at_position(model, tokenizer, sequence, position, device):
    """
    Computes the log-likelihood ratios (LLRs) for all possible amino acid substitutions
    at a given position and returns a DataFrame including the token string corresponding
    to each candidate token.
    
    Args:
        model: The masked language model.
        tokenizer: The corresponding tokenizer.
        sequence: The protein sequence (string).
        position: The position to analyze (1-indexed).
        device: torch.device to run the model.
        
    Returns:
        A pandas DataFrame with columns:
          - 'Residue': the candidate amino acid character,
          - 'Token_ID': the candidate token ID,
          - 'Token_Str': the decoded token string,
          - 'Log_Prob': the log probability of that candidate at the masked position,
          - 'LLR': log(prob(candidate)) - log(prob(wild_type))
    """
    # List of the 20 standard amino acids
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
    
    # Adjust position from 1-indexed to 0-indexed.
    idx = position - 1
    
    # Confirm the wild-type residue from the provided sequence.
    wild_type = sequence[idx]
    
    # Tokenize the full sequence (assumes one token per residue)
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Create a copy of input_ids and replace the target position with the mask token.
    masked_input_ids = input_ids.clone()
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("The tokenizer does not have a mask token!")
    
    masked_input_ids[0, idx] = mask_token_id
    
    # Run the model in evaluation mode.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(masked_input_ids, attention_mask=attention_mask)
        # logits shape: [batch_size, sequence_length, vocab_size]
        logits = outputs.logits

    # Extract the logits for our masked position.
    position_logits = logits[0, idx]
    # Convert logits to a probability distribution (using softmax)
    probs = torch.nn.functional.softmax(position_logits, dim=-1)
    
    # Small epsilon to prevent taking log(0)
    eps = 1e-10

    # Get the token ID for the wild-type residue.
    wt_token_id = tokenizer.convert_tokens_to_ids(wild_type)
    if wt_token_id is None:
        raise ValueError(f"Wildtype residue '{wild_type}' could not be converted to a token ID.")
    prob_wt = probs[wt_token_id].item()
    log_prob_wt = math.log(prob_wt + eps)
    
    # Prepare a list to gather results for each candidate amino acid.
    results = []
    for aa in amino_acids:
        # Get the token id for the candidate amino acid.
        aa_token_id = tokenizer.convert_tokens_to_ids(aa)
        if aa_token_id is None:
            # If this amino acid isn't in the vocabulary, skip it.
            continue

        # Convert the token id to the decoded token string using decode.
        # tokenizer.decode expects a list of token ids.
        token_str = tokenizer.decode([aa_token_id]).strip()
        
        prob_aa = probs[aa_token_id].item()
        log_prob_aa = math.log(prob_aa + eps)
        # Compute LLR: difference between the candidate's and wild-type's log probabilities.
        llr = log_prob_aa - log_prob_wt
        
        results.append({
            "Residue": aa,
            "Token_ID": aa_token_id,
            "Token_Str": token_str,
            "Log_Prob": log_prob_aa,
            "LLR": llr
        })
    
    # Convert the results list into a pandas DataFrame.
    df = pd.DataFrame(results)
    return df

# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
position_to_check = 33  # for example, position 33 (1-indexed)
protein_sequence = (
    'MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKA'
    'LIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEA'
    'GAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLY'
    'RAGQSVERTAQQAAAFVKAYREAVQ'
)

df_llrs = compute_llrs_at_position(model, tokenizer, protein_sequence, position=position_to_check, device=device)
print(f"LLR results at position {position_to_check} (wildtype: {protein_sequence[position_to_check - 1]}):")
print(df_llrs)


LLR results at position 33 (wildtype: F):
   Residue  Token_ID Token_Str   Log_Prob        LLR
0        A         5         A  -7.392007   5.114078
1        C        23         C  -9.025281   3.480805
2        D        13         D  -6.565214   5.940872
3        E         9         E  -7.239599   5.266487
4        F        18         F -12.506085   0.000000
5        G         6         G  -0.004772  12.501314
6        H        21         H  -8.125887   4.380199
7        I        12         I -13.383285  -0.877199
8        K        15         K  -8.696727   3.809359
9        L         4         L -11.507031   0.999054
10       M        20         M -12.534882  -0.028796
11       N        17         N  -8.653258   3.852827
12       P        14         P -10.269735   2.236351
13       Q        16         Q  -7.690768   4.815317
14       R        10         R  -8.078082   4.428003
15       S         8         S  -8.037423   4.468662
16       T        11         T -10.585841   1.920244
17  

In [13]:
# Example of how to use the function:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Choose a position to analyze (e.g., position 33, 1-indexed)
position_to_check = 33
df_llrs = compute_llrs_at_position(model_finetuned, tokenizer, 
                                   sequence=(
                                       'MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKA'
                                       'LIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEA'
                                       'GAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLY'
                                       'RAGQSVERTAQQAAAFVKAYREAVQ'
                                   ),
                                   position=position_to_check,
                                   device=device)
print(f"LLR results at position {position_to_check} (wildtype: {('MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKA')[position_to_check - 1]}):")
print(df_llrs)


LLR results at position 33 (wildtype: F):
   Residue  Token_ID Token_Str   Log_Prob        LLR
0        A         5         A -10.526323   2.403356
1        C        23         C  -9.709476   3.220203
2        D        13         D  -8.404804   4.524875
3        E         9         E  -6.972060   5.957619
4        F        18         F -12.929679   0.000000
5        G         6         G  -0.003592  12.926087
6        H        21         H -15.173178  -2.243499
7        I        12         I -12.218802   0.710876
8        K        15         K -11.628019   1.301660
9        L         4         L -11.770803   1.158876
10       M        20         M -12.972339  -0.042660
11       N        17         N -16.749253  -3.819575
12       P        14         P -15.240294  -2.310615
13       Q        16         Q -14.210980  -1.281301
14       R        10         R  -6.639997   6.289682
15       S         8         S  -8.042963   4.886716
16       T        11         T -13.673189  -0.743511
17  

In [9]:
import torch
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForMaskedLM

def compute_llrs_at_position(model, tokenizer, sequence, position, device):
    """
    Computes the log-likelihood ratios (LLRs) for all possible amino acid substitutions
    at a given position and returns a DataFrame.
    
    Args:
        model: The masked language model.
        tokenizer: The corresponding tokenizer.
        sequence: Protein sequence (string).
        position: The position to analyze (1-indexed).
        device: torch.device to run the model.
    
    Returns:
        A pandas DataFrame with columns:
          - 'Residue': candidate amino acid,
          - 'Token_ID': the candidate token id,
          - 'Token_Str': the decoded token string,
          - 'Log_Prob': log probability for candidate at masked position,
          - 'LLR': log(prob(candidate)) - log(prob(wild_type)).
    """
    # List of 20 standard amino acids.
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
    
    # Adjust position from 1-indexed to 0-indexed.
    idx = position - 1
    wild_type = sequence[idx]
    
    # Tokenize the full sequence (assumes one token per residue).
    encoded = tokenizer(sequence, return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # Create a copy and mask the target position.
    masked_input_ids = input_ids.clone()
    mask_token_id = tokenizer.mask_token_id
    if mask_token_id is None:
        raise ValueError("The tokenizer does not have a mask token!")
    masked_input_ids[0, idx] = mask_token_id
    
    # Forward pass through the model.
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(masked_input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch, seq_length, vocab_size]
    
    # Get logits at the masked position and convert to probabilities.
    position_logits = logits[0, idx]
    probs = torch.nn.functional.softmax(position_logits, dim=-1)
    eps = 1e-10  # small epsilon to avoid log(0)
    
    # Get wild-type probability and its log probability.
    wt_token_id = tokenizer.convert_tokens_to_ids(wild_type)
    if wt_token_id is None:
        raise ValueError(f"Wildtype residue '{wild_type}' could not be converted to a token ID.")
    prob_wt = probs[wt_token_id].item()
    log_prob_wt = math.log(prob_wt + eps)
    
    results = []
    for aa in amino_acids:
        aa_token_id = tokenizer.convert_tokens_to_ids(aa)
        if aa_token_id is None:
            continue
        
        # Get the token string from the token id.
        token_str = tokenizer.decode([aa_token_id]).strip()
        prob_aa = probs[aa_token_id].item()
        log_prob_aa = math.log(prob_aa + eps)
        llr = log_prob_aa - log_prob_wt
        results.append({
            "Residue": aa,
            "Token_ID": aa_token_id,
            "Token_Str": token_str,
            "Log_Prob": log_prob_aa,
            "LLR": llr
        })
    
    df = pd.DataFrame(results)
    return df

def compute_llr_matrix(model, tokenizer, sequence, device):
    """
    Computes the LLR for all 20 amino acids at every position in the sequence.
    
    Args:
        model: The masked language model.
        tokenizer: The corresponding tokenizer.
        sequence: The protein sequence (string).
        device: torch.device to run the model.
    
    Returns:
        A numpy array of shape (20, L) containing the LLR values,
        and a list of positions (1-indexed) corresponding to the columns.
        The rows correspond to the 20 standard amino acids in the order "ACDEFGHIKLMNPQRSTVWY".
    """
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
    n_positions = len(sequence)
    llr_matrix = np.zeros((len(amino_acids), n_positions))
    
    # Loop over every position in the sequence.
    for pos in range(1, n_positions + 1):
        df_llrs = compute_llrs_at_position(model, tokenizer, sequence, position=pos, device=device)
        # Reorder rows to match the fixed order of amino_acids.
        df_llrs = df_llrs.set_index("Residue").loc[amino_acids]
        llr_matrix[:, pos - 1] = df_llrs["LLR"].values
        
    return llr_matrix, list(range(1, n_positions + 1))

def plot_llr_heatmap(llr_matrix, sequence, save_path=None):
    """
    Plots a heatmap of LLR values across the entire sequence and optionally saves the figure.
    
    Args:
        llr_matrix: A (20 x L) numpy array where L is the sequence length.
        sequence: The protein sequence (string). Used to annotate the x-axis.
        save_path: File path to save the figure (e.g. "llr_heatmap.png"). If None, the figure isn't saved.
    """
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
    n_positions = len(sequence)
    
    fig, ax = plt.subplots(figsize=(max(10, n_positions/5), 8))
    cax = ax.imshow(llr_matrix, cmap='viridis', aspect='auto')
    
    # Set y-axis ticks: one for each amino acid.
    ax.set_yticks(np.arange(len(amino_acids)))
    ax.set_yticklabels(amino_acids, fontsize=12)
    
    # Set x-axis ticks (display a subset if the sequence is long)
    pos_ticks = np.arange(0, n_positions, max(1, n_positions // 20))
    ax.set_xticks(pos_ticks)
    ax.set_xticklabels([str(p) for p in (np.array(pos_ticks) + 1)], rotation=45, fontsize=10)
    
    ax.set_xlabel("Position in Sequence", fontsize=14)
    ax.set_ylabel("Residue", fontsize=14)
    ax.set_title("LLR Heatmap Across Entire Sequence", fontsize=16)
    fig.colorbar(cax, ax=ax, label="LLR")
    plt.tight_layout()
    
    if save_path is not None:
        plt.savefig(save_path)
        print(f"Heatmap figure saved to {save_path}")
    
    plt.show()

# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
protein_sequence = (
    'MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKA'
    'LIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEA'
    'GAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLY'
    'RAGQSVERTAQQAAAFVKAYREAVQ'
)

# The protein sequence (concatenated)
topA_seq = (
"MGKALVIVESPAKAKTINKYLGSDYVVKSSVGHIRDLPTSGSAAKKSADSTSTKTAKKPK" 
    "KDERGALVNRMGVDPWHNWEAHYEVLPGKEKVVSELKQLAEKADHIYLATDLDREGEAIA" 
    "WHLREVIGGDDARYSRVVFNEITKNAIRQAFNKPGELNIDRVNAQQARRFMDRVVGYMVS" 
    "PLLWKKIARGLSAGRVQSVAVRLVVEREREIKAFVPEEFWEVDASTTTPSGEALALQVTH" 
    "QNDKPFRPVNKEQTQAAVSLLEKARYSVLEREDKPTTSKPGAPFITSTLQQAASTRLGFG" 
    "VKKTMMMAQRLYEAGYITYMRTDSTNLSQDAVNMVRGYISDNFGKKYLPESPNQYASKEN" 
    "SQEAHEAIRPSDVNVMAESLKDMEADAQKLYQLIWRQFVACQMTPAKYDSTTLTVGAGDF" 
    "RLKARGRILRFDGWTKVMPALRKGDEDRILPAVNKGDALTLVELTPAQHFTKPPARFSEA" 
    "SLVKELEKRGIGRPSTYASIISTIQDRGYVRVENRRFYAEKMGEIVTDRLEENFRELMNY" 
    "DFTAQMENSLDQVANHEAEWKAVLDHFFSDFTQQLDKAEKDPEEGGMRPNQMVLTSIDCP" 
    "TCGRKMGIRTASTGVFLGCSGYALPPKERCKTTINLVPENEVLNVLEGEDAETNALRAKR" 
    "RCPKCGTAMDSYLIDPKRKLHVCGNNPTCDGYEIEEGEFRIKGYDGPIVECEKCGSEMHL" 
    "KMGRFGKYMACTNEECKNTRKILRNGEVAPPKEDPVPLPELPCEKSDAYFVLRDGAAGVF" 
    "LAANTFPKSRETRAPLVEELYRFRDRLPEKLRYLADAPQQDPEGNKTMVRFSRKTKQQYV" 
    "SSEKDGKATGWSAFYVDGKWVEGKK" 
)

spoT_seq = ("MYLFESLNQLIQTYLPEDQIKRLRQAYLVARDAHEGQTRSSGEPYITHPVAVACILAEMK"
                    "LDYETLMAALLHDVIEDTPATYQDMEQLFGKSVAELVEGVSKLDKLKFRDKKEAQAENFR"
                    "KMIMAMVQDIRVILIKLADRTHNMRTLGSLRPDKRRRIARETLEIYSPLAHRLGIHHIKT"
                    "ELEELGFEALYPNRYRVIKEVVKAARGNRKEMIQKILSEIEGRLQEAGIPCRVSGREKHL"
                    "YSIYCKMVLKEQRFHSIMDIYAFRVIVNDSDTCYRVLGQMHSLYKPRPGRVKDYIAIPKA"
                    "NGYQSLHTSMIGPHGVPVEVQIRTEDMDQMAEMGVAAHWAYKEHGETSTTAQIRAQRWMQ"
                    "SLLELQQSAGSSFEFIESVKSDLFPDEIYVFTPEGRIVELPAGATPVDFAYAVHTDIGHA"
                    "CVGARVDRQPYPLSQPLTSGQTVEIITAPGARPNAAWLNFVVSSKARAKIRQLLKNLKRD"
                    "DSVSLGRRLLNHALGGSRKLNEIPQENIQRELDRMKLATLDDLLAEIGLGNAMSVVVAKN"
                    "LQHGDASIPPATQSHGHLPIKGADGVLITFAKCCRPIPGDPIIAHVSPGKGLVIHHESCR"
                    "NIRGYQKEPEKFMAVEWDKETAQEFITEIKVEMFNHQGALANLTAAINTTTSNIQSLNTE"
                    "EKDGRVYSAFIRLTARDRVHLANIMRKIRVMPDVIKVTRNRN")

yeiB_seq = ("MERNVTLDFVRGVAILGILLLNISAFGLPKAAYLNPAWYGAITPRDAWTWAFLDLIGQVK"
"FLTLFALLFGAGLQMLLPRGRRWIQSRLTLLVLLGFIHGLLFWDGDILLAYGLVGLICWR"
"LVRDAPSVKSLFNTGVMLYLVGLGVLLLLGLISDSQTSRAWTPDASAILYEKYWKLHGGV"
"EAISNRADGVGNSLLALGAQYGWQLAGMMLIGAALMRSGWLKGQFSLRHYRRTGFVLVAI"
"GVTINLPAIALQWQLDWAYRWCAFLLQMPRELSAPFQAIGYASLFYGFWPQLSRFKLVLA"
"IACVGRMALTNYLLQTLICTTLFYHLGLFMHFDRLELLAFVIPVWLANILFSVIWLRYFR"
"QGPVEWLWRQLTLRAAGPAISKTSR")

# List of mutations provided as strings
gene_mutation = {"topA":"H33Y", "spoT":"K662I", "yeiB":"L143I"}

mutations = [gene_mutation["yeiB"]]

# Example configuration
base_model_name = "facebook/esm2_t30_150M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Load your model; here we use the pretrained model as an example.
model = AutoModelForMaskedLM.from_pretrained(base_model_name)
    
# Compute the full LLR matrix across the sequence.
llr_matrix, positions = compute_llr_matrix(model, tokenizer, spoT_seq, device)

# Create a DataFrame for the heatmap values.
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
df_heatmap = pd.DataFrame(llr_matrix, index=amino_acids, columns=positions)

# Save the heatmap values to a CSV file.
csv_filename = "pretrained_spoT_esm2_ecoli_llr_heatmap.csv"
df_heatmap.to_csv(csv_filename)
print(f"LLR heatmap values saved to {csv_filename}")

# Plot the heatmap and save the figure.
#figure_filename = "pretrained_esm2_ecoli_llr_heatmap.png"
#plot_llr_heatmap(llr_matrix, protein_sequence, save_path=figure_filename)


LLR heatmap values saved to pretrained_spoT_esm2_ecoli_llr_heatmap.csv


In [10]:
# Compute the full LLR matrix across the sequence.
llr_matrix, positions = compute_llr_matrix(model_finetuned, tokenizer, spoT_seq, device)

# Create a DataFrame for the heatmap values.
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
df_heatmap = pd.DataFrame(llr_matrix, index=amino_acids, columns=positions)

# Save the heatmap values to a CSV file.
csv_filename = "finetuned_spoT_esm2_ecoli_llr_heatmap.csv"
df_heatmap.to_csv(csv_filename)
print(f"LLR heatmap values saved to {csv_filename}")

# Plot the heatmap and save the figure.
#figure_filename = "finetuned_dgoa_llr_heatmap.png"
#plot_llr_heatmap(llr_matrix, protein_sequence, save_path=figure_filename)

LLR heatmap values saved to finetuned_spoT_esm2_ecoli_llr_heatmap.csv


In [8]:
import pandas as pd

# List of mutations as strings.
gene_mutation = {"topA":"H33Y", "spoT":"K662I", "yeiB":"L143I"}

def parse_mutation(mutation_str):
    """
    Parses a mutation string such as "F33I" into:
      - wt: the wild-type residue (for reference)
      - pos: the position (1-indexed)
      - mut: the mutant residue.
    """
    wt = mutation_str[0]      # For example, 'F'
    mutant = mutation_str[-1]  # For example, 'I'
    pos = int(mutation_str[1:-1])
    return wt, pos, mutant

# Load CSV files.
pretrained_csv = "pretrained_topA_esm2_ecoli_llr_heatmap.csv"
finetuned_csv = "finetuned_topA_esm2_ecoli_llr_heatmap.csv"

# Read CSV with the first column as the index (the amino acid letters)
pretrained_df = pd.read_csv(pretrained_csv, index_col=0)
finetuned_df = pd.read_csv(finetuned_csv, index_col=0)

# Ensure column headers (positions) are strings.
pretrained_df.columns = pretrained_df.columns.astype(str)
finetuned_df.columns = finetuned_df.columns.astype(str)

print("LLR values and delta for each mutation:\n")

mutations = [gene_mutation["topA"]]

for mutation in mutations:
    wt, pos, mut = parse_mutation(mutation)
    pos_str = str(pos)
    
    try:
        # Look up LLR values for the mutant residue at the given position.
        pretrained_llr = pretrained_df.loc[mut, pos_str]
        finetuned_llr = finetuned_df.loc[mut, pos_str]
    except KeyError as e:
        print(f"Error: Could not find data for mutation {mutation}: {e}")
        continue
        
    # Compute the delta: finetuned LLR minus pretrained LLR.
    delta = finetuned_llr - pretrained_llr
        
    # Print the results.
    print(f"Mutation {mutation}:")
    print(f"  Pretrained LLR: {pretrained_llr}")
    print(f"  Finetuned LLR:  {finetuned_llr}")
    print(f"  Delta (Finetuned - Pretrained): {delta}\n")


LLR values and delta for each mutation:

Mutation H33Y:
  Pretrained LLR: -0.7890436046188132
  Finetuned LLR:  2.149750789218812
  Delta (Finetuned - Pretrained): 2.9387943938376253



In [11]:
mutations = [gene_mutation["spoT"]]

for mutation in mutations:
    wt, pos, mut = parse_mutation(mutation)
    pos_str = str(pos)
    
    try:
        # Look up LLR values for the mutant residue at the given position.
        pretrained_llr = pretrained_df.loc[mut, pos_str]
        finetuned_llr = finetuned_df.loc[mut, pos_str]
    except KeyError as e:
        print(f"Error: Could not find data for mutation {mutation}: {e}")
        continue
        
    # Compute the delta: finetuned LLR minus pretrained LLR.
    delta = finetuned_llr - pretrained_llr
        
    # Print the results.
    print(f"Mutation {mutation}:")
    print(f"  Pretrained LLR: {pretrained_llr}")
    print(f"  Finetuned LLR:  {finetuned_llr}")
    print(f"  Delta (Finetuned - Pretrained): {delta}\n")


Mutation K662I:
  Pretrained LLR: 3.547046870046472
  Finetuned LLR:  -4.420149542104095
  Delta (Finetuned - Pretrained): -7.967196412150567

