In [1]:
import torch
from transformers import BertModel, BertTokenizer
import gc, time
import pandas as pd
import numpy as np

In [None]:
# Load embedding model
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False)
model = BertModel.from_pretrained("Rostlab/prot_bert_bfd")
model = model.eval()

gc.collect()

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
# Load protein sequence data
df = pd.read_csv("../data/test_exam/data_seq-raw.csv")

display(df)

In [None]:
valid_aa = 'ACDEFGHIKLMNPQRSTVWY'

# Replace invalid residues
df['aa_seq'] = df['aa_seq'].str.replace(f"[^{valid_aa}]", "", regex=True)

# # Filter invalid sequences
# df = df[df['aa_seq'].apply(lambda seq: set(seq).issubset(set(valid_aa)))]
# df = df.reset_index(drop=True)

display(df)

In [6]:
# Truncate sequences to max length
max_len = 1600
df['aa_seq'] = df['aa_seq'].str[:max_len]

In [7]:
# Add spaces between letters
seq = df['aa_seq']
seq = list(map(lambda x: " ".join(x), seq))

In [None]:
# Tokenize sequences
ids = tokenizer.batch_encode_plus(seq, add_special_tokens=True, padding=True)

input_ids = torch.tensor(ids['input_ids'])
attention_mask = torch.tensor(ids['attention_mask'])

print(input_ids.shape, attention_mask.shape)

In [None]:
# Split the tensors by batch size
batch_size = 256

batch_ids = input_ids.split(batch_size, dim=0)
batch_mask = attention_mask.split(batch_size, dim=0)
batch_n = len(batch_ids)

print("Number of samples:", len(input_ids))
print("Number of seq & mask batchs:", batch_n, len(batch_mask))
print("First batch shape:", batch_ids[0].shape)
print("Last batch shape:", batch_ids[-1].shape)

In [10]:
# Set embedding functions
def extract_embed(emb, mask):
    seq_len = (mask == 1).sum()
    emb_mean = emb[1:seq_len - 1].mean(0)
    return emb_mean

def embed_seq(inp, mask):
    inp, mask = inp.to(device), mask.to(device)
    with torch.no_grad():
        emb = model(input_ids=inp, attention_mask=mask)
    emb = emb.last_hidden_state.cpu().numpy()
    mask = mask.cpu().numpy()
    
    embed_mean = list(map(extract_embed, emb, mask))
    embed_mean = np.stack(embed_mean)
    
    # display process
    global step, time_step
    step += 1
    if step % 100 == 0:
        print(f"Step: {step}/{batch_n} | Processing time: {time.time() - time_step:.1f} sec")
        time_step = time.time()
    
    return embed_mean


In [None]:
# Embed sequences
step = 0
time_total = time.time()
time_step = time.time()
embed_mean = list(map(embed_seq, batch_ids, batch_mask))

print(f"Total processing time: {time.time() - time_total} sec")

In [None]:
emb_mean = np.concatenate(embed_mean)

print(len(df), emb_mean.shape)

In [None]:
# Convert the pooled features to dataframe & concatenate the each information
col_str = ['file_id', 'organism', 'locus_tag', 'ess']

emb_mean = pd.concat([df[col_str], pd.DataFrame(emb_mean)], axis=1)

display(emb_mean)

In [None]:
# Save the result
emb_mean.to_csv("../data/test_exam/data_emb-bert.csv", index=False)