In [None]:
from transformers import AutoTokenizer, EsmModel
import torch
import pandas as pd
import esm
from tqdm import tqdm
from joblib import Memory
from functools import lru_cache

tqdm.pandas()
memory = Memory('../data/colab_cache', verbose=False)

In [None]:
df = pd.read_csv("../data/training_set_20230316.tsv.gz", sep="\t")
df.head()

Unnamed: 0,VARIANTKEY,LABEL,ENSG,GENE_SYMBOL,AA_POSITION,PROTEIN_REF,PROTEIN_ALT
0,1-100196274-A-C,LOF,ENSG00000137992,DBT,477,MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...,MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...
1,1-100196286-T-C,NEUTRAL,ENSG00000137992,DBT,473,MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...,MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...
2,1-100196349-T-C,LOF,ENSG00000137992,DBT,452,MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...,MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...
3,1-100206470-G-A,LOF,ENSG00000137992,DBT,395,MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...,MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...
4,1-100206621-C-T,LOF,ENSG00000137992,DBT,345,MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...,MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
# Load the pre-trained ESM-1v model
esm_1v, alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1()
esm_1v = esm_1v.eval()
esm_1v.to(device)



ProteinBertModel(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (embed_positions): LearnedPositionalEmbedding(1026, 1280, padding_idx=1)


In [None]:
def get_esm1v_window_embedding(protein_sequence, aa_position):
    protein_sequence = get_protein_window(protein_sequence, aa_position, 1023)
    return get_esm1v_embedding(protein_sequence, aa_position)


def get_protein_window(protein_sequence, aa_position, window_size):
    max_length = 1023
    half_length = max_length // 2

    # Calculate the start and end positions for cropping
    start = max(0, aa_position - half_length)
    end = min(len(protein_sequence), aa_position + half_length)

    # Crop the protein sequence
    return protein_sequence[start:end]


def get_esm1v_embedding(protein_sequence, aa_position):
    batch_converter = alphabet.get_batch_converter()
    data = [("protein", protein_sequence)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
        results = esm_1v(batch_tokens, repr_layers=[33])
    token_representations = results["representations"][33]
    sequence_representation = token_representations[0, 1:-1].mean(0)
    return sequence_representation.cpu().numpy()

# Get protein embeddings for REF and ALT columns
df["REF_EMBEDDING_ESM1v"] = df.progress_apply(lambda row: get_esm1v_window_embedding(row["PROTEIN_REF"], row["AA_POSITION"]), axis=1)
df["ALT_EMBEDDING_ESM1v"] = df.progress_apply(lambda row: get_esm1v_window_embedding(row["PROTEIN_ALT"], row["AA_POSITION"]), axis=1)

# Print the updated DataFrame
print(df.head())

100%|██████████| 112437/112437 [2:53:05<00:00, 10.83it/s]
100%|██████████| 112437/112437 [2:53:11<00:00, 10.82it/s]

        VARIANTKEY    LABEL             ENSG GENE_SYMBOL  AA_POSITION  \
0  1-100196274-A-C      LOF  ENSG00000137992         DBT          477   
1  1-100196286-T-C  NEUTRAL  ENSG00000137992         DBT          473   
2  1-100196349-T-C      LOF  ENSG00000137992         DBT          452   
3  1-100206470-G-A      LOF  ENSG00000137992         DBT          395   
4  1-100206621-C-T      LOF  ENSG00000137992         DBT          345   

                                         PROTEIN_REF  \
0  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
1  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
2  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
3  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
4  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   

                                         PROTEIN_ALT  \
0  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
1  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
2  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   





In [None]:
df.to_pickle("../data/esm_v1_embeddings.pkl")

In [None]:
MODEL = "facebook/esm2_t33_650M_UR50D"
esm2_tokenizer = AutoTokenizer.from_pretrained(MODEL)
esm2 = EsmModel.from_pretrained(MODEL)
esm2.to(device)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 1280, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
    (position_embeddings): Embedding(1026, 1280, padding_idx=1)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-32): 33 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=1280, out_features=1280, bias=True)
            (key): Linear(in_features=1280, out_features=1280, bias=True)
            (value): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmInter

In [None]:
# Define layer to be used
layer = esm2.config.num_hidden_layers - 1

# Define a helper function for hidden states
@lru_cache(maxsize=10000)
def get_esm2_embedding(sequence):
    esm2.config.output_hidden_states = True
    encoded_input = esm2_tokenizer([sequence], return_tensors='pt', padding=True, truncation=True, max_length=1024)
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}  # Move input tensors to GPU
    with torch.no_grad():
        model_output = esm2(**encoded_input)
    hidden_states = model_output.hidden_states
    specific_hidden_states = hidden_states[layer][0]
    return specific_hidden_states.cpu().numpy().mean(axis=0)

# Get protein embeddings for REF and ALT columns
df["REF_EMBEDDING_ESM2"] = df.PROTEIN_REF.progress_apply(get_esm2_embedding)
df["ALT_EMBEDDING_ESM2"] = df.PROTEIN_ALT.progress_apply(get_esm2_embedding)

# Print the updated DataFrame
print(df.head())

100%|██████████| 112437/112437 [04:24<00:00, 424.67it/s]
100%|██████████| 112437/112437 [3:17:29<00:00,  9.49it/s]

        VARIANTKEY    LABEL             ENSG GENE_SYMBOL  AA_POSITION  \
0  1-100196274-A-C      LOF  ENSG00000137992         DBT          477   
1  1-100196286-T-C  NEUTRAL  ENSG00000137992         DBT          473   
2  1-100196349-T-C      LOF  ENSG00000137992         DBT          452   
3  1-100206470-G-A      LOF  ENSG00000137992         DBT          395   
4  1-100206621-C-T      LOF  ENSG00000137992         DBT          345   

                                         PROTEIN_REF  \
0  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
1  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
2  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
3  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
4  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   

                                         PROTEIN_ALT  \
0  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
1  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   
2  MAAVRMLRTWSRNAGKLICVRYFQTCGNVHVLKPNYVCFFGYPSFK...   





In [None]:
print(len(df.loc[42165].PROTEIN_ALT))
print(len(df.loc[42165].ALT_EMBEDDING_ESM2))

411
1280


In [None]:
df.to_pickle("../data/esm2_embeddings.pkl")