#### Setting environments

In [2]:
import torch, os, gc
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score,\
f1_score, precision_score, recall_score, roc_auc_score, average_precision_score

In [6]:
# Set options
embed_ver = ["clstm", "esm2", "bert", "t5"]
comb_ver = ["emb_gen-clstm", "emb_gen-esm2", "emb_gen-bert", "emb_gen-t5"]

# model options
data_path = "../data/"
model_path = f"../model/emb_gen/"
result_path = f"../result/prd-emb_gen/"
os.makedirs(result_path, exist_ok=True)

col_str = ['file_id', 'organism', 'locus_tag', 'ess']

batch_size = 256

In [4]:
# Set data list for test dataset
ts_data = {
    "data1": ["C018"],  # "Escherichia coli K-12 BW25113"
    "data2": ["C016"],  # "Escherichia coli K-12 MG1655"
    "data3": ["O046"],  # "synthetic bacterium JCVI-Syn3A"
    "data4": ["C048"],  # Bacteroides thetaiotaomicron VPI-5482
    "data5": ["C050"]  # Salmonella enterica subsp. enterica serovar Typhimurium str. 14028S
}

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
# Define function to record perfomance result
def record_perform(emb_ver, file_id, organ, y_real, y_conf, y_prd):
    y_real = y_real.cpu().numpy()
    y_conf = y_conf.cpu().numpy()
    y_prd = y_prd.cpu().numpy()
    
    if file_id != "O046":
        auc_roc = [roc_auc_score(y_real, y_conf)]
        auc_pr = [average_precision_score(y_real, y_conf)]
    else:
        auc_roc = None
        auc_pr = None
    
    tn, fp, fn, tp = confusion_matrix(y_real, y_prd).ravel()
    
    result = pd.DataFrame({
        "emb": [emb_ver],
        "file": [file_id],
        "organism": [organ],
        "tp": [tp],
        "fp": [fp],
        "tn": [tn],
        "fn": [fn],
        "mcc": [matthews_corrcoef(y_real, y_prd)],
        "acc": [accuracy_score(y_real, y_prd)],
        "f1": [f1_score(y_real, y_prd)],
        "prc": [precision_score(y_real, y_prd)],
        "rec": [recall_score(y_real, y_prd)],
        "npv": [precision_score(1 - y_real, 1 - y_prd)],
        "tnr": [recall_score(1 - y_real, 1 - y_prd)],
        "auc-roc": auc_roc,
        "auc-pr": auc_pr
    })

    return result


In [3]:
# Set model architecture
class Classifier(nn.Module):
    def __init__(self, input_size, num_layers, unit_decrease):
        super(Classifier, self).__init__()
        layers = [nn.BatchNorm1d(input_size), nn.Dropout(0.5)]
        in_dim = input_size
        out_dim = 1024
        for i in range(num_layers):            
            out_dim = max(2, out_dim // unit_decrease)
            layers.append(nn.Linear(in_dim, out_dim))
            self.initialize_weights(layers[-1])
            layers.append(nn.GELU())
            in_dim = out_dim
        layers.append(nn.Linear(out_dim, 1))
        self.cls_block = nn.Sequential(*layers)
        
    def initialize_weights(self, layer):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='linear')
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)
    
    def forward(self, x):
        return self.cls_block(x)

#### Evaluate model

In [8]:
df_eval = pd.DataFrame()

