#### Set environments

In [1]:
import torch, gc
import torch.nn as nn
import pandas as pd
import numpy as np
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score,\
f1_score, precision_score, recall_score, roc_auc_score, average_precision_score

In [None]:
# Set options
embed_ver = ["cnn", "lstm", "clstm"]
data_path = "../data/test_exam/"
model_path = "../models/embed_custom/"
result_path = "../results/"

In [None]:
col_str = ['file_id', 'organism', 'locus_tag', 'ess']
layer_num = 2
max_len = 1600
batch_size = 256

In [3]:
# Set data list for test dataset
ts_data = {
    "data1": ["C018"],  # "Escherichia coli K-12 BW25113"
}

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# Define function to record perfomance result
def record_perform(emb_ver, file_id, organ, y_real, y_conf, y_prd):
    y_real = y_real.cpu().numpy()
    y_conf = y_conf.cpu().numpy()
    y_prd = y_prd.cpu().numpy()
    
    if file_id != "O046":
        auc_roc = [roc_auc_score(y_real, y_conf)]
        auc_pr = [average_precision_score(y_real, y_conf)]
    else:
        auc_roc = None
        auc_pr = None
    
    tn, fp, fn, tp = confusion_matrix(y_real, y_prd).ravel()
    
    result = pd.DataFrame({
        "embed": [emb_ver],
        "file": [file_id],
        "organism": [organ],
        "tp": [tp],
        "fp": [fp],
        "tn": [tn],
        "fn": [fn],
        "mcc": [matthews_corrcoef(y_real, y_prd)],
        "acc": [accuracy_score(y_real, y_prd)],
        "f1": [f1_score(y_real, y_prd)],
        "prc": [precision_score(y_real, y_prd)],
        "rec": [recall_score(y_real, y_prd)],
        "npv": [precision_score(1 - y_real, 1 - y_prd)],
        "tnr": [recall_score(1 - y_real, 1 - y_prd)],
        "auc-roc": auc_roc,
        "auc-pr": auc_pr
    })

    return result


In [6]:
# Set model architecture
class ClassifierCNN(nn.Module):
    def __init__(self, num_layer, pool_len, vocab_size, max_len):
        super(ClassifierCNN, self).__init__()
        emb_dim = 16
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim, padding_idx=0)

        in_dim = emb_dim
        out_dim = 1024
        layers = []
        for i in range(num_layer):
            layers.append(nn.Conv1d(in_dim, out_dim, kernel_size=8))
            self.initialize_weights(layers[-1])
            layers.append(nn.GELU())
            layers.append(nn.AdaptiveMaxPool1d(max(pool_len, max_len // (8 * (i + 1)))))
            in_dim = out_dim
        layers.append(nn.AdaptiveMaxPool1d(pool_len))
        self.emb_block = nn.Sequential(*layers)
        
        self.bn = nn.BatchNorm1d(out_dim)
        self.do = nn.Dropout(0.5)
        self.fc = nn.Linear(out_dim, 1)
    
    def initialize_weights(self, layer):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='linear')
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)

    def forward(self, x):
        # embedding
        emb = self.embedding(x)
        emb = emb.permute(0, 2, 1)
        emb = self.emb_block(emb)
        # # global average pooling w/o paddings
        # val_lens = (emb != 0).any(dim=1).sum(dim=1, keepdim=True)
        # emb = emb.sum(dim=2) / val_lens
        emb = emb.mean(dim=2)
        emb = self.bn(emb)
        # classification
        x = self.do(emb)
        x = self.fc(x)
        return x.squeeze(1)

In [7]:
# Set model architecture
class ClassifierLSTM(nn.Module):
    def __init__(self, num_layer, input_len, vocab_size):
        super(ClassifierLSTM, self).__init__()
        self.input_len = input_len
        
        emb_dim = 16
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim, padding_idx=0)
        
        out_dim = 1024
        self.lstm = nn.LSTM(emb_dim, out_dim // 2, num_layers=num_layer, bidirectional=True, batch_first=True)
        
        self.bn = nn.BatchNorm1d(out_dim)
        self.do = nn.Dropout(0.5)
        self.fc = nn.Linear(out_dim, 1)
    
    def initialize_weights(self, layer):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='linear')
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)

    def forward(self, x):
        # extract valid input sequences
        val_lens = (x != 0).sum(dim=1)
        start_idx = torch.clamp((val_lens - self.input_len) // 2, min=0)
        x = torch.stack([x[i, start:start+self.input_len] for i, start in enumerate(start_idx)], dim=0)
        # embedding
        emb = self.embedding(x)
        val_lens = (x != 0).sum(dim=1).cpu()
        emb = pack_padded_sequence(emb, val_lens, batch_first=True, enforce_sorted=False)
        out, (hidden, cell) = self.lstm(emb)
        hidden = hidden.view(self.lstm.num_layers, 2, x.size(0), self.lstm.hidden_size)
        emb = torch.cat((hidden[-1, 0, :, :], hidden[-1, 1, :, :]), dim=1)
        emb = self.bn(emb)
        # classification
        x = self.do(emb)
        x = self.fc(x)
        return x.squeeze(1)

In [8]:
# Set model architecture
class ClassifierCNN_LSTM(nn.Module):
    def __init__(self, pool_len, vocab_size):
        super(ClassifierCNN_LSTM, self).__init__()
        emb_dim = 16        
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim, padding_idx=0)

        out_dim = 1024
        self.conv1d = nn.Conv1d(emb_dim, out_dim, kernel_size=8)
        self.initialize_weights(self.conv1d)
        self.pool = nn.AdaptiveAvgPool1d(pool_len)
        self.lstm = nn.LSTM(out_dim, out_dim // 2, bidirectional=True, batch_first=True)
        
        self.bn = nn.BatchNorm1d(out_dim)
        self.do = nn.Dropout(0.5)
        self.fc = nn.Linear(out_dim, 1) 
    
    def initialize_weights(self, layer):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='linear')
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)

    def forward(self, x):
        # embedding
        emb = self.embedding(x)
        emb = emb.permute(0, 2, 1)
        emb = F.gelu(self.conv1d(emb))
        emb = self.pool(emb)
        emb = emb.permute(0, 2, 1)
        # mark valid feature steps
        val_lens = (emb != 0).any(dim=2).sum(dim=1).cpu()
        emb = pack_padded_sequence(emb, val_lens, batch_first=True, enforce_sorted=False)
        out, (hidden, cell) = self.lstm(emb)
        emb = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        emb = self.bn(emb)
        # classification
        x = self.do(emb)
        x = self.fc(x)
        return x.squeeze(1)

