<a href="https://colab.research.google.com/github/shilpasy/Variant-Prioritization-miniproject-with-AlphaMissense-cBioPortal-and-ESM-models/blob/main/3_ESM_StructuralFeatureAugmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install fair-esm biopython

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Collecting biopython
  Downloading biopython-1.85-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading biopython-1.85-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m48.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm, biopython
Successfully installed biopython-1.85 fair-esm-2.0.0


In [None]:
from google.colab import drive
import pandas as pd
import gzip
import os

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from esm import pretrained, FastaBatchedDataset, ProteinBertModel

In [None]:
# ESM2 embeddings for TP53 variants

import re
import math
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import esm

# all paths:
CSV_PATH = "/content/drive/MyDrive/AlphaMissense_ex/TP53_cbioportal_mutations_annotated_with_AlphaMissense.csv"  # merged file
OUTPUT_CSV = "/content/drive/MyDrive/AlphaMissense_ex/TP53_ESM2_embeddings.csv"
USE_CLASSES = ["pathogenic"]  #focusing only on pathogenic class for now
BATCH_SIZE = 64  # testing with 64

# canonical TP53 (UniProt P04637) # we can take any other sequence of interest if needed
tp53_seq = (
    "MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGP"
    "DEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKS"
    "VSTSGEYRHVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIIITLEDSSGNLLG"
    "LPCVYIWGDYLPQEEQELREVAPRCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVV"
    "PYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIIITLEDSSGNLLGLPCVYIWGDYLP"
    "QEEQELREVAP"
)
SEQ_LEN = len(tp53_seq)

df = pd.read_csv(CSV_PATH)

# a. filter by AlphaMissense class (pathogenic here)
df = df[df["AlphaMissense_Class"].isin(USE_CLASSES)].copy()

# b. ensure required columns exist
required_cols = ["Variant", "WT_AA", "Position", "Mut_AA"]
missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise ValueError(f"Missing required columns: {missing}")

# c. clean Position properly (the old .isdigit() dropped 175.0 etc.) so modifying and checking
df["Position"] = pd.to_numeric(df["Position"], errors="coerce")

# d. validate Mut_AA: single AA letter in the 20 standard set
aa_re = re.compile(r"^[ACDEFGHIKLMNPQRSTVWY]$")
df = df[df["Mut_AA"].astype(str).str.fullmatch(aa_re)]

# e. keep valid positions only
df = df[df["Position"].notna()] #removing na ones
df["Position"] = df["Position"].astype(int)
df = df[(df["Position"] >= 1) & (df["Position"] <= SEQ_LEN)] # those with valid sequences

# f. deduplicate variants (Position, Mut_AA)
df = df.drop_duplicates(subset=["Position", "Mut_AA"]).copy()

print("Counts after cleaning:")
print({
    "rows_after_class_filter": int(df.shape[0]),
    "unique_variants": int(df[["Position","Mut_AA"]].drop_duplicates().shape[0])
})

Counts after cleaning:
{'rows_after_class_filter': 37, 'unique_variants': 37}


In [None]:
# ++++ ============ BUILD MUTATED SEQUENCES ====================================================
def mutate_seq(seq: str, pos: int, new_aa: str) -> str:
    i = pos - 1
    if 0 <= i < len(seq):
        return seq[:i] + new_aa + seq[i+1:]
    return None

df["Mutated_Seq"] = df.apply(lambda r: mutate_seq(tp53_seq, r["Position"], r["Mut_AA"]), axis=1)
df = df[df["Mutated_Seq"].notna()].copy()

# label fallback if Variant is missing / not string
def make_label(row):
    v = row.get("Variant")
    if isinstance(v, str) and len(v) > 0:
        return v if v.startswith("p.") else f"p.{v}"
    return f"p.{row['WT_AA']}{row['Position']}{row['Mut_AA']}"

df["Label"] = df.apply(make_label, axis=1)

print("Example variants to embed:", df["Label"].head(5).tolist())
print("Total variants to embed:", df.shape[0])

# ++++ ====== LOAD ESM2 MODEL =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"using device: {device}")

# alternative: switch to esm2_t6_8M_UR50D() for speed if needed and if GPU isn't available
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model = model.to(device).eval()
batch_converter = alphabet.get_batch_converter()

# ++++ ====== EMBEDDING IN BATCHES ===============================================================================
# we'll also include WT at the top
labels = ["WT"] + df["Label"].tolist()
seqs   = [tp53_seq] + df["Mutated_Seq"].tolist()

emb_rows = []
total = len(seqs)
num_batches = math.ceil(total / BATCH_SIZE)

with torch.no_grad():
    pbar = tqdm(range(num_batches), desc="Embedding batches")
    for b in pbar:
        start = b * BATCH_SIZE
        end = min(total, start + BATCH_SIZE)
        data_batch = list(zip(labels[start:end], seqs[start:end]))
        batch_labels, batch_strs, batch_tokens = batch_converter(data_batch)
        batch_tokens = batch_tokens.to(device)

        out = model(batch_tokens, repr_layers=[33], return_contacts=False)
        reps = out["representations"][33]  # shape: [batch, L, 1280]

        for i in range(len(data_batch)):
            L = len(batch_strs[i])
            # exclude special tokens [CLS], [EOS]
            vec = reps[i, 1:L+1].mean(0).detach().cpu().numpy()  # (1280,)
            emb_rows.append({"Variant": batch_labels[i], **{f"ESM2_{j}": vec[j] for j in range(vec.shape[0])}})

emb_df = pd.DataFrame(emb_rows)

# QC
print("Embedding dataframe shape:", emb_df.shape)
print("First rows:", emb_df.head(3))