for e_ver in embed_ver:
    ver = f'emb_gen-{e_ver}'
    print(f"\n>>>> {ver} <<<<")
    
    if ver in comb_ver:            
        # load dataset
        data = pd.read_csv(data_path + f"data-{ver}-ts.csv")
        display("Raw data:", data)

        #### Preprocess for test dataset ####
        col_num = [col for col in data.columns if col not in col_str]

        # get test datasets
        loc_ts = {}
        data_ts = {}
        org_ts = {}
        for ts_ver, ids in ts_data.items():
            # get test sample locations
            loc_ts[ts_ver] = data['file_id'].isin(ids)
            # get test samples
            data_ts[ts_ver] = data[loc_ts[ts_ver]]
            org = []
            # get test organism list
            for i in ids:
                organ = data_ts[ts_ver]['organism'][data_ts[ts_ver]['file_id'] == i].to_list()
                if len(organ) > 0:
                    org.append(organ[0])
            org_ts[ts_ver] = org

            print("Test dataset(" + ts_ver + "):", data_ts[ts_ver].shape)
        print("Test organism:", org_ts, len(org_ts))

        # split info.& inputs & labels of the test datasets
        info_ts = {}
        y_ts = {}
        test_loader = {}
        for ts_ver, df in data_ts.items():
            info_ts[ts_ver] = df[col_str]
            X_ts = torch.tensor(df[col_num].astype('float32').values)
            y_ts[ts_ver] = torch.tensor(df['ess'].astype('float32').values)
            print("Splited test dataset(" + ts_ver + "):", X_ts.shape, y_ts[ts_ver].shape)                    
            # generate dataloader by the test datasets
            dataset_ts = TensorDataset(X_ts, y_ts[ts_ver])
            test_loader[ts_ver] = DataLoader(dataset_ts, batch_size=batch_size, shuffle=False)

        # get the total test dataset
        loc_ts_all = [sum(loc) >= 1 for loc in zip(*loc_ts.values())]
        info_ts_all = data.loc[loc_ts_all, col_str]
        X_ts_all = torch.tensor(data.loc[loc_ts_all, col_num].astype('float32').values)
        y_ts_all = torch.tensor(data.loc[loc_ts_all, 'ess'].astype('float32').values)
        
        print("Splited test dataset(all):", X_ts_all.shape, y_ts_all.shape)

        # generate dataloader of total test dataset
        test_all_dataset = TensorDataset(X_ts_all, y_ts_all)
        test_all_loader = DataLoader(test_all_dataset, batch_size=batch_size, shuffle=False)
        

        #### Evaluate model ####
        # set model name
        model_name = f"{ver}-{fcnn_ver}"
        print(f"\n===== Test model: {model_name} ====")
        # generate model instance
        model = Classifier(
            input_size=X_ts_all.shape[-1],
            num_layers=3,
            unit_decrease=2
        ).to(device)

        # load model weight
        model.load_state_dict(torch.load(model_path + model_name + ".pt", map_location=device))
        model.eval()


        ## model evaluations by test dataset ##
        for ts_ver, ids in ts_data.items():
            all_preds = []
            all_labels = []
            with torch.no_grad():
                for X_batch, y_batch in test_loader[ts_ver]:
                    X_batch = X_batch.to(device)
                    y_batch = y_batch.to(device)
                    preds = model(X_batch).squeeze()
                    all_preds.append(preds.cpu())
                    all_labels.append(y_batch.cpu())
            
            # concatenate results to one tensor
            all_preds = torch.cat(all_preds)
            all_labels = torch.cat(all_labels)

            # convert logits to confidences & classes
            prd_conf = torch.sigmoid(all_preds)
            prd_cls = (prd_conf >= 0.5).int()
            # performances by testset
            perform = record_perform(
                emb_ver=e_ver,
                file_id="+".join(ids),
                organ="+".join(org_ts[ts_ver]),
                y_real=y_ts[ts_ver],
                y_conf=prd_conf,
                y_prd=prd_cls,
            )
            display(perform)
            df_eval = pd.concat([df_eval, perform], ignore_index=True)
        
        
        ## model evaluation on the total test dataset ##
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for X_batch, y_batch in test_all_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)
                preds = model(X_batch).squeeze()
                all_preds.append(preds.cpu())
                all_labels.append(y_batch.cpu())
        
        # concatenate results to one tensor
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)

        # convert logits to confidences & classes
        prd_conf = torch.sigmoid(all_preds)
        prd_cls = (prd_conf >= 0.5).int()

        # performances on total testset
        perform = record_perform(
            emb_ver=e_ver,
            file_id="+".join([i for ids in ts_data.values() for i in ids]),
            organ="+".join([org for orgs in org_ts.values() for org in orgs]),
            y_real=y_ts_all,
            y_conf=prd_conf,
            y_prd=prd_cls
        )
        display(perform)
        df_eval = pd.concat([df_eval, perform], ignore_index=True)

        # concatenate the protein info. & predicted confidences
        df_prd = pd.DataFrame(prd_conf, columns=["conf"], index=info_ts_all.index)
        df_prd = pd.concat([info_ts_all, df_prd], axis=1)

        # save the model prediction result
        df_prd.to_csv(result_path + f"prd-{model_name}.csv", index=False)

        gc.collect()

