# Global embedding analysis of TP53 missense mutations

**Goal:**  
To test whether global protein embeddings can distinguish functional vs non-functional TP53 missense mutations observed in tumor data.

**Key result:**  
Functional and non-functional mutations largely overlap in global embedding space, indicating that most missense mutations do not cause large-scale structural changes.

**Interpretation:**  
Loss of TP53 function is driven primarily by local structural disruption rather than global unfolding.


### Generating all possible Mutations
The wild type TP53 was taken from uniprot Database.Then using a function all possible missense mutation was generated.

In [None]:
TP53='MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'

In [None]:
len(TP53)

393

In [None]:
from typing import Dict

# Standard 20 amino acids
AMINO_ACIDS = [
    "A", "C", "D", "E", "F",
    "G", "H", "I", "K", "L",
    "M", "N", "P", "Q", "R",
    "S", "T", "V", "W", "Y"
]


def generate_single_mutations(
    protein_sequence: str
) -> Dict[str, str]:
    """
    Generate all possible single amino-acid substitutions
    for a protein sequence.

    Parameters
    ----------
    protein_sequence : str
        Wild-type protein sequence (e.g. TP53)

    Returns
    -------
    mutations : dict
        Dictionary where:
        key   = mutation code (e.g. 'R175H')
        value = mutated protein sequence
    """

    mutations = {}
    seq_len = len(protein_sequence)

    # Loop over each position in the sequence
    for i in range(seq_len):
        wt_aa = protein_sequence[i]  # wild-type amino acid
        pos = i + 1                  # 1-based indexing (biological convention)

        # Try all possible amino-acid substitutions
        for mut_aa in AMINO_ACIDS:
            # Skip if mutation is same as wild-type
            if mut_aa == wt_aa:
                continue

            # Construct mutation code (e.g. R175H)
            mutation_code = f"{wt_aa}{pos}{mut_aa}"

            # Create mutated sequence
            mutated_sequence = (
                protein_sequence[:i]
                + mut_aa
                + protein_sequence[i + 1:]
            )

            mutations[mutation_code] = mutated_sequence

    return mutations


In [None]:
mutations = generate_single_mutations(TP53)

print(len(mutations))       # 7467
print(mutations["R175H"])   # mutated TP53 sequence

7467
MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRHCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD


### Making the Input ready
To get embeddings for each protein we need to make the input data ready.

The output is a list of tuples where each tuple is

 (mutation_name,sequence)



In [None]:
from typing import Dict, List, Tuple


def mutation_dict_to_esm_input(
    mutation_dict: Dict[str, str],
    protein_name: str = "TP53"
) -> List[Tuple[str, str]]:
    """
    Convert a mutation dictionary into ESM batch input format.

    Parameters
    ----------
    mutation_dict : dict
        Key   = mutation code (e.g. 'R175H')
        Value = mutated protein sequence

    protein_name : str
        Prefix for sequence names (e.g. 'TP53')

    Returns
    -------
    esm_sequences : list of tuples
        Format: [(name, sequence), ...]
        Example: [('TP53_R175H', 'MEEPQSDPSV...')]
    """

    esm_sequences = []

    for mutation_code, sequence in mutation_dict.items():
        # Create a unique sequence identifier for ESM
        seq_name = f"{protein_name}_{mutation_code}"

        # Append in ESM-required format
        esm_sequences.append((seq_name, sequence))

    return esm_sequences


In [None]:
mutations = generate_single_mutations(TP53)

esm_input = mutation_dict_to_esm_input(
    mutations,
    protein_name="TP53"
)

print(esm_input[0])
# ('TP53_M1A', 'MEEPQSDPSV...')


('TP53_M1A', 'AEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD')


In [None]:
wt_input = [("TP53_WT", TP53)]


## Set Up model and batch converter

In [None]:
!pip install fair-esm

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [None]:
import torch

In [None]:
device="cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [None]:
import torch