#### Prepare data

In [9]:
# Load dataset
df = pd.read_csv(data_path + f"data_seq-raw.csv")
display(df)

Unnamed: 0,file_id,genome_id,organism,locus_tag,protein_id,product,locus,strand,dna_seq,dna_len,aa_seq,aa_len,ess
0,C018,CP009273,Escherichia coli K-12 BW25113,BW25113_0001,AIN30539.1,thr operon leader peptide,190..255,+,ATGAAACGCATTAGCACCACCATTACCACCACCATCACCATTACCA...,66,MKRISTTITTTITITTGNGAG,21,0
1,C018,CP009273,Escherichia coli K-12 BW25113,BW25113_0002,AIN30540.1,Bifunctional aspartokinase/homoserinedehydroge...,337..2799,+,ATGCGAGTGTTGAAGTTCGGCGGTACATCAGTGGCAAATGCAGAAC...,2463,MRVLKFGGTSVANAERFLRVADILESNARQGQVATVLSAPAKITNH...,820,0
2,C018,CP009273,Escherichia coli K-12 BW25113,BW25113_0003,AIN30541.1,homoserine kinase,2801..3733,+,ATGGTTAAAGTTTATGCCCCGGCTTCCAGTGCCAATATGAGCGTCG...,933,MVKVYAPASSANMSVGFDVLGAAVTPVDGALLGDVVTVEAAETFSL...,310,0
3,C018,CP009273,Escherichia coli K-12 BW25113,BW25113_0004,AIN30542.1,L-threonine synthase,3734..5020,+,ATGAAACTCTACAATCTGAAAGATCACAACGAGCAGGTCAGCTTTG...,1287,MKLYNLKDHNEQVSFAQAVTQGLGKNQGLFFPHDLPEFSLTEIDEM...,428,0
4,C018,CP009273,Escherichia coli K-12 BW25113,BW25113_0005,AIN30543.1,DUF2502 family putative periplasmic protein,5234..5530,+,GTGAAAAAGATGCAATCTATCGTACTCGCACTTTCCCTGGTTCTGG...,297,MKKMQSIVLALSLVLVAPMAAQAAEITLVPSVKLQIGDRDNRGYYW...,98,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
4308,C018,CP009273,Escherichia coli K-12 BW25113,BW25113_4702,AIN34525.1,regulatory leader peptide for mgtA,4457248..4457301,+,ATGGAACCTGATCCCACGCCTCTCCCTCGACGGAGATTAAAACTTT...,54,MEPDPTPLPRRRLKLFR,17,0
4309,C018,CP009273,Escherichia coli K-12 BW25113,BW25113_4703,AIN34398.1,putative membrane-bound BasS regulator,4321933..4322022,+,ATGAAAAACCGTGTTTATGAAAGTTTAACTACCGTGTTCAGCGTGC...,90,MKNRVYESLTTVFSVLVVSSFLYIWFATY,29,0
4310,C018,CP009273,Escherichia coli K-12 BW25113,BW25113_4705,AIN31285.1,"Mn(2)-response protein, MntR-repressed",848325..848453,-,ATGAATGAGTTCAAGAGGTGTATGCGCGTGTTTAGTCATTCTCCCT...,129,MNEFKRCMRVFSHSPFKVRLMLLSMLCDMVNNKPQQDKPSDK,42,0
4311,C018,CP009273,Escherichia coli K-12 BW25113,BW25113_4706,AIN32938.1,3-hydroxypropionic acid resistance peptide,2661365..2661430,-,ATGAAGCCGGCATTACGCGATTTCATCGCCATTGTGCAGGAACGTT...,66,MKPALRDFIAIVQERLASVTA,21,0