# save the model perfomance result
display("Model performance:", df_eval)
df_eval.to_csv(f"../result/eval-emb_gen.csv", index=False)


>>>> emb_gen-clstm <<<<


'Raw data:'

Unnamed: 0,file_id,organism,locus_tag,ess,0,1,2,3,4,5,...,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
0,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0001,0,-1.271414,0.000099,0.076290,-0.000540,-0.002149,-0.002770,...,-0.000128,0.004823,0.000413,0.127032,0.006706,0.013683,-0.020618,-0.049667,0.108772,-0.003629
1,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0002,0,0.080849,0.000191,0.066948,-0.001360,-0.001139,-0.002485,...,-0.000043,0.001776,0.001967,0.126776,0.007055,0.008662,0.113034,-0.035381,0.105291,0.002070
2,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0003,0,0.049742,0.000037,0.063013,-0.003911,-0.002671,-0.000340,...,-0.000037,0.001271,0.002466,0.126992,0.007107,0.007993,0.119393,-0.034128,0.103778,0.002723
3,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0004,0,0.438240,0.000130,0.061992,-0.002766,-0.001883,-0.001800,...,-0.000041,0.001173,0.001617,0.127110,0.007114,0.008159,0.119812,-0.033128,0.105875,0.002058
4,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0005,0,0.647014,0.000091,0.058150,-0.003637,-0.002623,-0.002049,...,-0.000028,0.000526,0.001174,0.127482,0.007209,0.007493,0.129736,-0.029263,0.105677,0.004036
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19378,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0913,1,0.244533,0.000047,0.060919,-0.003986,-0.002723,-0.000765,...,-0.000052,0.001888,0.002917,0.127124,0.007069,0.007943,0.102165,-0.032948,0.106551,0.001652
19379,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0918,1,0.446268,0.000114,0.060191,-0.002962,-0.002494,-0.002787,...,-0.000028,0.001444,0.000616,0.126796,0.007066,0.008400,0.101383,-0.035993,0.105889,0.002418
19380,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0930,1,-1.313249,0.000093,0.077817,-0.000610,-0.002145,-0.002522,...,-0.000083,0.002701,0.002104,0.127490,0.007041,0.009210,0.034684,-0.041377,0.106090,-0.001276
19381,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0931,1,0.339057,0.000057,0.060354,-0.003875,-0.002801,-0.001344,...,-0.000051,0.000935,0.001868,0.126976,0.007136,0.007776,0.122922,-0.030382,0.105050,0.003196


Test dataset(data1): (4313, 1028)
Test dataset(data2): (4313, 1028)
Test dataset(data3): (458, 1028)
Test dataset(data4): (4825, 1028)
Test dataset(data5): (5474, 1028)
Test organism: {'data1': ['Escherichia coli K-12 BW25113'], 'data2': ['Escherichia coli K-12 MG1655'], 'data3': ['synthetic bacterium JCVI-Syn3A'], 'data4': ['Bacteroides thetaiotaomicron VPI-5482'], 'data5': ['Salmonella enterica subsp. enterica serovar Typhimurium str. 14028S']} 5
Splited test dataset(data1): torch.Size([4313, 1024]) torch.Size([4313])
Splited test dataset(data2): torch.Size([4313, 1024]) torch.Size([4313])
Splited test dataset(data3): torch.Size([458, 1024]) torch.Size([458])
Splited test dataset(data4): torch.Size([4825, 1024]) torch.Size([4825])
Splited test dataset(data5): torch.Size([5474, 1024]) torch.Size([5474])
Splited test dataset(all): torch.Size([19383, 1024]) torch.Size([19383])

