## Sequence Entropy based on ESM2

In [1]:
# Check for GPU and set device
# Python 3.14.0

import esm
import torch 
import torch.nn.functional as F
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize ESM-2 model (150M parameters) in evaluation mode
model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
model = model.to(device)
model.eval()

batch_converter = alphabet.get_batch_converter()

# Define standard 20 amino acids vocabulary for filtering predictions
aa_list = list("ACDEFGHIKLMNPQRSTVWY")
aa_to_idx = {aa: alphabet.get_idx(aa) for aa in aa_list}



Using device: cuda


In [2]:
def rectify_logits(logits: torch.Tensor, epsilon: float = 1e-12) -> torch.Tensor:
    """
    Computes numerically stable log-probabilities from logits using epsilon smoothing.

    Args:
        logits: Tensor of shape (vocab_size,).
        epsilon: Floor value to prevent log(0).

    Returns:
        Tensor of shape (vocab_size,) containing log-probabilities in original dtype.
    """
    # Use float64 to maximize precision during softmax accumulation
    p = logits.to(dtype=torch.float64).softmax(dim=-1)
    
    # Apply smoothing and renormalize to ensure valid distribution (sum=1)
    p = p + epsilon
    p = p / p.sum(dim=-1, keepdims=True)
    
    return torch.log(p).to(dtype=logits.dtype)


In [3]:
def compute_masked_entropy(seq: str):
    """
    Computes per-position Shannon entropy over the 20 standard amino acids using MLM masking.
    
    The function iteratively masks each position, obtains the model's predicted distribution 
    over the 20 standard AAs, and calculates the entropy of this conditional distribution.

    Args:
        seq: Wild-type protein sequence (string of length L).

    Returns:
        entropies: list[float] of length L, Shannon entropy (in bits) at each position.
    """
    L = len(seq)
    mut_pref_matrix = np.zeros((L, 20), dtype=np.float32)
    
    # Pre-calculate mapping from standard AA to model vocab index. This must use the model alphabet (not a simple 0..19 mapping).
    aa_vocab_idx = {aa: alphabet.get_idx(aa) for aa in aa_list}

    for i in range(L):
        # Mask current position
        masked_seq = seq[:i] + "<mask>" + seq[i+1:]
        data = [("seq", masked_seq)]

        # Prepare input (assume batch_converter handles tokenization and padding)
        _, _, tokens = batch_converter(data)
        tokens = tokens.to(device)

        with torch.no_grad():
            # logits shape: (batch=1, L+2, vocab_size). 
            # Slicing [1:-1] removes BOS and EOS tokens to align with seq.
            logits = model(tokens)["logits"][0, 1:-1]

        # Locate mask token to extract specific predictions
        mask_positions = (tokens[0] == alphabet.mask_idx).nonzero(as_tuple=False).squeeze(-1)
        if mask_positions.numel() != 1:
            raise ValueError(
                f"Expected exactly one <mask> token, but found {mask_positions.numel()} at position i={i}."
            )
        mask_idx_in_tokens = mask_positions.item()

        # Note: tokens include BOS, so mask index needs -1 offset to align with sliced logits
        pos_in_logits = mask_idx_in_tokens - 1

        # Get vocab log-probabilities at the masked position
        log_probs_vocab = rectify_logits(logits[pos_in_logits])  # (vocab_size,)

        # Extract log-probabilities for the 20 standard amino acids
        for j, aa in enumerate(aa_list):
            mut_pref_matrix[i, j] = log_probs_vocab[aa_vocab_idx[aa]].item()

    # Convert the 20-AA log-probabilities into a 20-AA conditional distribution per position
    llr_tensor = torch.tensor(mut_pref_matrix, dtype=torch.float32)
    probs_20aa = F.softmax(llr_tensor, dim=1).cpu().numpy()

    # Compute Shannon entropy per position (in bits)
    eps = 1e-12
    entropy_per_pos = -np.sum(probs_20aa * np.log2(probs_20aa + eps), axis=1)

    return entropy_per_pos.tolist()

# calculate the sequence entropy of Q8L9G7
aa_seq = "MTVAAGIGYALVALGPSLSLFVSVISRKPFLILTVLSSTLLWLVSLIILSGLWRPFLPLKANVWWPYALLVITSVCFQEGLRFLFWKVYKRLEDVLDSFADRISRPRLFLTDKLQIALAGGLGHGVAHAVFFCLSLLTPAFGPATFYVERCSKVPFFLISAIIALAFVTIHTFSMVIAFEGYAKGNKVDQIIVPVIHLTAGMLTLVNFASEGCVIGVPLLYLVASLTLVHCGKMVWQRLLESRNQSSASR"
entropy_per_pos = compute_masked_entropy(aa_seq)
mean_entropy = round(np.mean(entropy_per_pos), 6)

print('the sequence entropy is {:.6f}'.format(mean_entropy))

the sequence entropy is 3.564803
