#### Setting environments

In [1]:
import torch, gc, itertools
import torch.nn as nn
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score,\
f1_score, precision_score, recall_score, roc_auc_score, average_precision_score

In [None]:
# Set options
embed_ver = ["clstm", "esm2", "bert", "t5"]
data_path = "../data/test_exam"
model_path = "../models/classifier_ensem/"
result_path = "../result/"

In [None]:
layer_num = 3
unit_decrease = 2
batch_size = 256
col_str = ['file_id', 'organism', 'locus_tag', 'ess']

In [3]:
# Set data list for test dataset
ts_data = {
    "data1": ["C018"],  # "Escherichia coli K-12 BW25113"
    "data2": ["C016"],  # "Escherichia coli K-12 MG1655"
    "data3": ["O046"],  # "synthetic bacterium JCVI-Syn3A"
    "data4": ["C048"],  # Bacteroides thetaiotaomicron VPI-5482
    "data5": ["C050"]  # Salmonella enterica subsp. enterica serovar Typhimurium str. 14028S
}

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# Define function to record perfomance result
def record_perform(comb_ver, file_id, organ, y_real, y_conf, y_prd):
    y_real = y_real.cpu().numpy()
    y_conf = y_conf.cpu().numpy()
    y_prd = y_prd.cpu().numpy()
    
    if file_id != "O046":
        auc_roc = [roc_auc_score(y_real, y_conf)]
        auc_pr = [average_precision_score(y_real, y_conf)]
    else:
        auc_roc = None
        auc_pr = None
    
    tn, fp, fn, tp = confusion_matrix(y_real, y_prd).ravel()
    
    result = pd.DataFrame({
        "comb": [comb_ver],
        "file": [file_id],
        "organism": [organ],
        "tp": [tp],
        "fp": [fp],
        "tn": [tn],
        "fn": [fn],
        "mcc": [matthews_corrcoef(y_real, y_prd)],
        "acc": [accuracy_score(y_real, y_prd)],
        "f1": [f1_score(y_real, y_prd)],
        "prc": [precision_score(y_real, y_prd)],
        "rec": [recall_score(y_real, y_prd)],
        "npv": [precision_score(1 - y_real, 1 - y_prd)],
        "tnr": [recall_score(1 - y_real, 1 - y_prd)],
        "auc-roc": auc_roc,
        "auc-pr": auc_pr
    })

    return result


In [6]:
# Set model architecture
class Classifier(nn.Module):
    def __init__(self, input_size, num_layers, unit_decrease):
        super(Classifier, self).__init__()
        layers = [nn.BatchNorm1d(input_size), nn.Dropout(0.5)]
        in_dim = input_size
        out_dim = 1024
        for i in range(num_layers):            
            out_dim = max(2, out_dim // unit_decrease)
            layers.append(nn.Linear(in_dim, out_dim))
            self.initialize_weights(layers[-1])
            layers.append(nn.GELU())
            in_dim = out_dim
        layers.append(nn.Linear(out_dim, 1))
        self.cls_block = nn.Sequential(*layers)
        
    def initialize_weights(self, layer):
        nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='linear')
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)
    
    def forward(self, x):
        return self.cls_block(x)

#### Evaluation

In [None]:
dfs = [(e_ver, pd.read_csv(data_path + f"data_emb-{e_ver}.csv")) for e_ver in embed_ver]

In [None]:
df_eval = pd.DataFrame()