===== Test model: emb_gen-clstm-3_2 ====


  model.load_state_dict(torch.load(model_path + model_name + ".pt", map_location=device))


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,clstm,3_2,data1,C018,Escherichia coli K-12 BW25113,233,229,3786,65,0.594411,0.931834,0.613158,0.504329,0.781879,0.983121,0.942964,0.910965,0.586572


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,clstm,3_2,data2,C016,Escherichia coli K-12 MG1655,241,240,3783,49,0.613699,0.932993,0.625162,0.50104,0.831034,0.987213,0.940343,0.947629,0.632228


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,clstm,3_2,data3,O046,synthetic bacterium JCVI-Syn3A,112,0,0,346,0.0,0.244541,0.392982,1.0,0.244541,0.0,0.0,,


  df_eval = pd.concat([df_eval, perform], ignore_index=True)


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,clstm,3_2,data4,C048,Bacteroides thetaiotaomicron VPI-5482,282,107,4393,43,0.776919,0.968912,0.789916,0.724936,0.867692,0.990307,0.976222,0.975421,0.818815


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,clstm,3_2,data5,C050,Salmonella enterica subsp. enterica serovar Ty...,78,493,4868,35,0.278314,0.903544,0.22807,0.136602,0.690265,0.992862,0.90804,0.891886,0.107095


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,clstm,3_2,test_all,C018+C016+O046+C048+C050,Escherichia coli K-12 BW25113+Escherichia coli...,946,1069,16830,538,0.50333,0.917092,0.540726,0.469479,0.637466,0.969023,0.940276,0.883538,0.461305



>>>> emb_gen-esm2 <<<<


'Raw data:'

Unnamed: 0,file_id,organism,locus_tag,ess,0,1,2,3,4,5,...,1270,1271,1272,1273,1274,1275,1276,1277,1278,1279
0,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0001,0,-0.025803,2.423989,-1.412047,-0.427782,2.635565,-0.527033,...,-0.155313,1.499200,-1.622623,1.140054,-0.579188,-0.201200,-0.181841,2.446096,1.391733,-0.146903
1,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0002,0,-0.668316,-0.463429,0.103198,-0.088451,0.004220,0.314487,...,-0.441345,-0.216852,-0.173726,0.456722,-0.253636,0.488008,-0.237100,0.032918,-0.110359,0.293158
2,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0003,0,-0.389958,-0.056299,0.404185,-0.130026,0.101402,0.399542,...,-0.291001,-0.496596,-0.240145,0.931340,0.250903,0.500919,-0.206658,0.055057,0.349885,0.051546
3,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0004,0,-0.132498,-0.263630,0.180769,-0.201304,-0.002199,0.342121,...,0.079341,-0.345971,-0.169741,0.699338,-0.327057,0.669865,-0.155763,0.371676,-0.280873,-0.052240
4,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0005,0,-0.286661,-0.111691,0.556191,-0.063293,-0.274769,0.401593,...,0.157999,-0.030111,0.223549,1.058194,-0.627709,-0.035077,0.126186,0.007433,-0.104928,-0.069931
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19378,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0913,1,-0.013360,-0.466237,-0.574799,0.004326,-0.280262,0.498790,...,-0.052086,-0.375514,-0.349227,1.061728,-0.046756,0.376358,0.299031,0.954350,1.110562,0.166870
19379,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0918,1,-0.644403,-0.233777,0.404386,0.101492,-0.407517,0.367913,...,-0.216774,-0.193204,-0.085630,0.603998,0.047384,0.273780,0.067045,-0.170716,0.290532,0.482953
19380,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0930,1,0.311929,-0.638465,0.265130,0.334715,1.060847,0.648867,...,-0.729781,-0.557699,-0.633800,0.923602,-1.082554,0.885507,-1.106922,-1.015008,-0.201546,-0.682305
19381,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0931,1,-0.985561,0.844931,-0.040269,-0.267429,-0.631806,-0.615640,...,0.835866,-0.616349,-0.738524,-0.103508,0.191622,0.856668,-0.864706,0.502746,0.409977,0.147629


