In [1]:
'''
Author: DengRui
Date: 2023-09-10 13:15:08
LastEditors: DengRui
LastEditTime: 2023-09-10 13:23:06
FilePath: /DeepSub/embedding/esm_embedding_esm2.ipynb
Description:  using esm2 embedding seqs
Copyright (c) 2023 by DengRui, All Rights Reserved. 
'''
import pandas as pd
import numpy as np
import esm
import torch
from tqdm import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#设置gpu
os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()
model = model.to(device)

In [4]:
def get_rep_seq(sequences):

    batch_labels, batch_strs, batch_tokens = batch_converter(sequences)
    batch_tokens = batch_tokens.to(device)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=False)
    token_representations = results["representations"][33]
    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        # average on the protein length, to obtain a single vector per fasta
        sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
       
    np_list = []
    # detach the tensors to obtain a numpy array
    for i, ten in enumerate(sequence_representations):
        ten=ten.cpu().detach().numpy()
        np_list.append(ten)
    res = pd.DataFrame(np_list)
    res.columns = ['f'+str(i) for i in range (0,res.shape[1])]
    return res
    

In [6]:

# 数据获取
dataset = pd.read_feather('../DATA/Dataset_0724.csv')
df_data = list(zip(dataset.uniprot_id.index,dataset.seq))

# 分批次运行
stride =2
num_iterations = len(df_data) // stride
if len(df_data) % stride != 0:
    num_iterations += 1
    
# 数据embedding
all_results = pd.DataFrame()

for i in tqdm(range(num_iterations)):
    # 计算当前循环处理的数据的开始和结束位置
    start = i * stride
    end = start + stride

    # 取出要处理的数据
    current_data = df_data[start:end]

    rep33 = get_rep_seq(sequences=current_data)
    rep33['uniprot_id'] = dataset[start:end].uniprot_id.tolist()
    cols = list(rep33.columns)
    cols = [cols[-1]] + cols[:-1]
    rep33 = rep33[cols]
    all_results = pd.concat([all_results, rep33], ignore_index=True)
    if end%500 == 0:
        all_results.to_feather('../DATA/feature_esm2_20230911_checkpoint.feather')

all_results.to_feather('../DATA/feature_esm2_20230911.feather')

100%|██████████| 5241/5241 [14:08<00:00,  6.18it/s] 