In [10]:
# Set vocabulary dictionary
valid_aa = 'ACDEFGHIKLMNPQRSTVWY'
aa_to_int = {aa: idx + 1 for idx, aa in enumerate(valid_aa)}  # 0 is padding

print(aa_to_int)

{'A': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, 'K': 9, 'L': 10, 'M': 11, 'N': 12, 'P': 13, 'Q': 14, 'R': 15, 'S': 16, 'T': 17, 'V': 18, 'W': 19, 'Y': 20}


In [11]:
# Replace invalid residues
df['aa_seq'] = df['aa_seq'].str.replace(f"[^{valid_aa}]", "", regex=True)

# # Filter invalid sequences
# df = df[df['aa_seq'].apply(lambda seq: set(seq).issubset(set(valid_aa)))]
# df = df.reset_index(drop=True)

print(df.shape)

(4313, 13)


In [12]:
# Truncate sequences to max length
seq = df['aa_seq'].str[:max_len].to_list()

print(len(seq))
print(seq[:2])

4313
['MKRISTTITTTITITTGNGAG', 'MRVLKFGGTSVANAERFLRVADILESNARQGQVATVLSAPAKITNHLVAMIEKTISGQDALPNISDAERIFAELLTGLAAAQPGFPLAQLKTFVDQEFAQIKHVLHGISLLGQCPDSINAALICRGEKMSIAIMAGVLEARGHNVTVIDPVEKLLAVGHYLESTVDIAESTRRIAASRIPADHMVLMAGFTAGNEKGELVVLGRNGSDYSAAVLAACLRADCCEIWTDVDGVYTCDPRQVPDARLLKSMSYQEAMELSYFGAKVLHPRTITPIAQFQIPCLIKNTGNPQAPGTLIGASRDEDELPVKGISNLNNMAMFSVSGPGMKGMVGMAARVFAAMSRARISVVLITQSSSEYSISFCVPQSDCVRAERAMQEEFYLELKEGLLEPLAVTERLAIISVVGDGMRTLRGISAKFFAALARANINIVAIAQGSSERSISVVVNNDDATTGVRVTHQMLFNTDQVIEVFVIGVGGVGGALLEQLKRQQSWLKNKHIDLRVCGVANSKALLTNVHGLNLENWQEELAQAKEPFNLGRLIRLVKEYHLLNPVIVDCTSSQAVADQYADFLREGFHVVTPNKKANTSSMDYYHQLRYAAEKSRRKFLYDTNVGAGLPVIENLQNLLNAGDELMKFSGILSGSLSYIFGKLDEGMSFSEATTLAREMGYTEPDPRDDLSGMDVARKLLILARETGRELELADIEIEPVLPAEFNAEGDVAAFMANLSQLDDLFAARVAKARDEGKVLRYVGNIDEDGVCRVKIAEVDGNDPLFKVKNGENALAFYSHYYQPLPLVLRGYGAGNDVTAAGVFADLLRTLSWKLGV']


In [13]:
# Set integer encoding function
def int_encode_seq(seq, aa_to_int):
    seq_list = list(seq)
    return [aa_to_int[aa] for aa in seq_list]

# Encode sequences to integers
encoded_seq = [int_encode_seq(s, aa_to_int) for s in seq]

print(len(encoded_seq))
print(encoded_seq[:2])