Test dataset(data1): (4313, 1284)
Test dataset(data2): (4313, 1284)
Test dataset(data3): (458, 1284)
Test dataset(data4): (4825, 1284)
Test dataset(data5): (5474, 1284)
Test organism: {'data1': ['Escherichia coli K-12 BW25113'], 'data2': ['Escherichia coli K-12 MG1655'], 'data3': ['synthetic bacterium JCVI-Syn3A'], 'data4': ['Bacteroides thetaiotaomicron VPI-5482'], 'data5': ['Salmonella enterica subsp. enterica serovar Typhimurium str. 14028S']} 5
Splited test dataset(data1): torch.Size([4313, 1280]) torch.Size([4313])
Splited test dataset(data2): torch.Size([4313, 1280]) torch.Size([4313])
Splited test dataset(data3): torch.Size([458, 1280]) torch.Size([458])
Splited test dataset(data4): torch.Size([4825, 1280]) torch.Size([4825])
Splited test dataset(data5): torch.Size([5474, 1280]) torch.Size([5474])
Splited test dataset(all): torch.Size([19383, 1280]) torch.Size([19383])

===== Test model: emb_gen-esm2-3_2 ====


  model.load_state_dict(torch.load(model_path + model_name + ".pt", map_location=device))


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,esm2,3_2,data1,C018,Escherichia coli K-12 BW25113,269,519,3496,29,0.507602,0.872942,0.495396,0.341371,0.902685,0.991773,0.870735,0.942126,0.701041


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,esm2,3_2,data2,C016,Escherichia coli K-12 MG1655,280,517,3506,10,0.540069,0.877811,0.515179,0.351317,0.965517,0.997156,0.871489,0.9755,0.784774


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,esm2,3_2,data3,O046,synthetic bacterium JCVI-Syn3A,259,0,0,199,0.0,0.565502,0.722455,1.0,0.565502,0.0,0.0,,


  df_eval = pd.concat([df_eval, perform], ignore_index=True)


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,esm2,3_2,data4,C048,Bacteroides thetaiotaomicron VPI-5482,292,622,3878,33,0.486274,0.864249,0.471348,0.319475,0.898462,0.991562,0.861778,0.937444,0.613368


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,esm2,3_2,data5,C050,Salmonella enterica subsp. enterica serovar Ty...,65,688,4673,48,0.184478,0.865546,0.150115,0.086321,0.575221,0.989833,0.871666,0.832666,0.062933


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,esm2,3_2,test_all,C018+C016+O046+C048+C050,Escherichia coli K-12 BW25113+Escherichia coli...,1165,2346,15553,319,0.451501,0.862508,0.466466,0.331814,0.78504,0.979902,0.868931,0.894743,0.48336



>>>> emb_gen-bert <<<<


'Raw data:'