for r in range(2, len(dfs) + 1):
    combs = list(itertools.combinations(dfs, r))
    
    for comb in combs:
        comb_ver = "_".join([df[0] for df in comb])
        print(f"\n>>>> {comb_ver} <<<<")
        
        # merge dataset
        data = comb[0][1]
        for df in comb[1:]:
            data = pd.merge(data, df[1], on=col_str, suffixes=("", f"_{df[0]}"))
        
        display("Raw data:", data)
        
        # calculate mean of confidences
        col_num = [col for col in data.columns if col not in col_str]
    
        # get test datasets
        loc_ts = {}
        data_ts = {}
        org_ts = {}
        for ts_ver, ids in ts_data.items():
            # get test sample locations
            loc_ts[ts_ver] = data['file_id'].isin(ids)
            # get test samples
            data_ts[ts_ver] = data[loc_ts[ts_ver]]
            org = []
            # get test organism list
            for i in ids:
                organ = data_ts[ts_ver]['organism'][data_ts[ts_ver]['file_id'] == i].to_list()
                if len(organ) > 0:
                    org.append(organ[0])
            org_ts[ts_ver] = org    
            print("Test dataset(" + ts_ver + "):", data_ts[ts_ver].shape)
        print("Test organism:", org_ts, len(org_ts))
        
        # split info.& inputs & labels of the test datasets
        info_ts = {}
        y_ts = {}
        test_loader = {}
        for ts_ver, df in data_ts.items():
            info_ts[ts_ver] = df[col_str]
            X_ts = torch.tensor(df.iloc[:, len(col_str):].astype('float32').values)
            y_ts[ts_ver] = torch.tensor(df['ess'].astype('float32').values)
            print("Splited test dataset(" + ts_ver + "):", X_ts.shape, y_ts[ts_ver].shape)                    
            # generate dataloader by the test datasets
            dataset_ts = TensorDataset(X_ts, y_ts[ts_ver])
            test_loader[ts_ver] = DataLoader(dataset_ts, batch_size=batch_size, shuffle=False)
        
        ## Test model ##
        # set model name
        model_name = f"cls-{comb_ver}"
        print(f"\n===== Test model: {model_name} ====")
        # generate model instance
        model = Classifier(
            input_size=X_ts.shape[-1],
            num_layers=layer_num,
            unit_decrease=unit_decrease
        ).to(device)

        # load model weight
        model.load_state_dict(torch.load(model_path + model_name + ".pt", map_location=device))
        model.eval()
        
        # model evaluations by test dataset
        df_pred = pd.DataFrame()
        total_result = {key: [] for key in col_str + ['logit', 'conf']}

        for ts_ver, ids in ts_data.items():
            results = {key: [] for key in total_result.keys()}
            with torch.no_grad():
                for X_batch, y_batch in test_loader[ts_ver]:
                    X_batch = X_batch.to(device)
                    y_batch = y_batch.to(device)
                    # prediction
                    preds = model(X_batch).squeeze()
                    # gather the result
                    results['logit'].extend(preds.cpu().tolist())
            # gather testset info.
            for key in col_str:
                results[key].extend(info_ts[ts_ver][key].tolist())
            
            gc.collect()

            # convert logits to confidences & classes
            prd_conf = torch.sigmoid(results['logit'])
            prd_cls = (prd_conf >= 0.5).int()

            # gather result of the predicted essentiality
            for key, val in results.items():
                total_result[key].extend(val)
            pred_ts = pd.DataFrame({key: results[key] for key in col_str + ['conf']})
            df_pred = pd.concat([df_pred, pred_ts], ignore_index=True)
            
            # get evaluation row by testset
            eval_ts = record_perform(
                comb_ver=f"{comb_ver}",
                test_ver=ts_ver,
                file_id="+".join(ids),
                organ="+".join(org_ts[ts_ver]),
                y_real=results['ess'],
                y_conf=prd_conf,
                y_prd=prd_cls,
            )
            df_eval = pd.concat([df_eval, eval_ts], ignore_index=True)
            print(f"- Test in {ts_ver} was done.")

        # save the model prediction result
        df_pred.to_csv(f"{result_path}prd-embed_ensem/{model_name}.csv", index=False)

        # convert logits to confidences & classes
        prd_conf = torch.sigmoid(total_result['logit'])
        prd_cls = (prd_conf >= 0.5).int()

        # get total mean row
        eval_ts = record_perform(
            comb_ver=f"{comb_ver}",
            test_ver="test_all",
            file_id="total",
            organ="all",
            y_real=total_result['ess'],
            y_conf=prd_conf,
            y_prd=prd_cls
        )
        df_eval = pd.concat([df_eval, eval_ts], ignore_index=True)

# save the model perfomance result
df_eval.to_csv(f"{result_path}eval-embed_ensem.csv", index=False)
display("Model performance:", df_eval)