In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

In [None]:
df_detect_peptide_train = pd.read_csv('../data/df_detect_peptide_train.csv')
test = pd.read_csv('../data/df_detect_peptide_test.csv')
train, val = train_test_split(df_detect_peptide_train, test_size=0.2, random_state=7)

In [39]:
df = pd.concat([train, val, test], axis=0).reset_index(drop=True)

train_idx = df.iloc[:len(train), :].index
val_idx = df.iloc[len(train):len(train)+len(val), :].index
test_idx = df.iloc[len(train)+len(val):, :].index

In [5]:
df_prot = pd.read_csv('../data/uniprot/df_uni.csv')

prot2seq = {k:v for k, v in df_prot[['PROTEIN', 'SEQUENCE']].values}

# protein embedding

In [7]:
w = open('../data/uniprot/targetProt.fasta', 'w')
for p in df.protein.unique():
    w.write('>'+p+'\n')
    w.write(prot2seq[p]+'\n')

In [8]:
p2seq = {p:prot2seq[p] for p in df.protein.unique()}

In [None]:
# python extract.py esm1b_t33_650M_UR50S /home/bis/2021_SJH_detectability/Detectability/data/uniprot/targetProt.fasta /home/bis/2021_SJH_detectability/Detectability/data/ProtStructureEmbedding_emb_esm1b/ --repr_layers 0 32 33 --include mean per_tok

In [14]:
print('# of portein (lenght > 1025) : ', sum([1 for _ in p2seq.values() if len(_)>1025]))

# of portein (lenght > 1025) :  1076


# peptide embedding

In [60]:
df.head(1)

Unnamed: 0,peptide,En,Ec,E1,E2,protein,PEP,ID
0,K.QELNEPPKQSTSFLVLQEILESEEKGDPNK.P,VYKMLQEKQELNEPP,EEKGDPNKPSGFRSV,QELNEPPKQSTSFLV,EILESEEKGDPNKPS,sp|O00151|PDLI1_HUMAN,QELNEPPKQSTSFLVLQEILESEEKGDPNK,0


In [61]:
w = open('../data/ProtStructureEmbedding_emb_esm1b/fasta/En.fasta', 'w')
for idx, p in zip(df.index, df.En.values):
    w.write('>'+str(idx)+'\n')
    w.write(p+'\n')

In [62]:
w = open('../data/ProtStructureEmbedding_emb_esm1b/fasta/Ec.fasta', 'w')
for idx, p in zip(df.index, df.Ec.values):
    w.write('>'+str(idx)+'\n')
    w.write(p+'\n')

In [63]:
w = open('../data/ProtStructureEmbedding_emb_esm1b/fasta/E1.fasta', 'w')
for idx, p in zip(df.index, df.E1.values):
    w.write('>'+str(idx)+'\n')
    w.write(p+'\n')

In [64]:
w = open('../data/ProtStructureEmbedding_emb_esm1b/fasta/E2.fasta', 'w')
for idx, p in zip(df.index, df.E2.values):
    w.write('>'+str(idx)+'\n')
    w.write(p+'\n')

In [65]:
w = open('../data/ProtStructureEmbedding_emb_esm1b/fasta/PEP.fasta', 'w')
for idx, p in zip(df.index, df.PEP.values):
    w.write('>'+str(idx)+'\n')
    w.write(p+'\n')

In [134]:
label_path = '/data/211129_SJH_ESM/ProtStructureEmbedding_emb_esm1b/LABEL/'
for idx, lab in zip(df.index, df.ID.values):
    open(label_path+str(idx), 'w').write(str(lab))

# numpy

In [69]:
import sys
PATH_TO_REPO = "/home/bis/2021_AIhub/esm/"
sys.path.append(PATH_TO_REPO)

import torch
import esm

In [70]:
FASTA_PATH = '/data/211129_SJH_ESM/ProtStructureEmbedding_emb_esm1b/fasta/PEP.fasta'
EMB_PATH = '/data/211129_SJH_ESM/ProtStructureEmbedding_emb_esm1b/PEP/'  # .pt (63GB -> zeropad:140GB)
EMB_LAYER = 33