Unnamed: 0,file_id,organism,locus_tag,ess,0,1,2,3,4,5,...,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
0,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0001,0,0.042822,-0.126205,-0.050820,-0.129861,0.068868,-0.006040,...,-0.013456,-0.048067,-0.110174,-0.008105,0.018130,-0.105259,-0.008767,0.005277,0.042083,0.034597
1,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0002,0,-0.058692,-0.037950,-0.004752,0.014216,-0.023279,0.035774,...,0.007268,-0.119998,0.003547,-0.041535,-0.139431,0.001130,-0.025769,-0.031517,0.014989,-0.006627
2,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0003,0,-0.070625,-0.035552,0.038030,0.001899,-0.000246,0.004621,...,-0.001732,-0.167638,-0.040140,-0.074581,-0.179380,-0.021140,-0.023059,-0.046827,0.031335,-0.032964
3,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0004,0,-0.037140,-0.036069,0.029161,0.002461,-0.011851,0.017077,...,0.032899,-0.134339,0.000210,-0.057801,-0.168538,-0.000947,-0.042856,-0.030568,0.011087,-0.040674
4,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0005,0,-0.032631,-0.024123,-0.018862,0.039874,-0.006922,-0.055323,...,0.046949,-0.040834,-0.021078,-0.029033,-0.105917,0.012200,0.013685,-0.042306,-0.009481,-0.002596
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19378,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0913,1,-0.036090,-0.011721,-0.002769,0.041671,0.005540,0.003676,...,-0.005665,-0.128913,0.003729,-0.027293,-0.126274,-0.013166,-0.049804,-0.017025,-0.028457,-0.000095
19379,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0918,1,-0.023218,-0.008702,0.014693,0.022481,0.036286,0.010613,...,0.056558,-0.047529,-0.022225,-0.060193,-0.078188,-0.036799,-0.024342,-0.014599,0.013637,-0.009223
19380,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0930,1,-0.009309,0.004102,0.003444,0.031261,0.026689,0.051066,...,0.051691,-0.034478,0.004822,0.016618,0.070800,0.009008,-0.033036,-0.031658,-0.041180,-0.027903
19381,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0931,1,0.007387,-0.008439,0.020698,0.009457,0.050447,-0.007523,...,0.059217,-0.071523,0.001269,-0.009234,-0.043703,0.003266,-0.027528,-0.026730,-0.012166,0.000853


Test dataset(data1): (4313, 1028)
Test dataset(data2): (4313, 1028)
Test dataset(data3): (458, 1028)
Test dataset(data4): (4825, 1028)
Test dataset(data5): (5474, 1028)
Test organism: {'data1': ['Escherichia coli K-12 BW25113'], 'data2': ['Escherichia coli K-12 MG1655'], 'data3': ['synthetic bacterium JCVI-Syn3A'], 'data4': ['Bacteroides thetaiotaomicron VPI-5482'], 'data5': ['Salmonella enterica subsp. enterica serovar Typhimurium str. 14028S']} 5
Splited test dataset(data1): torch.Size([4313, 1024]) torch.Size([4313])
Splited test dataset(data2): torch.Size([4313, 1024]) torch.Size([4313])
Splited test dataset(data3): torch.Size([458, 1024]) torch.Size([458])
Splited test dataset(data4): torch.Size([4825, 1024]) torch.Size([4825])
Splited test dataset(data5): torch.Size([5474, 1024]) torch.Size([5474])
Splited test dataset(all): torch.Size([19383, 1024]) torch.Size([19383])

===== Test model: emb_gen-bert-3_2 ====


  model.load_state_dict(torch.load(model_path + model_name + ".pt", map_location=device))


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,bert,3_2,data1,C018,Escherichia coli K-12 BW25113,270,430,3585,28,0.549519,0.893809,0.541082,0.385714,0.90604,0.99225,0.892902,0.950126,0.758463


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,bert,3_2,data2,C016,Escherichia coli K-12 MG1655,282,440,3583,8,0.578935,0.896128,0.557312,0.390582,0.972414,0.997772,0.890629,0.980791,0.839553


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,bert,3_2,data3,O046,synthetic bacterium JCVI-Syn3A,283,0,0,175,0.0,0.617904,0.763833,1.0,0.617904,0.0,0.0,,


  df_eval = pd.concat([df_eval, perform], ignore_index=True)


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,bert,3_2,data4,C048,Bacteroides thetaiotaomicron VPI-5482,219,617,3883,106,0.355445,0.850155,0.377261,0.261962,0.673846,0.973427,0.862889,0.851592,0.484393


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,bert,3_2,data5,C050,Salmonella enterica subsp. enterica serovar Ty...,39,730,4631,74,0.085505,0.853124,0.088435,0.050715,0.345133,0.984272,0.863831,0.717901,0.042231


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,bert,3_2,test_all,C018+C016+O046+C048+C050,Escherichia coli K-12 BW25113+Escherichia coli...,1093,2217,15682,391,0.432902,0.865449,0.455987,0.330211,0.736523,0.975673,0.876138,0.880059,0.490736



