#### Set environment

In [1]:
import torch, os
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence

In [2]:
# Set options
data_path = "../data/"
model_path = "../model/seq_gen/"

model_name = 'clstm'
pool_len = 25

# preprocessing & training options
max_len = 1600
batch_size = 32

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### Prepare data

In [None]:
# Load protein sequence data
df = pd.read_csv(data_path + "data-seq_raw-ts.csv")

display(df)

In [None]:
# Set vocabulary dictionary
valid_aa = 'ACDEFGHIKLMNPQRSTVWY'
aa_to_int = {aa: idx + 1 for idx, aa in enumerate(valid_aa)}  # 0 is padding

print(aa_to_int)

In [None]:
# Replace invalid residues
df['aa_seq'] = df['aa_seq'].str.replace(f"[^{valid_aa}]", "", regex=True)

# # Filter invalid sequences
# df = df[df['aa_seq'].apply(lambda seq: set(seq).issubset(set(valid_aa)))]
# df = df.reset_index(drop=True)

print(df.shape)

In [None]:
# Truncate sequences to max length
seq = df['aa_seq'].str[:max_len].to_list()

print(len(seq))
print(seq[:2])

In [None]:
# Set integer encoding function
def int_encode_seq(seq, aa_to_int):
    seq_list = list(seq)
    return [aa_to_int[aa] for aa in seq_list]

# Encode sequences to integers
encoded_seq = [int_encode_seq(s, aa_to_int) for s in seq]

print(len(encoded_seq))
print(encoded_seq[:2])

In [None]:
# Set padding function
def pad_seq(seq, max_len):
    return seq + [0] * (max_len - len(seq))

# Pad sequences
encoded_seq = np.array([pad_seq(s, max_len) for s in encoded_seq])

print(encoded_seq.shape)
print(encoded_seq[:2])

In [None]:
# convert dataset to tensor
X_all = torch.LongTensor(encoded_seq)

# generate dataloader
data_loader = DataLoader(X_all, batch_size=batch_size, shuffle=False)

print(X_all.shape, len(data_loader))

#### Sequence embedding

In [11]:
# Set model architecture
class ClassifierCNN_LSTM(nn.Module):
    def __init__(self, pool_len, vocab_size):
        super(ClassifierCNN_LSTM, self).__init__()
        emb_dim = 16        
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim, padding_idx=0)

        out_dim = 1024
        self.conv1d = nn.Conv1d(emb_dim, out_dim, kernel_size=8)
        self.initialize_weights(self.conv1d)
        self.pool = nn.AdaptiveAvgPool1d(pool_len)
        self.lstm = nn.LSTM(out_dim, out_dim // 2, bidirectional=True, batch_first=True)
        
        self.bn = nn.BatchNorm1d(out_dim)
        self.do = nn.Dropout(0.5)
        self.fc = nn.Linear(out_dim, 1) 
    
    def initialize_weights(self, layer):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='linear')
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)

    def forward(self, x):
        # embedding
        emb = self.embedding(x)
        emb = emb.permute(0, 2, 1)
        emb = F.gelu(self.conv1d(emb))
        emb = self.pool(emb)
        emb = emb.permute(0, 2, 1)
        # mark valid feature steps
        val_lens = (emb != 0).any(dim=2).sum(dim=1).cpu()
        emb = pack_padded_sequence(emb, val_lens, batch_first=True, enforce_sorted=False)
        out, (hidden, cell) = self.lstm(emb)
        emb = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        emb = self.bn(emb)
        # classification
        x = self.do(emb)
        x = self.fc(x)
        return emb.squeeze(1)

In [12]:
model = ClassifierCNN_LSTM(pool_len=pool_len, vocab_size=len(aa_to_int) + 1).to(device)
model.load_state_dict(torch.load(model_path + f"seq_gen-{model_name}.pt", map_location=device))

  model.load_state_dict(torch.load(model_path + f"seq_gen-{model_name}-{str(model_ver)}.pt", map_location=device))


<All keys matched successfully>

In [13]:
# Get embedding of the sequences
model.eval()
embed = []
with torch.no_grad():
    for X_batch in data_loader:
        X_batch = X_batch.to(device)
        emb = model(X_batch).cpu()
        embed.append(emb)

In [14]:
# Convert embeding tensor to dataframe
emb_df = pd.DataFrame(torch.cat(embed, dim=0))

In [15]:
# concatenate protein info. & embeddings
col_str = ['file_id', 'organism', 'locus_tag', 'ess']
emb_df = pd.concat([df[col_str], emb_df], axis=1)

display(emb_df)

Unnamed: 0,file_id,organism,locus_tag,ess,0,1,2,3,4,5,...,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
0,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0001,0,-1.271414,0.000099,0.076290,-0.000540,-0.002149,-0.002770,...,-0.000128,0.004823,0.000413,0.127032,0.006706,0.013683,-0.020618,-0.049667,0.108772,-0.003629
1,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0002,0,0.080849,0.000191,0.066948,-0.001360,-0.001139,-0.002485,...,-0.000043,0.001776,0.001967,0.126776,0.007055,0.008662,0.113034,-0.035381,0.105291,0.002070
2,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0003,0,0.049742,0.000037,0.063013,-0.003911,-0.002671,-0.000340,...,-0.000037,0.001271,0.002466,0.126992,0.007107,0.007993,0.119393,-0.034128,0.103778,0.002723
3,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0004,0,0.438240,0.000130,0.061992,-0.002766,-0.001883,-0.001800,...,-0.000041,0.001173,0.001617,0.127110,0.007114,0.008159,0.119812,-0.033128,0.105875,0.002058
4,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0005,0,0.647014,0.000091,0.058150,-0.003637,-0.002623,-0.002049,...,-0.000028,0.000526,0.001174,0.127482,0.007209,0.007493,0.129736,-0.029263,0.105677,0.004036
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
283919,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0913,1,0.244533,0.000047,0.060919,-0.003986,-0.002723,-0.000765,...,-0.000052,0.001888,0.002917,0.127124,0.007069,0.007943,0.102165,-0.032948,0.106551,0.001652
283920,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0918,1,0.446268,0.000114,0.060191,-0.002962,-0.002494,-0.002787,...,-0.000028,0.001444,0.000616,0.126796,0.007066,0.008400,0.101383,-0.035993,0.105889,0.002418
283921,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0930,1,-1.313249,0.000093,0.077817,-0.000610,-0.002145,-0.002522,...,-0.000083,0.002701,0.002104,0.127490,0.007041,0.009210,0.034684,-0.041377,0.106090,-0.001276
283922,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0931,1,0.339057,0.000057,0.060354,-0.003875,-0.002801,-0.001344,...,-0.000051,0.000935,0.001868,0.126976,0.007136,0.007776,0.122922,-0.030382,0.105050,0.003196


In [16]:
# Save the embeddings
emb_df.to_csv(data_path + f"data-emb_gen-{model_name}.csv", index=False)