In [1]:
import pandas as pd

In [None]:
START_IDX = 0
END_IDX = 2000
MODEL_N_LAYERS = 33
BATCH_SIZE = 4

In [2]:
df = pd.read_parquet("../proteinclip/protclip_embed_dataset.parquet")
df.shape

(10000, 9)

In [5]:
# only emebed first 2000
df = df.iloc[:END_IDX]
df.head()

Unnamed: 0,organism,organism_id,name,evidence,function,id,textual_embedding,sequence,proteinclip_embed
0,Homo sapiens (Human),9606,EPHA7,1,Receptor tyrosine kinase which binds promiscuo...,Q15375,"[-0.015253728, 0.016237658, -0.016555615, 0.02...",MVFQTRYPSWIILCYIWLLRFAHTGEAQAAKEVLLLDSKAQQTELE...,"[0.16681893169879913, 0.105362169444561, 0.011..."
1,Homo sapiens (Human),9606,ANXA8,1,This protein is an anticoagulant protein that ...,P13928,"[-0.008352073, 0.00474287, 0.006541474, -0.002...",MAWWKSWIEQEGVTVKSSSHFNPDPDAETLYKAMKGIGTNEQAIID...,"[-0.10789338499307632, 0.06770769506692886, -0..."
2,Homo sapiens (Human),9606,DPY19L2P1,2,Probable C-mannosyltransferase that mediates C...,Q6NXN4,"[-0.00039709447, -0.02393247, -0.014100584, 0....",MKKQGVNPKPLQSSRPSPSKRPYGASPARELEVEKSALGGGKLPGG...,"[0.026129087433218956, 0.1801730990409851, 0.0..."
3,Homo sapiens (Human),9606,NR1D1,1,Transcriptional repressor which coordinates ci...,P20393,"[0.008172105, -0.0116752, -0.016805198, -0.005...",MTTLDSNNNTGGVITYIGSSGSSPSRTSPESLYSDNSNGSFQSLTQ...,"[-0.016533901914954185, -0.0384892001748085, 0..."
4,Homo sapiens (Human),9606,SLC15A2,1,Proton-coupled amino-acid transporter that tra...,Q16348,"[0.0031013805, -0.0019497981, -0.003831747, 0....",MNPFQKNESKETLFSPVSIEEVPPRPPSPPKKPSPTICGSNYPLSI...,"[0.08691893517971039, 0.07942728698253632, -0...."


In [7]:
id_sequence_pairs: list[tuple[str, str]] = list(df[["id", "sequence"]].itertuples(index=False, name=None))


In [26]:
ESM_CALLABLES: any = {
    48: esm.pretrained.esm2_t48_15B_UR50D,
    36: esm.pretrained.esm2_t36_3B_UR50D,
    33: esm.pretrained.esm2_t33_650M_UR50D,
    30: esm.pretrained.esm2_t30_150M_UR50D,
    12: esm.pretrained.esm2_t12_35M_UR50D,
    6: esm.pretrained.esm2_t6_8M_UR50D,
}


def get_model(model_size: int) -> any:
    """Return model and alphabet for a given model size."""
    model, alphabet = ESM_CALLABLES[model_size]()
    model.eval()
    return model, alphabet

In [31]:

import numpy as np
from torch.nn import functional as F
from tqdm.auto import tqdm

import torch
from torch import nn
import esm
DEVICE = "cuda"

batch_size = BATCH_SIZE
indices = np.arange(0, len(id_sequence_pairs), batch_size)

model, alphabet = get_model(MODEL_N_LAYERS)
batch_converter = alphabet.get_batch_converter()

labels = []
sequence_representations = []
m = model.to(DEVICE)
for start_idx in tqdm(indices):
    end_idx = start_idx + batch_size
    batch_labels, _batch_strs, batch_tokens = batch_converter(
        id_sequence_pairs[start_idx:end_idx]
    )
    labels.extend(batch_labels)

    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
    with torch.no_grad():
        results = m(
            batch_tokens.to(DEVICE), repr_layers=[33], return_contacts=False
        )
    token_representations = results["representations"][33]
    for i, tokens_len in enumerate(batch_lens):
        rep = token_representations[i, 1 : tokens_len - 1].cpu().numpy().mean(0)
        sequence_representations.append(rep)


  0%|          | 0/500 [00:00<?, ?it/s]

In [37]:
output = dict(zip(labels, sequence_representations))

In [39]:
len(labels), len(sequence_representations), len(output)

(2000, 2000, 2000)

In [40]:
import pickle
with open("../data/esm2/sequence_embedding_t33_0_2000.pkl", "wb") as f:
    pickle.dump(output, f)

In [41]:
output[list(output.keys())[0]]

array([ 0.00049137, -0.04332284, -0.01638226, ..., -0.10839974,
        0.02102051,  0.13916047], shape=(1280,), dtype=float32)