>>>> emb_gen-t5 <<<<


'Raw data:'

Unnamed: 0,file_id,organism,locus_tag,ess,0,1,2,3,4,5,...,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
0,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0001,0,0.018239,-0.030508,0.078645,0.082283,0.127463,-0.197618,...,0.027051,0.021510,0.231686,-0.173741,0.079085,-0.236435,-0.114395,0.054386,-0.030780,-0.075774
1,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0002,0,0.043340,0.078926,0.026737,0.050184,0.004291,0.038041,...,-0.023108,-0.003255,-0.010074,-0.010509,0.021057,-0.041092,-0.026289,0.020529,-0.029514,0.032310
2,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0003,0,0.083051,0.047820,0.007012,0.029369,-0.004415,0.054288,...,-0.021937,-0.035736,0.000474,-0.025479,0.003724,-0.044224,-0.078591,-0.024167,-0.032960,-0.003994
3,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0004,0,0.048851,0.071860,0.020825,0.042091,-0.002787,0.009612,...,0.005614,-0.032511,-0.037446,-0.013102,0.030071,0.009861,-0.022703,-0.002080,-0.011043,0.036730
4,C050,Salmonella enterica subsp. enterica serovar Ty...,STM14_0005,0,0.035259,0.042339,0.024604,-0.025343,-0.006228,0.043092,...,-0.020996,0.023913,-0.012189,-0.067514,0.045659,0.010079,-0.014732,-0.013761,-0.029590,0.063286
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19378,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0913,1,0.020563,0.107576,0.027025,0.052937,-0.012762,0.021742,...,0.013594,0.028046,0.010028,-0.060973,0.093713,-0.003019,-0.053488,-0.045007,-0.043997,-0.002393
19379,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0918,1,0.057506,-0.014071,-0.033472,0.067503,-0.016533,0.091051,...,-0.048624,0.031121,0.011467,-0.048133,-0.033131,-0.060800,-0.055844,-0.098644,-0.041724,0.014361
19380,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0930,1,0.003938,-0.011171,0.027588,0.017686,-0.010909,0.041780,...,-0.008057,0.051703,0.032014,-0.148301,0.106521,-0.029680,-0.032345,-0.021705,0.096927,0.099435
19381,O046,synthetic bacterium JCVI-Syn3A,JCVISYN3A_0931,1,0.064550,0.037385,-0.023904,-0.001969,-0.006216,0.048084,...,-0.037994,-0.018542,-0.014736,-0.102566,0.023649,-0.022915,-0.034404,-0.085372,0.044055,0.009911


Test dataset(data1): (4313, 1028)
Test dataset(data2): (4313, 1028)
Test dataset(data3): (458, 1028)
Test dataset(data4): (4825, 1028)
Test dataset(data5): (5474, 1028)
Test organism: {'data1': ['Escherichia coli K-12 BW25113'], 'data2': ['Escherichia coli K-12 MG1655'], 'data3': ['synthetic bacterium JCVI-Syn3A'], 'data4': ['Bacteroides thetaiotaomicron VPI-5482'], 'data5': ['Salmonella enterica subsp. enterica serovar Typhimurium str. 14028S']} 5
Splited test dataset(data1): torch.Size([4313, 1024]) torch.Size([4313])
Splited test dataset(data2): torch.Size([4313, 1024]) torch.Size([4313])
Splited test dataset(data3): torch.Size([458, 1024]) torch.Size([458])
Splited test dataset(data4): torch.Size([4825, 1024]) torch.Size([4825])
Splited test dataset(data5): torch.Size([5474, 1024]) torch.Size([5474])
Splited test dataset(all): torch.Size([19383, 1024]) torch.Size([19383])