def extract_esm_embeddings_batch(
    model,
    batch_converter,
    sequences,
    repr_layer=5,
    batch_size=4,
    return_per_residue=True,
    device='cuda'
):
    """
    sequences: list of (name, sequence) tuples
    returns:
        global_embeddings: torch.Tensor (N, D)
        residue_embeddings: list of torch.Tensor (L_i, D) if return_per_residue
    """

    #device = next(model.parameters()).device
    global_embeddings = []
    residue_embeddings = []

    for i in range(0, len(sequences), batch_size):
        batch = sequences[i:i + batch_size]

        labels, strs, tokens = batch_converter(batch)
        tokens = tokens.to(device)

        with torch.no_grad():
            outputs = model(tokens, repr_layers=[repr_layer])
            reps = outputs["representations"][repr_layer]  # (B, L, D)

            # mask padding tokens
            mask = tokens != model.padding_idx  # (B, L)

            # ----- GLOBAL MEAN POOLING -----
            mask_expanded = mask.unsqueeze(-1)  # (B, L, 1)
            pooled = (reps * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
            global_embeddings.append(pooled.cpu())

            # ----- PER-RESIDUE EMBEDDINGS -----
            if return_per_residue:
                for b in range(reps.size(0)):
                    seq_len = mask[b].sum().item()-2
                    residue_embeddings.append(
                        reps[b, 1:seq_len+1].cpu()
                    )
                    # NOTE: skip CLS token at index 0

    global_embeddings = torch.cat(global_embeddings, dim=0)

    if return_per_residue:
        return global_embeddings, residue_embeddings
    else:
        return global_embeddings


In [None]:
import torch
import esm

model,alphabet=esm.pretrained.esm2_t6_8M_UR50D()
batch_converter=alphabet.get_batch_converter()
model.eval()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t6_8M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t6_8M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t6_8M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t6_8M_UR50D-contact-regression.pt


ESM2(
  (embed_tokens): Embedding(33, 320, padding_idx=1)
  (layers): ModuleList(
    (0-5): 6 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=320, out_features=320, bias=True)
        (v_proj): Linear(in_features=320, out_features=320, bias=True)
        (q_proj): Linear(in_features=320, out_features=320, bias=True)
        (out_proj): Linear(in_features=320, out_features=320, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=320, out_features=1280, bias=True)
      (fc2): Linear(in_features=1280, out_features=320, bias=True)
      (final_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=120, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((320,), eps=1e-05, elementwis

### Computing global and Local embedding for Wild type TP53
Here input is the Wild type TP53 sequence and output is global and local meaning per residue embedding.

In [None]:
model.to(device)

ESM2(
  (embed_tokens): Embedding(33, 320, padding_idx=1)
  (layers): ModuleList(
    (0-5): 6 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=320, out_features=320, bias=True)
        (v_proj): Linear(in_features=320, out_features=320, bias=True)
        (q_proj): Linear(in_features=320, out_features=320, bias=True)
        (out_proj): Linear(in_features=320, out_features=320, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=320, out_features=1280, bias=True)
      (fc2): Linear(in_features=1280, out_features=320, bias=True)
      (final_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=120, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((320,), eps=1e-05, elementwis

In [None]:
wt_global_embedding ,wt_local_embedding= extract_esm_embeddings_batch(
    model=model,
    batch_converter=batch_converter,
    sequences=wt_input,
    batch_size=1,
    return_per_residue=True,
    device=device
)# shape: (D,) , shape :(D,320)


In [None]:
wt_global_embedding.shape

torch.Size([1, 320])

In [None]:
wt_local_embedding[0].shape

torch.Size([393, 320])

### Compute global and local  embedding for each  mutant sequence

Here input is each possible missense mutation and output is global and local embeddings for each mutant sequence

In [None]:
esm_input[:1]

[('TP53_M1A',
  'AEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD')]

In [None]:
mutant_global_embeddings,mutant_local_embeddings = extract_esm_embeddings_batch(
    model=model,
    batch_converter=batch_converter,
    sequences=esm_input,
    batch_size=4,           # adjust based on GPU memory
    return_per_residue=True,
    device=device
)


In [None]:
print(len(mutant_local_embeddings))
print(mutant_global_embeddings.shape)

7467
torch.Size([7467, 320])


In [None]:
len(TP53)

393

In [None]:
mutant_global_embeddings.shape

torch.Size([7467, 320])

In [None]:
mutant_local_embeddings[0].shape

torch.Size([393, 320])

### Making a df for unsupervised Learning

Here we are taking the global embedding for the wild type and mutant sequences and putting that into a dataframe.
So that we can do PCA Later on.

In [None]:
wt_embedding_np = wt_global_embedding.detach().cpu().numpy()


In [None]:
import pandas as pd
import numpy as np

# ----------------------------
# 1. Stack embeddings together
# ----------------------------

# Combine WT + mutants
all_embeddings = np.vstack([
    wt_embedding_np,          # shape (320,)
    mutant_global_embeddings      # shape (N, 320)
])

# ----------------------------
# 2. Build mutation labels
# ----------------------------

all_mutation_names = [
    "WT_TP53"
] + list(mutations.keys())  # mutant_names must align with embedding order

# ----------------------------
# 3. Create column names
# ----------------------------

embedding_dim = all_embeddings.shape[1]
embedding_columns = [f"emb_{i}" for i in range(embedding_dim)]

# ----------------------------
# 4. Create DataFrame
# ----------------------------

df_embeddings = pd.DataFrame(
    all_embeddings,
    columns=embedding_columns
)

# Insert mutation name as first column
df_embeddings.insert(0, "mutation", all_mutation_names)

# ----------------------------
# 5. Inspect
# ----------------------------

print(df_embeddings.head())
print(df_embeddings.shape)


  mutation     emb_0     emb_1     emb_2     emb_3     emb_4     emb_5  \
0  WT_TP53 -0.314137  0.043923  0.284459  0.099692  0.609362  0.080997   
1      M1A -0.333014  0.001066  0.345158  0.118458  0.593779  0.114718   
2      M1C -0.342941  0.007248  0.355775  0.145803  0.610450  0.123007   
3      M1D -0.349876 -0.002811  0.349717  0.114419  0.600457  0.117712   
4      M1E -0.340403 -0.020392  0.370213  0.116997  0.610020  0.116472   

      emb_6     emb_7     emb_8  ...   emb_310   emb_311   emb_312   emb_313  \
0  0.173163  0.167333  1.102550  ...  0.058029 -0.203137  0.406664  1.190339   
1  0.185727  0.176038  1.103571  ...  0.077118 -0.172380  0.426602  1.178421   
2  0.170419  0.180405  1.134820  ...  0.080565 -0.177388  0.435609  1.171222   
3  0.192346  0.174532  1.101928  ...  0.074074 -0.162665  0.419469  1.174409   
4  0.196395  0.165416  1.102782  ...  0.090378 -0.153890  0.418086  1.171690   

    emb_314   emb_315   emb_316   emb_317   emb_318   emb_319  
0 -0.18194

### PCA
So for visualizing we need to reduce the dimension from 320 to 2
that is why we used PCA

In [None]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# -----------------------------------
# 1. Prepare features
# -----------------------------------

X = df_embeddings.drop(columns=["mutation"]).values

# Standardize (important for PCA)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# -----------------------------------
# 2. PCA to 2D
# -----------------------------------

pca = PCA(n_components=2, random_state=42)
X_pca = pca.fit_transform(X_scaled)



### Getting Known Data From Real tumor Patients

In [None]:
import pandas as pd


In [None]:
tumor_variant=pd.read_csv("TumorVariantDownload_r21.csv")

In [None]:
tumor_variant.head()

Unnamed: 0,Mutation_ID,MUT_ID,hg18_Chr17_coordinates,hg19_Chr17_coordinates,hg38_Chr17_coordinates,ExonIntron,Codon_number,Description,c_description,g_description,...,SpliceAI_DP_AG,SpliceAI_DP_AL,SpliceAI_DP_DG,SpliceAI_DP_DL,Ref_ID,Cross_Ref_ID,PubMed,Exclude_analysis,WGS_WXS,PubMedLink
0,1,1590,7519227,7578502,7675184,5-exon,143,T>C,c.428T>C,g.7578502A>G,...,884.0,-230.0,1010.0,-330.0,1,,2649981.0,False,No,https://www.ncbi.nlm.nih.gov/pubmed/0002649981
1,2,2143,7519131,7578406,7675088,5-exon,175,G>A,c.524G>A,g.7578406C>T,...,-117.0,-134.0,11.0,248.0,1,,2649981.0,False,No,https://www.ncbi.nlm.nih.gov/pubmed/0002649981
2,3,1407,7519261,7578536,7675218,5-exon,132,A>C,c.394A>C,g.7578536T>G,...,-3.0,18.0,-364.0,118.0,2,,1694291.0,False,No,https://www.ncbi.nlm.nih.gov/pubmed/0001694291
3,4,3932,7517810,7577085,7673767,8-exon,285,G>A,c.853G>A,g.7577085C>T,...,-428.0,46.0,-560.0,-66.0,2,,1694291.0,False,No,https://www.ncbi.nlm.nih.gov/pubmed/0001694291
4,5,3326,7518259,7577534,7674216,7-exon,249,G>C,c.747G>C,g.7577534C>G,...,-877.0,27.0,1669.0,-993.0,2,,1694291.0,False,No,https://www.ncbi.nlm.nih.gov/pubmed/0001694291


In [None]:
def clean_tumor_df(df_tumor):
    """
    Cleans the tumor dataset so ProtDescription matches AAchange in functional dataset
    """
    df = df_tumor.copy()

    # Remove whitespace
    df['ProtDescription'] = df['ProtDescription'].str.strip()

    # Remove 'p.' prefix if present
    df['ProtDescription'] = df['ProtDescription'].apply(lambda x: x.split('.')[-1] if isinstance(x, str) else x)

    return df


In [None]:
df_tumor_clean = clean_tumor_df(tumor_variant)


In [None]:
import re
def get_driver_missense_mutations(driver_list_set):
  # Filter only single amino acid substitutions
  driver_missense_mutations = [m for m in driver_list_set if re.match(r'^[A-Z]\d+[A-Z]$', m)]

  return driver_missense_mutations


## Doing PCA in the functional and Non functional Dataset

In [None]:
# Separate functional and non-functional variants
df_functional = df_tumor_clean[df_tumor_clean['StructureFunctionClass'] == 'functional']
df_nonfunctional = df_tumor_clean[df_tumor_clean['StructureFunctionClass'] == 'non-functional']


In [None]:
functional_mutations = df_functional['ProtDescription'].dropna().tolist()
nonfunctional_mutations = df_nonfunctional['ProtDescription'].dropna().tolist()


In [None]:
functional_mutations[:10]

['M246V',
 'M246V',
 'R175S',
 'A138V',
 'A129D',
 'N239D',
 'Q144P',
 'P177R',
 'S260A',
 'D259G']

In [None]:
nonfunctional_mutations[:10]

['V143A',
 'R175H',
 'K132Q',
 'E285K',
 'R249S',
 'R280K',
 'R249S',
 'R249S',
 'R249S',
 'V157F']

In [None]:
functional_mutations_set = set(functional_mutations)
nonfunctional_mutations_set = set(nonfunctional_mutations)

In [None]:
common_set=functional_mutations_set.intersection(nonfunctional_mutations_set)
common_set

set()

The common set being an empty set indicates that the same mutation cannot be functional and non functional at the same time in two different patients

In [None]:
functional_missense = get_driver_missense_mutations(functional_mutations_set)
nonfunctional_missense = get_driver_missense_mutations(nonfunctional_mutations_set)

print("Functional missense mutations:", len(functional_missense))
print("Non-functional missense mutations:", len(nonfunctional_missense))


Functional missense mutations: 419
Non-functional missense mutations: 554


Above we have seperated the dataset based on Functional and non functional mutation from the tumor Dataset.
Now here we are taking two more groups based on either the mutation is on dna binding domain or not.

In [None]:
#PCA plotting based on position of mutation
mutation_list=list(mutations.keys())
position_dna_binding=[]
non_position_dna_binding=[]
for mutation in mutation_list:
    pos=int(mutation[1:-1])
    if pos>101 and pos<293:
        position_dna_binding.append(mutation)
    else:
        non_position_dna_binding.append(mutation)

In [57]:
import plotly.express as px
import pandas as pd

# -----------------------------------
# Prepare plotting DataFrame
# -----------------------------------

plot_df = pd.DataFrame({
    "PC1": X_pca[:, 0],
    "PC2": X_pca[:, 1],
    "mutation": df_embeddings["mutation"]
})

# Convert mutation lists to sets for fast lookup
functional_set = set(functional_missense)
nonfunctional_set = set(nonfunctional_missense)

# Default label
plot_df["type"] = "Other mutations"

# Assign classes
plot_df.loc[plot_df["mutation"].isin(functional_set), "type"] = "Functional missense"
plot_df.loc[plot_df["mutation"].isin(nonfunctional_set), "type"] = "Non-functional missense"
plot_df.loc[plot_df["mutation"] == "WT_TP53", "type"] = "WT TP53"
plot_df.loc[plot_df["mutation"].isin(non_position_dna_binding), "type"] = "Non-DNA binding position"

# -----------------------------------
# Plot
# -----------------------------------

fig = px.scatter(
    plot_df,
    x="PC1",
    y="PC2",
    color="type",
    hover_name="mutation",
    color_discrete_map={
        "Other mutations": "lightgray",
        "Functional missense": "green",
        "Non-functional missense": "orange",
        "WT TP53": "red",
        "Non-DNA binding position": "blue"
    },
    title="PCA of TP53 Mutant ESM Embeddings (Tumor Structure–Function Annotation)"
)

# Make WT bigger and star-shaped
fig.update_traces(
    selector=dict(name="WT TP53"),
    marker=dict(size=16, symbol="star")
)

fig.update_layout(
    width=900,
    height=650,
    legend_title_text="Mutation class",
    xaxis_title="PC1",
    yaxis_title="PC2"
)

fig.show()


### Final Key Insights

Key Insights from TP53 Mutation PCA Analysis:

Initial intuition: The general sense that comes into mind that mutations that are non functional will be far from TP53 and functional mutations will be close to wild type TP53

Observation: Most  mutations found in cancer dataset either functional or non functional  cluster close to wild-type TP53.

PC1 vs PC2 hypothesis:

Domain-specific variation:

Mutations that happened in the DNA-binding domain vary mainly along PC1, with little change in PC2.

Mutations that happened outside the DNA-binding domain vary more along PC2, with PC1 relatively stable.