emb_df.to_csv(OUTPUT_CSV, index=False)
print(f"✅ Saved embeddings to {OUTPUT_CSV}")

Example variants to embed: ['p.F113S', 'p.R273H', 'p.R273C', 'p.F109S', 'p.R273P']
Total variants to embed: 37
using device: cuda
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


Embedding batches: 100%|██████████| 1/1 [00:05<00:00,  5.51s/it]


Embedding dataframe shape: (38, 1281)
First rows:    Variant    ESM2_0    ESM2_1    ESM2_2    ESM2_3    ESM2_4    ESM2_5  \
0       WT  0.117823  0.033276  0.011216  0.063452 -0.001315 -0.046198   
1  p.F113S  0.119970  0.035833  0.012323  0.062180 -0.001641 -0.047405   
2  p.R273H  0.114494  0.033331  0.010980  0.056251 -0.003338 -0.047012   

     ESM2_6    ESM2_7    ESM2_8  ...  ESM2_1270  ESM2_1271  ESM2_1272  \
0  0.118452  0.017319 -0.026738  ...   0.003391   0.035751  -0.125065   
1  0.119121  0.017385 -0.027567  ...   0.002429   0.037539  -0.127980   
2  0.114071  0.025882 -0.029230  ...   0.007250   0.039530  -0.123865   

   ESM2_1273  ESM2_1274  ESM2_1275  ESM2_1276  ESM2_1277  ESM2_1278  ESM2_1279  
0   0.108309   0.056459  -0.070490   0.077801  -0.194841   0.017830   0.121030  
1   0.109221   0.055262  -0.068929   0.071980  -0.197760   0.013072   0.119775  
2   0.111462   0.047580  -0.060999   0.080080  -0.191880   0.017989   0.114437  

[3 rows x 1281 columns]
✅ Saved emb

In [None]:
emb_df

Unnamed: 0,Variant,ESM2_0,ESM2_1,ESM2_2,ESM2_3,ESM2_4,ESM2_5,ESM2_6,ESM2_7,ESM2_8,...,ESM2_1270,ESM2_1271,ESM2_1272,ESM2_1273,ESM2_1274,ESM2_1275,ESM2_1276,ESM2_1277,ESM2_1278,ESM2_1279
0,WT,0.117823,0.033276,0.011216,0.063452,-0.001315,-0.046198,0.118452,0.017319,-0.026738,...,0.003391,0.035751,-0.125065,0.108309,0.056459,-0.07049,0.077801,-0.194841,0.01783,0.12103
1,p.F113S,0.11997,0.035833,0.012323,0.06218,-0.001641,-0.047405,0.119121,0.017385,-0.027567,...,0.002429,0.037539,-0.12798,0.109221,0.055262,-0.068929,0.07198,-0.19776,0.013072,0.119775
2,p.R273H,0.114494,0.033331,0.01098,0.056251,-0.003338,-0.047012,0.114071,0.025882,-0.02923,...,0.00725,0.03953,-0.123865,0.111462,0.04758,-0.060999,0.08008,-0.19188,0.017989,0.114437
3,p.R273C,0.119631,0.038121,0.017878,0.061448,-0.00541,-0.041818,0.11617,0.029576,-0.027555,...,0.0045,0.03745,-0.116961,0.119463,0.050261,-0.060435,0.075729,-0.194225,0.013341,0.119889
4,p.F109S,0.123277,0.037602,0.012639,0.062221,-0.004402,-0.044994,0.118047,0.020354,-0.028258,...,0.004092,0.036882,-0.129564,0.109863,0.055471,-0.064717,0.067949,-0.1952,0.007696,0.124223
5,p.R273P,0.114667,0.035666,0.007555,0.054972,0.00134,-0.044975,0.114528,0.03823,-0.032562,...,0.013517,0.044264,-0.11054,0.113729,0.05221,-0.066808,0.087389,-0.191409,0.02352,0.11464
6,p.R273S,0.116318,0.036407,0.01424,0.060188,-0.000339,-0.042378,0.117059,0.028503,-0.030541,...,0.013499,0.042539,-0.113416,0.114016,0.048483,-0.066858,0.086593,-0.188832,0.019653,0.116575
7,p.F113V,0.118094,0.033001,0.011668,0.060936,-0.002434,-0.047579,0.117826,0.014895,-0.029416,...,0.004203,0.035511,-0.125446,0.107346,0.055783,-0.070857,0.077019,-0.196976,0.019702,0.123557
8,p.G105D,0.119236,0.035956,0.01222,0.065286,0.000161,-0.044589,0.117668,0.018728,-0.030071,...,0.005944,0.036222,-0.125381,0.107752,0.056022,-0.072128,0.076622,-0.196094,0.017385,0.121666
9,p.R110P,0.122062,0.036119,0.010729,0.062678,-0.00185,-0.044185,0.117687,0.022739,-0.025019,...,0.003435,0.040866,-0.127251,0.109604,0.056325,-0.066895,0.075945,-0.196724,0.011202,0.122828


In [None]:
df = pd.read_csv(CSV_PATH)
df["AlphaMissense_Class"].value_counts()

Unnamed: 0_level_0,count
AlphaMissense_Class,Unnamed: 1_level_1
pathogenic,403
benign,39
ambiguous,1


In [None]:
df = df[df["AlphaMissense_Class"].isin(USE_CLASSES)].copy()

In [None]:
df["Position"]

Unnamed: 0,Position
0,
1,
2,
3,
4,
...,...
3149,
3150,59.0
3151,
3152,


In [None]:
df["AlphaMissense_Class"].value_counts()

Unnamed: 0_level_0,count
AlphaMissense_Class,Unnamed: 1_level_1
pathogenic,403
