In [1]:
import pandas as pd
import pandas as pd
import esm
import numpy as np
import torch
from tqdm import tqdm

df = pd.read_csv('/home/musong/Desktop/paper/data/drugfinder/esm2_320_dimensions_with_labels.csv')

df.columns

Index(['UniProt_id', '1', '2', '3', '4', '5', '6', '7', '8', '9',
       ...
       '312', '313', '314', '315', '316', '317', '318', '319', '320', 'label'],
      dtype='object', length=322)

In [2]:
def esm_embeddings(peptide_sequence_list):
    model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    batch_converter = alphabet.get_batch_converter()
    model.eval()  # disables dropout for deterministic results
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    batch_labels, batch_strs, batch_tokens = batch_converter(peptide_sequence_list)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
    ## batch tokens are the embedding results of the whole data set
    batch_tokens = batch_tokens.to(device)
    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        # Here we export the last layer of the EMS model output as the representation of the peptides
        # model'esm2_t6_8M_UR50D' only has 6 layers, and therefore repr_layers parameters is equal to 6
        results = model(batch_tokens, repr_layers=[6], return_contacts=True)  
    token_representations = results["representations"][6]

    # Generate per-sequence representations via averaging
    # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        sequence_representations.append((peptide_sequence_list[i][0], token_representations[i, 1 : tokens_len - 1].mean(0)))
    return sequence_representations

In [3]:
a = 'P50579'
b = """
MAGVEEVAASGSHLNGDLDPDDREEGAASTAEEAAKKKRRKKKKSKGPSAAGEQEPDKESGASVDEVARQLERSALEDKERDEDDEDGDGDGDGATGKKKKKKKKKRGPKVQTDPPSVPICDLYPNGVFPKGQECEYPPTQDGRTAAWRTTSEEKKALDQASEEIWNDFREAAEAHRQVRKYVMSWIKPGMTMIEICEKLEDCSRKLIKENGLNAGLAFPTGCSLNNCAAHYTPNAGDTTVLQYDDICKIDFGTHISGRIIDCAFTVTFNPKYDTLLKAVKDATNTGIKCAGIDVRLCDVGEAIQEVMESYEVEIDGKTYQVKPIRNLNGHSIGQYRIHAGKTVPIVKGGEATRMEEGEVYAIETFGSTGKGVVHDDMECSHYMKNFDVGHVPIRLPRTKHLLNVINENFGTLAFCRRWLDRLGESKYLMALKNLCDLGIVDPYPPLCDIKGSYTAQFEHTILLRPTCKEVVSRGDDY
"""
esm_embeddings([(a, b)])

[('P50579',
  tensor([ 2.0595e-02, -5.2397e-02,  7.2424e-02,  2.4903e-01,  1.3683e-01,
          -9.5064e-03,  7.8911e-02, -1.1505e-01, -7.1237e-02, -9.5962e-02,
           1.1693e-03,  1.7169e-02, -1.6200e-01,  8.7022e-02,  1.3806e-01,
          -1.6597e-01, -1.1243e-02, -1.6501e-01,  4.9552e-02,  1.8117e-01,
           1.8462e-02, -2.2816e-02, -4.0338e-02, -1.8552e-01,  1.2977e-01,
           1.6385e-01, -1.7089e-02,  4.1195e-02, -2.1177e-02,  8.2126e-02,
          -3.0474e-02,  1.2641e-01,  3.9872e-02, -2.6634e-01, -2.6091e-02,
          -7.0616e-02,  2.6442e-02,  3.9284e-02,  1.6439e-01,  1.3452e-01,
           7.6915e-02,  2.6355e-01, -1.4124e-01,  3.5976e-02, -6.2721e-02,
           7.8290e-02, -1.1267e+00,  9.4312e-02, -2.9280e-01, -7.1302e-02,
           4.6619e-02, -3.7156e-02,  4.1510e-02,  6.6509e-02, -4.5167e-02,
          -1.5195e-01,  2.0540e-01, -3.1420e-02,  4.3970e-01, -2.0260e-02,
          -4.1587e-03,  9.1296e-03, -4.5462e+00, -1.4220e-01, -1.8296e-02,
           3.