In [None]:
import torch
import sys
sys.path.append("/novo/users/cpjb/rp0689_lm_finetune/env/pLM_FT/lib/python3.10/site-packages")
import esm
import warnings
import random
import numpy as np 

In [None]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class ESMPseudoPerplexity:
    """
    A class to calculate the pseudo-perplexity of protein sequences using the ESM-2 model.

    This class loads a pre-trained ESM-2 model upon initialization and provides a method
    to calculate the pseudo-log-likelihood (pLL) and pseudo-perplexity for any
    given protein sequence. This approach is efficient for scoring multiple sequences
    as the model is loaded only once.

    Attributes:
        model: The loaded ESM-2 model.
        alphabet: The alphabet used by the ESM-2 model.
        batch_converter: A utility to convert sequences to batched tensors.
        device: The computing device (CUDA or CPU) where the model is located.

    USAGE:

        # In a new script or notebook:
        # from ESM2_utils import ESMPseudoPerplexity

        # # Initialize the calculator (loads the model)
        # calculator = ESMPseudoPerplexity()

        # # Calculate the pLL score for any sequence
        # my_sequence = "MKALIVLGLVLLSVTVQGKVQ"
        # score = calculator.calculate_pll_score(my_sequence)

    # print(f"The pseudo-perplexity score for '{my_sequence}' is: {score:.4f}")
    """
    def __init__(self, model_name: str = 'esm2_t33_650M_UR50D'):
        """
        Initializes the model, alphabet, and batch converter.

        Args:
            model_name (str): The name of the pre-trained ESM-2 model to load.
        """
        self.mask_str_token = "<mask>"
        self.model, self.alphabet = esm.pretrained.load_model_and_alphabet(model_name)
        self.batch_converter = self.alphabet.get_batch_converter()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.model.eval()  # Disable dropout for deterministic results

        print(f"ESM-2 model '{model_name}' loaded on {self.device}.")

    def _generate_masked_sequences(self, sequence: str, mask_length: list):
        """
        Generates sequences with a sliding window of masks of length N.

        Args:
            sequence (str): The input protein sequence.
            mask_length (int): The number of adjacent tokens to mask (N).

        Yields:
            A tuple containing the masked sequence string, the start index of the
            mask, and the end index of the mask.
        """
        np.random.seed(0)

        seq_indexes = np.arange(0,len(sequence))
        np.random.shuffle(seq_indexes)
        batched_seq_indexes = list(batch(seq_indexes,n=mask_length))

        all_sequences = []
        for masked_index in batched_seq_indexes:
            seq_copy = list(sequence).copy()
            for index in masked_index:
                seq_copy[index] = self.mask_str_token
            
            all_sequences.append((masked_index,"".join(seq_copy)))
        
        return all_sequences

    @torch.no_grad()
    def calculate_pll_score(self, sequence: str, mask_length: int = 1) -> float:
        """
        Calculates a score based on the model's ability to predict residues
        within a sliding window of N masked tokens.

        When mask_length=1, this is equivalent to standard pseudo-log-likelihood (pLL).

        Args:
            sequence (str): The input protein sequence.
            mask_length (int): The length of the mask stretch (N).
            batch_size (int): The number of masked sequences to process in each batch.

        Returns:
            float: The calculated pseudo-perplexity-like score.
        """
        if not sequence or not isinstance(sequence, str):
            raise ValueError("Input sequence must be a non-empty string.")

        # Set seed for suffling
        np.random.seed(0)
                
        # 1. Generate all sequences with sliding window masks
        masked_data = self._generate_masked_sequences(sequence, mask_length)

        # ESM-2 input
        ESM_input = [(i, masked_seq[1]) for i, masked_seq in enumerate(masked_data)]
     
        # 2. Convert to batches
        batch_labels, batch_strs, batch_tokens = self.batch_converter(ESM_input)
        batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1)
        batch_tokens = batch_tokens.to(self.device)
        
        # 3. Pass through ESM-2
        results = self.model(batch_tokens, repr_layers=[33], return_contacts=False)
        logits = results["logits"]
        logit_prob = torch.nn.functional.log_softmax(logits, dim=-1)

        # 4. Calcualte PLL
        #     
        log_likelihood = 0
        for i,(masked_index, _) in enumerate(masked_data):
            for j in masked_index:
                log_likelihood += logit_prob[i, j+1, self.alphabet.get_idx(sequence[j])]
        
        # Calculate the average log likelihood per token
        avg_log_likelihood = log_likelihood / len(sequence)

        # Compute and return the pseudo-perplexity
        pll = torch.exp(-torch.tensor(avg_log_likelihood)).item()
        return pll

In [None]:
# This block demonstrates how to use the class.
# It will only run when the script is executed directly.
if __name__ == '__main__':
    try:
        # 1. Initialize the calculator once. The model is loaded here.
        pll_calculator = ESMPseudoPerplexity()
        # 2. Provide a sequence to calculate its score
        # Using a truncated FGF10_HUMAN sequence as an example
        print("\nCalculating pLL score for a natural protein...")
        natural_protein_score = pll_calculator.calculate_pll_score(human_protein_sequence)

        print(f"-> Pseudo-Perplexity for natural protein: {natural_protein_score:.4f}")
        # Example with a non-natural (poly-Alanine) sequence
        non_natural_sequence = "A" * 80
        print("\nCalculating pLL score for a non-natural sequence...")
        non_natural_score = pll_calculator.calculate_pll_score(non_natural_sequence)
        print(f"-> Pseudo-Perplexity for non-natural protein: {non_natural_score:.4f}")

        # Example of scoring multiple sequences efficiently
        print("\nCalculating pLL for a list of short peptides...")
        sequences_to_score = ["MKYKL", "VLLLE", "AGVTV"]
        for seq in sequences_to_score:
            score = pll_calculator.calculate_pll_score(seq)
            print(f"- pLL for '{seq}': {score:.4f}")

    except ImportError:
        warnings.warn("Please install the 'fair-esm' library to run this script. Use: pip install fair-esm")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")