In [139]:
aa2vec = dict()
for header, seq in esm.data.read_fasta(FASTA_PATH):
    idx = header.split('>')[1]
    print(idx)
    if idx == str(1):
        break
    fn = f'{EMB_PATH}{idx}.pt'
    embs = torch.load(fn)['representations'][EMB_LAYER]
    
    # zero padding on top
    zp = torch.nn.ZeroPad2d((0, 0, 30-len(embs), 0))
    embs_pad = zp(embs).numpy()
    aa2vec[idx] = embs_pad

0
1


In [140]:
aa2vec

{'0': array([[ 0.08528218,  0.15698674, -0.00476043, ..., -0.28082493,
         -0.17248774, -0.25759438],
        [ 0.2589074 , -0.02770338,  0.04186126, ..., -0.37686828,
          0.01822023,  0.09063172],
        [ 0.10234036,  0.10289276, -0.14009404, ..., -0.1643274 ,
         -0.02725955,  0.14206922],
        ...,
        [ 0.07254714,  0.11588723, -0.08704348, ..., -0.13625613,
          0.05028485,  0.0671237 ],
        [ 0.22895892,  0.04845764,  0.06295758, ..., -0.14954534,
         -0.0281459 ,  0.10974251],
        [ 0.23911606, -0.11963475,  0.03793335, ..., -0.09492811,
          0.21702647,  0.11332861]], dtype=float32)}

In [141]:
df.index

RangeIndex(start=0, stop=813388, step=1)

In [146]:
vecs_dir = '/data/211129_SJH_ESM/ProtStructureEmbedding_emb_esm1b/'
EMB_LAYER = 33

import time
s= time.time()

for idx, id_num in enumerate(df.index):
    if idx % 1000 == 0:
        print(idx, round(time.time()-s,2), 'sec', end='\r')
    pep_path = vecs_dir + 'PEP/'
    en_path = vecs_dir + 'En/'
    ec_path = vecs_dir + 'Ec/'
    m1_path = vecs_dir + 'E1/'
    m2_path = vecs_dir + 'E2/'

    id_name = str(id_num) + '.pt'
    pep_fn = f'{pep_path}{id_name}'  # vector 1개 경로
    en_fn = f'{en_path}{id_name}'
    ec_fn = f'{ec_path}{id_name}'
    m1_fn = f'{m1_path}{id_name}'
    m2_fn = f'{m2_path}{id_name}'

    embs = torch.load(pep_fn)['representations'][EMB_LAYER]
    pep_zp = torch.nn.ZeroPad2d((0, 0, 30-len(embs), 0))  # zero padding on top
    pep_embed = pep_zp(embs).numpy()
    en_embed = torch.load(en_fn)['representations'][EMB_LAYER].numpy()
    ec_embed = torch.load(ec_fn)['representations'][EMB_LAYER].numpy()
    m1_embed = torch.load(m1_fn)['representations'][EMB_LAYER].numpy()
    if len(m1_embed)==1:
        m1_embed = np.zeros((15, 1280))
    m2_embed = torch.load(m2_fn)['representations'][EMB_LAYER].numpy()
    if len(m2_embed)==1:
        m2_embed = np.zeros((15, 1280))

    id_name_save = str(id_num) + '.npy'
    pep_path_save = vecs_dir + 'PEPnpy/'
    en_path_save = vecs_dir + 'Ennpy/'
    ec_path_save = vecs_dir + 'Ecnpy/'
    m1_path_save = vecs_dir + 'E1npy/'
    m2_path_save = vecs_dir + 'E2npy/'
    np.save(f'{pep_path_save}{id_name_save}', pep_embed)
    np.save(f'{en_path_save}{id_name_save}', en_embed)
    np.save(f'{ec_path_save}{id_name_save}', ec_embed)
    np.save(f'{m1_path_save}{id_name_save}', m1_embed)
    np.save(f'{m2_path_save}{id_name_save}', m2_embed)

1000 7.73 sec

KeyboardInterrupt: 

In [147]:
7.73*(len(df)/1000)/3600

1.746524788888889