4313
[[11, 9, 15, 8, 16, 17, 17, 8, 17, 17, 17, 8, 17, 8, 17, 17, 6, 12, 6, 1, 6], [11, 15, 18, 10, 9, 5, 6, 6, 17, 16, 18, 1, 12, 1, 4, 15, 5, 10, 15, 18, 1, 3, 8, 10, 4, 16, 12, 1, 15, 14, 6, 14, 18, 1, 17, 18, 10, 16, 1, 13, 1, 9, 8, 17, 12, 7, 10, 18, 1, 11, 8, 4, 9, 17, 8, 16, 6, 14, 3, 1, 10, 13, 12, 8, 16, 3, 1, 4, 15, 8, 5, 1, 4, 10, 10, 17, 6, 10, 1, 1, 1, 14, 13, 6, 5, 13, 10, 1, 14, 10, 9, 17, 5, 18, 3, 14, 4, 5, 1, 14, 8, 9, 7, 18, 10, 7, 6, 8, 16, 10, 10, 6, 14, 2, 13, 3, 16, 8, 12, 1, 1, 10, 8, 2, 15, 6, 4, 9, 11, 16, 8, 1, 8, 11, 1, 6, 18, 10, 4, 1, 15, 6, 7, 12, 18, 17, 18, 8, 3, 13, 18, 4, 9, 10, 10, 1, 18, 6, 7, 20, 10, 4, 16, 17, 18, 3, 8, 1, 4, 16, 17, 15, 15, 8, 1, 1, 16, 15, 8, 13, 1, 3, 7, 11, 18, 10, 11, 1, 6, 5, 17, 1, 6, 12, 4, 9, 6, 4, 10, 18, 18, 10, 6, 15, 12, 6, 16, 3, 20, 16, 1, 1, 18, 10, 1, 1, 2, 10, 15, 1, 3, 2, 2, 4, 8, 19, 17, 3, 18, 3, 6, 18, 20, 17, 2, 3, 13, 15, 14, 18, 13, 3, 1, 15, 10, 10, 9, 16, 11, 16, 20, 14, 4, 1, 11, 4, 10, 16, 20, 5, 6, 1,

In [14]:
# Set padding function
def pad_seq(seq, max_len):
    return seq + [0] * (max_len - len(seq))

# Pad sequences
encoded_seq = np.array([pad_seq(s, max_len) for s in encoded_seq])

print(encoded_seq.shape)
print(encoded_seq[:2])

(4313, 1600)
[[11  9 15 ...  0  0  0]
 [11 15 18 ...  0  0  0]]


In [15]:
data = pd.concat([df[col_str], pd.DataFrame(encoded_seq)], axis=1)

display(data)

Unnamed: 0,file_id,organism,locus_tag,ess,0,1,2,3,4,5,...,1590,1591,1592,1593,1594,1595,1596,1597,1598,1599
0,C018,Escherichia coli K-12 BW25113,BW25113_0001,0,11,9,15,8,16,17,...,0,0,0,0,0,0,0,0,0,0
1,C018,Escherichia coli K-12 BW25113,BW25113_0002,0,11,15,18,10,9,5,...,0,0,0,0,0,0,0,0,0,0
2,C018,Escherichia coli K-12 BW25113,BW25113_0003,0,11,18,9,18,20,1,...,0,0,0,0,0,0,0,0,0,0
3,C018,Escherichia coli K-12 BW25113,BW25113_0004,0,11,9,10,20,12,10,...,0,0,0,0,0,0,0,0,0,0
4,C018,Escherichia coli K-12 BW25113,BW25113_0005,0,11,9,9,11,14,16,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4308,C018,Escherichia coli K-12 BW25113,BW25113_4702,0,11,4,13,3,13,17,...,0,0,0,0,0,0,0,0,0,0
4309,C018,Escherichia coli K-12 BW25113,BW25113_4703,0,11,9,12,15,18,20,...,0,0,0,0,0,0,0,0,0,0
4310,C018,Escherichia coli K-12 BW25113,BW25113_4705,0,11,12,4,5,9,15,...,0,0,0,0,0,0,0,0,0,0
4311,C018,Escherichia coli K-12 BW25113,BW25113_4706,0,11,9,13,1,10,15,...,0,0,0,0,0,0,0,0,0,0


In [16]:
col_num = [col for col in data.columns if col not in col_str]

In [17]:
# get test datasets
loc_ts = {}
data_ts = {}
org_ts = {}
for ts_ver, ids in ts_data.items():
    # get test sample locations
    loc_ts[ts_ver] = data['file_id'].isin(ids)
    # get test samples
    data_ts[ts_ver] = data[loc_ts[ts_ver]]
    org = []
    # get test organism list
    for i in ids:
        organ = data_ts[ts_ver]['organism'][data_ts[ts_ver]['file_id'] == i].to_list()
        if len(organ) > 0:
            org.append(organ[0])
    org_ts[ts_ver] = org

    print("Test dataset(" + ts_ver + "):", data_ts[ts_ver].shape)
print("Test organism:", org_ts, len(org_ts))

Test dataset(data1): (4313, 1604)
Test organism: {'data1': ['Escherichia coli K-12 BW25113']} 1


In [18]:
# split info.& inputs & labels of the test datasets
info_ts = {}
y_ts = {}
test_loader = {}
for ts_ver, df in data_ts.items():
    info_ts[ts_ver] = df[col_str]
    X_ts = torch.tensor(df[col_num].astype('long').values)
    y_ts[ts_ver] = torch.tensor(df['ess'].astype('long').values)
    print("Splited test dataset(" + ts_ver + "):", X_ts.shape, y_ts[ts_ver].shape)                    
    # generate dataloader by the test datasets
    dataset_ts = TensorDataset(X_ts, y_ts[ts_ver])
    test_loader[ts_ver] = DataLoader(dataset_ts, batch_size=256, shuffle=False)

Splited test dataset(data1): torch.Size([4313, 1600]) torch.Size([4313])


#### Evaluate model

In [None]:
df_eval = pd.DataFrame()

for ver in embed_ver:
    #### Evaluate model ####
    # set model name
    model_name = f"emb-{ver}"
    print(f"\n===== Test model: {model_name} ====")
    
    if ver == 'cnn':
        model = ClassifierCNN(num_layer=layer_num, pool_len=25, vocab_size=len(aa_to_int) + 1, max_len=max_len).to(device)
        model.load_state_dict(torch.load(model_path + model_name + ".pt", map_location=device))
    elif ver == 'lstm':
        model = ClassifierLSTM(num_layer=layer_num, input_len=100, vocab_size=len(aa_to_int) + 1).to(device)
        model.load_state_dict(torch.load(model_path + model_name + ".pt", map_location=device))
    else:
        model = ClassifierCNN_LSTM(pool_len=25, vocab_size=len(aa_to_int) + 1).to(device)
        model.load_state_dict(torch.load(model_path + model_name + ".pt", map_location=device))
    
    model.eval()

    ## model evaluations by test dataset ##
    df_pred = pd.DataFrame()
    total_result = {key: [] for key in col_str + ['logit', 'conf']}

    for ts_ver, ids in ts_data.items():
        results = {key: [] for key in total_result.keys()}
        with torch.no_grad():
            for X_batch, y_batch in test_loader[ts_ver]:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)
                # prediction
                preds = model(X_batch).squeeze()
                # gather the result
                results['logit'].extend(preds.cpu().tolist())
        # gather testset info.
        for key in col_str:
            results[key].extend(info_ts[ts_ver][key].tolist())
        
        gc.collect()

        # convert logits to confidences & classes
        prd_conf = torch.sigmoid(results['logit'])
        prd_cls = (prd_conf >= 0.5).int()

        # gather result of the predicted essentiality
        for key, val in results.items():
            total_result[key].extend(val)
        pred_ts = pd.DataFrame({key: results[key] for key in col_str + ['conf']})
        df_pred = pd.concat([df_pred, pred_ts], ignore_index=True)
        
        # get evaluation row by testset
        eval_ts = record_perform(
            emb_ver=ver,
            file_id="+".join(ids),
            organ="+".join(org_ts[ts_ver]),
            y_real=results['ess'],
            y_conf=prd_conf,
            y_prd=prd_cls,
        )
        df_eval = pd.concat([df_eval, eval_ts], ignore_index=True)
        print(f"- Test in {ts_ver} was done.")
    
    # save the model prediction result
    df_pred.to_csv(f"{result_path}prd-custom_embed/{model_name}.csv", index=False)

    # convert logits to confidences & classes
    prd_conf = torch.sigmoid(total_result['logit'])
    prd_cls = (prd_conf >= 0.5).int()

    # get total mean row
    eval_ts = record_perform(
        emb_ver=ver,
        file_id="total",
        organ="all",
        y_real=total_result['ess'],
        y_conf=prd_conf,
        y_prd=prd_cls
    )
    df_eval = pd.concat([df_eval, eval_ts], ignore_index=True)

# save the model evaluation result
df_eval.to_csv(f"{result_path}eval-custom_embed.csv", index=False)
display("Model performance:", df_eval)