In [1]:
import torch, esm, time
import pandas as pd
import numpy as np

#### Prepare Model

In [None]:
# Load embedding model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()  # esm2_t33_650M_UR50D() - layers:33, esm2_t6_8M_UR50D() - layers:6
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

#### Prepare data

In [None]:
# Load protein sequence data
df = pd.read_csv("../data/data-seq_raw-ts.csv")

display(df)

In [None]:
valid_aa = 'ACDEFGHIKLMNPQRSTVWY'

# Replace invalid residues
df['aa_seq'] = df['aa_seq'].str.replace(f"[^{valid_aa}]", "", regex=True)

# # Filter invalid sequences
# df = df[df['aa_seq'].apply(lambda seq: set(seq).issubset(set(valid_aa)))]
# df = df.reset_index(drop=True)

print(df.shape)

In [6]:
# Truncate sequences to max length
max_len = 1600
df['aa_seq'] = df['aa_seq'].str[:max_len]

In [7]:
# Extract sequences
df['id'] = df['file_id'] + "-" + df['locus_tag']
data = list(zip(df['id'], df['aa_seq']))

# Tokenize sequences
*_, batch_tokens = batch_converter(data)  # batch_labels, batch_strs, batch_tokens

In [None]:
# Split dataset to batch
batch_size = 32

batch_data = batch_tokens.split(batch_size, dim=0)
batch_n = len(batch_data)

print("Number of samples:", len(data))
print("Number of batchs:", batch_n)
print("First batch shape:", batch_data[0].shape)
print("Last batch shape:", batch_data[-1].shape)

In [9]:
def embed_seq(batch_data):
    # Extract per-residue representations
    batch_tokens = batch_data.to(device)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=False)
    token_representations = results["representations"][33].cpu()
    batch_tokens = batch_data.cpu()
    
    # Generate per-sequence representations via averaging
    # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
    embed_mean = []

    for i, tokens_len in enumerate(batch_lens):
        embed_mean.append(token_representations[i, 1:tokens_len - 1].mean(0).cpu().numpy())
    
    embed_mean = np.stack(embed_mean)
    
    # display process
    global step, time_step
    step += 1
    if step % 100 == 0:
        print(f"Step: {step}/{batch_n} | Processing time: {time.time() - time_step:.1f} sec")
        time_step = time.time()
    
    return embed_mean


In [None]:
# Embed sequences
step = 0
time_total = time.time()
time_step = time.time()
embed_mean = list(map(embed_seq, batch_data))

print(f"Total processing time: {time.time() - time_total:.1f} sec")

In [None]:
emb_mean = np.concatenate(embed_mean)

print(len(df), emb_mean.shape)

#### Concatenate each gene info. & embedded features

In [None]:
col_str = ['file_id', 'organism', 'locus_tag', 'ess']

# Convert the pooled features to dataframe & concatenate the each information
emb_mean = pd.concat([df[col_str], pd.DataFrame(emb_mean)], axis=1)

display(emb_mean)

In [None]:
# Save the result
emb_mean.to_csv("../data/data-emb_gen-esm2-ts.csv", index=False)