===== Test model: emb_gen-t5-3_2 ====


  model.load_state_dict(torch.load(model_path + model_name + ".pt", map_location=device))


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,t5,3_2,data1,C018,Escherichia coli K-12 BW25113,269,372,3643,29,0.577528,0.907025,0.57295,0.419657,0.902685,0.992102,0.907347,0.94759,0.696496


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,t5,3_2,data2,C016,Escherichia coli K-12 MG1655,283,369,3654,7,0.618118,0.912822,0.600849,0.434049,0.975862,0.998088,0.908277,0.981808,0.825065


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,t5,3_2,data3,O046,synthetic bacterium JCVI-Syn3A,305,0,0,153,0.0,0.665939,0.799476,1.0,0.665939,0.0,0.0,,


  df_eval = pd.concat([df_eval, perform], ignore_index=True)


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,t5,3_2,data4,C048,Bacteroides thetaiotaomicron VPI-5482,257,529,3971,68,0.456933,0.876269,0.462646,0.326972,0.790769,0.983164,0.882444,0.918049,0.616563


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,t5,3_2,data5,C050,Salmonella enterica subsp. enterica serovar Ty...,36,651,4710,77,0.084616,0.867008,0.09,0.052402,0.318584,0.983915,0.878567,0.781696,0.047115


Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,t5,3_2,test_all,C018+C016+O046+C048+C050,Escherichia coli K-12 BW25113+Escherichia coli...,1150,1921,15978,334,0.486138,0.883661,0.50494,0.374471,0.774933,0.979524,0.892676,0.900482,0.525443


'Model performance:'

Unnamed: 0,emb,fcnn,testset,file,organism,tp,fp,tn,fn,mcc,acc,f1,prc,rec,npv,tnr,auc-roc,auc-pr
0,clstm,3_2,data1,C018,Escherichia coli K-12 BW25113,233,229,3786,65,0.594411,0.931834,0.613158,0.504329,0.781879,0.983121,0.942964,0.910965,0.586572
1,clstm,3_2,data2,C016,Escherichia coli K-12 MG1655,241,240,3783,49,0.613699,0.932993,0.625162,0.50104,0.831034,0.987213,0.940343,0.947629,0.632228
2,clstm,3_2,data3,O046,synthetic bacterium JCVI-Syn3A,112,0,0,346,0.0,0.244541,0.392982,1.0,0.244541,0.0,0.0,,
3,clstm,3_2,data4,C048,Bacteroides thetaiotaomicron VPI-5482,282,107,4393,43,0.776919,0.968912,0.789916,0.724936,0.867692,0.990307,0.976222,0.975421,0.818815
4,clstm,3_2,data5,C050,Salmonella enterica subsp. enterica serovar Ty...,78,493,4868,35,0.278314,0.903544,0.22807,0.136602,0.690265,0.992862,0.90804,0.891886,0.107095
5,clstm,3_2,test_all,C018+C016+O046+C048+C050,Escherichia coli K-12 BW25113+Escherichia coli...,946,1069,16830,538,0.50333,0.917092,0.540726,0.469479,0.637466,0.969023,0.940276,0.883538,0.461305
6,esm2,3_2,data1,C018,Escherichia coli K-12 BW25113,269,519,3496,29,0.507602,0.872942,0.495396,0.341371,0.902685,0.991773,0.870735,0.942126,0.701041
7,esm2,3_2,data2,C016,Escherichia coli K-12 MG1655,280,517,3506,10,0.540069,0.877811,0.515179,0.351317,0.965517,0.997156,0.871489,0.9755,0.784774
8,esm2,3_2,data3,O046,synthetic bacterium JCVI-Syn3A,259,0,0,199,0.0,0.565502,0.722455,1.0,0.565502,0.0,0.0,,
9,esm2,3_2,data4,C048,Bacteroides thetaiotaomicron VPI-5482,292,622,3878,33,0.486274,0.864249,0.471348,0.319475,0.898462,0.991562,0.861778,0.937444,0.613368
