# Ablation experiment of environment factors with XGBoost (logkcat/KM)
## (1) pH temperature
## (2) logP MW
## (3) organism

In [1]:
import pandas as pd
import os
current_dir = os.getcwd()
df_input = pd.read_pickle(f'{current_dir}/../data_process/dataset/df_all_log_transformed.pkl')
df_input.head()

Unnamed: 0,ec,organism,uniprot,substrate,smiles,sequence,type,ph,t,esm2,...,unirep,prott5,prost5,molebert,transsmiles,logkm,logkcat,logkcatkm,logp,mw
0,3.5.5.1,Saccharolobus solfataricus,P95896,trichloroacetonitrile,C(#N)C(Cl)(Cl)Cl,MGIKLPTLEDLREISKQFNLDLEDEELKSFLQLLKLQLESYERLDS...,wild,7.4,70.0,"[0.07309095, -0.085310504, 0.03223636, -0.0094...",...,"[0.004618997, 0.04178079, 0.039039593, -0.0533...","[0.0620778725, 0.0053198822, 0.026737025, -0.0...","[0.0229359791, -0.0089250933, -0.0310355425000...","[-0.0235576797, -0.1898318082, -0.005378013, 0...","[-0.1102231815, -0.2566757202, 0.2837018669, 0...",-21.416413,-4.60517,,1.88018,144.388
1,1.21.99.4,Homo sapiens,,L-thyroxine,C1=C(C=C(C(=C1I)OC2=CC(=C(C(=C2)I)O)I)I)CC(C(=...,MGLPQPGLWLKRLWVLLEVAVHVVVGKVLLILFPDRVKRNILAMGE...,mutant,7.5,37.0,"[0.054237492, -0.04574185, 0.008709021, 0.0387...",...,"[0.009288959, 0.11994403, 0.08886064, -0.00835...","[0.0137849757, 0.0193469338, 0.0360138603, 0.0...","[-0.0300230943, -0.0246387329, -0.0323411487, ...","[0.11776212600000001, 0.24018001560000002, -0....","[-0.049280483300000004, -0.2168657631, 0.30448...",-20.192638,,,4.5573,776.872
2,1.21.99.4,Homo sapiens,,L-thyroxine,C1=C(C=C(C(=C1I)OC2=CC(=C(C(=C2)I)O)I)I)CC(C(=...,MGLPQPGLWLKRLWVLLEVAVHVVVGKVLLILFPDRVKRNILAMGE...,wild,7.5,37.0,"[0.054237492, -0.04574185, 0.008709021, 0.0387...",...,"[0.009288959, 0.11994403, 0.08886064, -0.00835...","[0.0137849757, 0.0193469338, 0.0360138603, 0.0...","[-0.0300230943, -0.0246387329, -0.0323411487, ...","[0.11776212600000001, 0.24018001560000002, -0....","[-0.049280483300000004, -0.2168657631, 0.30448...",-19.658555,,,4.5573,776.872
3,3.5.5.1,Saccharolobus solfataricus,P95896,Cinnamonitrile,C1=CC=C(C=C1)C=CC#N,MGIKLPTLEDLREISKQFNLDLEDEELKSFLQLLKLQLESYERLDS...,mutant,7.4,70.0,"[0.07308519, -0.08452837, 0.029972142, -0.0119...",...,"[0.0046866684999999995, 0.04164683, 0.03819374...","[0.0622685216, 0.0037881299, 0.027366610200000...","[0.022653413900000002, -0.008009411400000001, ...","[0.1382588446, -0.2295262814, -0.0154548008, 0...","[-0.07580477000000001, -0.26762363310000004, 0...",-18.45114,-4.135167,,2.22338,129.162
4,3.5.5.1,Saccharolobus solfataricus,P95896,Malononitrile,C(C#N)C#N,MGIKLPTLEDLREISKQFNLDLEDEELKSFLQLLKLQLESYERLDS...,wild,7.4,70.0,"[0.07309095, -0.085310504, 0.03223636, -0.0094...",...,"[0.004618997, 0.04178079, 0.039039593, -0.0533...","[0.0620778725, 0.0053198822, 0.026737025, -0.0...","[0.0229359791, -0.0089250933, -0.0310355425000...","[0.4168405533, -0.2187459022, -0.1292698383, 0...","[-0.0909005329, -0.2873343229, 0.2866819799, 0...",-16.821293,-1.966113,,0.42366,66.063


In [2]:
import os.path
import json
import pandas as pd
import xgboost
import numpy as np
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from scipy.stats import pearsonr
from hyperopt import fmin, tpe, hp, Trials, space_eval
from copy import deepcopy
import torch
import torch.nn as nn


def return_scores(y_true, y_pred):
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    pcc = pearsonr(y_true, y_pred)[0]

    return rmse, mae, r2, pcc

def return_xgb_x_y(df_filtered):
    y = df_filtered[label_name].values
    mask = ~np.isnan(y)

    # factors
    auxiliary_data = []
    if use_t_ph_embedding:
        ph = df_filtered['ph'].values.reshape(-1, 1)
        t = df_filtered['t'].values.reshape(-1, 1)
        auxiliary_data.append(ph)
        auxiliary_data.append(t)

    if use_mw_logp:
        mw = df_filtered['mw'].values.reshape(-1, 1)
        logp = df_filtered['logp'].values.reshape(-1, 1)
        auxiliary_data.append(mw)
        auxiliary_data.append(logp)

    if use_organism:
        organism = df_filtered['organism'].astype('category').cat.codes.values
        organism_tensor = torch.tensor(organism, dtype=torch.long)
        embedding = nn.Embedding(organism_num_classes, organism_embedding_dim)
        organism_embedded = embedding(organism_tensor).detach().numpy()
        auxiliary_data.append(organism_embedded)

    protein_data = np.array(df_filtered[protein_column].tolist())
    substrate_data = np.array(df_filtered[substrate_column].tolist())
    x = np.hstack([protein_data, substrate_data] + auxiliary_data)

    return x[mask], y[mask]

def search_xgb(params):
    print(params)
    temp_params = deepcopy(params)
    temp_params.update({"device": "cuda", "eval_metric": ["rmse"], "sampling_method": "gradient_based"})
    num_rounds = temp_params.pop('num_rounds')

    val_scores_list = []
    for train_index, val_index in kf.split(df_train_val):
        df_train = df_train_val.iloc[train_index]
        df_val = df_train_val.iloc[val_index]

        train_x, train_y = return_xgb_x_y(df_train)
        val_x, val_y = return_xgb_x_y(df_val)

        # DMatrix
        m_train = xgboost.DMatrix(train_x, label=train_y)
        m_val = xgboost.DMatrix(val_x, label=val_y)
        eval_list = [(m_train, 'train'), (m_val, 'val')]

        # train
        model = xgboost.train(temp_params, m_train, num_rounds, evals=eval_list, verbose_eval=False, early_stopping_rounds=60)

        # val
        val_predicted = model.predict(xgboost.DMatrix(val_x))
        val_scores = return_scores(val_y, val_predicted)
        val_scores_list.append(val_scores)

    val_scores_mean = np.mean(val_scores_list, axis=0)
    print(f"[Val_mean] rmse {val_scores_mean[0]:.3f} mae {val_scores_mean[1]:.3f} r2 {val_scores_mean[2]:.3f} pcc {val_scores_mean[3]:.3f}")

    return val_scores_mean[0]

def search_best_param(max_evals):
    space = {
        "learning_rate": hp.uniform("learning_rate", 0.02, 0.1),
        "max_depth": hp.randint("max_depth", 6, 9),
        "reg_lambda": hp.uniform("reg_lambda", 0, 3),
        "reg_alpha": hp.uniform("reg_alpha", 0, 3),
        "max_delta_step": hp.uniform("max_delta_step", 0, 4),
        "min_child_weight": hp.uniform("min_child_weight", 10, 15),
        "num_rounds": hp.randint("num_rounds", 1500, 3000),
        "subsample": hp.uniform("subsample", 0, 1),
        "eta": hp.uniform("eta", 0.01, 0.2),
    }

    trials = Trials()
    print(f'[Info] Starting parameter search...')
    best_params = fmin(fn=search_xgb, space=space, algo=tpe.suggest, max_evals=max_evals, trials=trials)

    best_params['max_depth'] = int(best_params['max_depth'])
    best_params['num_rounds'] = int(best_params['num_rounds'])

    best_params = space_eval(space, best_params)

    # to json
    with open(params_json_path, 'w') as json_file:
        json.dump(best_params, json_file)

    return best_params

def return_path_name():
    name_parts = []
    if use_t_ph_embedding:
        name_parts.append("t_ph")
    if use_mw_logp:
        name_parts.append("mw_logp")
    if use_organism:
        name_parts.append("organism")

    return "_".join(name_parts) if name_parts else "default"


# main
protein_column = 'prott5'
substrate_column = 'molebert'
label_name = 'logkcatkm'
random_state = 66
search_max_evals = 60

# df_input = pd.read_pickle('../data_process/dataset/folds5_lite/df_lite.pkl')
df_train_val, df_test = train_test_split(df_input, test_size=0.2, random_state=random_state)
kf = KFold(n_splits=5, shuffle=True, random_state=random_state)

# organism str to embedding
organism_embedding_dim = 8
organism_num_classes = df_input['organism'].nunique()
print(f'Number of organism classes: {organism_num_classes}')

# save results
results = []
cv_results = []

for use_t_ph_embedding in [True, False]:
    for use_mw_logp in [True, False]:
        for use_organism in [True, False]:
            val_scores_list = []
            test_scores_list = []
            print(f"use_t_ph_embedding [{use_t_ph_embedding}] use_mw_logp [{use_mw_logp}] use_organism [{use_organism}]")

            # search best params
            file_name = return_path_name()
            if file_name == 'default':
                params_json_path = f'{current_dir}/../embeddings_ablation/model_dict/esm2_molebert.json'
            else: params_json_path = f'{current_dir}/model_dict/xgb_{file_name}_params.json'
            if os.path.exists(params_json_path):
                with open(params_json_path) as json_file:
                    params = json.load(json_file)
            else:
                params = search_best_param(search_max_evals)

            # train
            _params = deepcopy(params)
            _params.update({"device": "cuda", "eval_metric": ["rmse"], "sampling_method": "gradient_based"})
            num_rounds = _params.pop('num_rounds')
            for fold, (train_index, val_index) in enumerate(kf.split(df_train_val), start=1):
                print(f"Fold: {fold}/5")
                df_train = df_train_val.iloc[train_index]
                df_val = df_train_val.iloc[val_index]

                train_x, train_y = return_xgb_x_y(df_train)
                val_x, val_y = return_xgb_x_y(df_val)
                test_x, test_y = return_xgb_x_y(df_test)

                # DMatrix
                m_train = xgboost.DMatrix(train_x, label=train_y)
                m_val = xgboost.DMatrix(val_x, label=val_y)
                eval_list = [(m_train, 'train'), (m_val, 'val')]

                # train
                model = xgboost.train(_params, m_train, num_rounds, evals=eval_list, verbose_eval=1000, early_stopping_rounds=60)

                # val
                val_predicted = model.predict(xgboost.DMatrix(val_x))
                val_scores = return_scores(val_y, val_predicted)
                val_scores_list.append(val_scores)

                # test
                test_predicted = model.predict(xgboost.DMatrix(test_x))
                test_scores = return_scores(test_y, test_predicted)
                test_scores_list.append(test_scores)

                # fold
                cv_results.append([
                    file_name, fold,
                    val_scores[0], val_scores[1], val_scores[2], val_scores[3],
                    test_scores[0], test_scores[1], test_scores[2], test_scores[3]
                ])

            # mean
            val_scores_mean = np.mean(val_scores_list, axis=0)
            test_scores_mean = np.mean(test_scores_list, axis=0)
            print(f"Dimension of x: {train_x.shape[1]}")
            print(f"[Val] rmse {val_scores_mean[0]:.4f} mae {val_scores_mean[1]:.4f} r2 {val_scores_mean[2]:.4f} pcc {val_scores_mean[3]:.4f} "
                  f"[Test] rmse {test_scores_mean[0]:.4f} mae {test_scores_mean[1]:.4f} r2 {test_scores_mean[2]:.4f} pcc {test_scores_mean[3]:.4f}\n")

            # 存入均值结果
            results.append([
                file_name,
                val_scores_mean[0], val_scores_mean[1], val_scores_mean[2], val_scores_mean[3],
                test_scores_mean[0], test_scores_mean[1], test_scores_mean[2], test_scores_mean[3]
            ])

# save
df_results = pd.DataFrame(results, columns=["Combination",
    "Val_RMSE", "Val_MAE", "Val_R2", "Val_PCC",
    "Test_RMSE", "Test_MAE", "Test_R2", "Test_PCC"])
df_results.to_excel(f"{current_dir}/results.xlsx", index=False)

# save cvs
df_cv_results = pd.DataFrame(cv_results, columns=[
    "Combination", "Fold",
    "Val_RMSE", "Val_MAE", "Val_R2", "Val_PCC",
    "Test_RMSE", "Test_MAE", "Test_R2", "Test_PCC"])
df_cv_results.to_excel(f"{current_dir}/cv_results.xlsx", index=False)
print("Results saved to results.xlsx and cv_results.xlsx")

Number of organism classes: 2008
use_t_ph_embedding [True] use_mw_logp [True] use_organism [True]
[Info] Starting parameter search...
  0%|          | 0/60 [00:00<?, ?trial/s, best loss=?]                                                      {'eta': 0.14710244918635948, 'learning_rate': 0.07231338052658877, 'max_delta_step': 1.1856462803706669, 'max_depth': 8, 'min_child_weight': 14.004724953626894, 'num_rounds': 2819, 'reg_alpha': 2.048361534437608, 'reg_lambda': 2.3175586260036702, 'subsample': 0.03708057689933486}
  0%|          | 0/60 [00:00<?, ?trial/s, best loss=?]

  from .autonotebook import tqdm as notebook_tqdm


                                                      [Val_mean] rmse 2.963 mae 2.235 r2 0.491 pcc 0.705
  0%|          | 0/60 [00:43<?, ?trial/s, best loss=?]  2%|▏         | 1/60 [00:43<42:59, 43.71s/trial, best loss: 2.963279061972921]                                                                               {'eta': 0.05424266799891317, 'learning_rate': 0.0865387780176784, 'max_delta_step': 0.2886305043147721, 'max_depth': 7, 'min_child_weight': 10.72576373271597, 'num_rounds': 2674, 'reg_alpha': 0.20486163205431518, 'reg_lambda': 0.9691773140914418, 'subsample': 0.6324017865184192}
  2%|▏         | 1/60 [00:43<42:59, 43.71s/trial, best loss: 2.963279061972921]                                                                               [Val_mean] rmse 2.632 mae 1.884 r2 0.599 pcc 0.775
  2%|▏         | 1/60 [01:52<42:59, 43.71s/trial, best loss: 2.963279061972921]  3%|▎         | 2/60 [01:52<56:42, 58.66s/trial, best loss: 2.6319207172815515]                        

In [3]:
df_results.head()

Unnamed: 0,Combination,Val_RMSE,Val_MAE,Val_R2,Val_PCC,Test_RMSE,Test_MAE,Test_R2,Test_PCC
0,t_ph_mw_logp_organism,2.639673,1.891232,0.596378,0.773096,2.699384,1.926086,0.572101,0.757171
1,t_ph_mw_logp,2.622327,1.877132,0.601666,0.776653,2.687673,1.915656,0.575805,0.760013
2,t_ph_organism,2.63761,1.891495,0.596986,0.773529,2.708555,1.927754,0.569182,0.75556
3,t_ph,2.623002,1.880794,0.601464,0.77672,2.693576,1.919265,0.573942,0.75901
4,mw_logp_organism,2.655514,1.91272,0.591456,0.769622,2.722368,1.947592,0.564789,0.75224


In [4]:
df_cv_results.head()

Unnamed: 0,Combination,Fold,Val_RMSE,Val_MAE,Val_R2,Val_PCC,Test_RMSE,Test_MAE,Test_R2,Test_PCC
0,t_ph_mw_logp_organism,1,2.695422,1.908581,0.592174,0.769704,2.705637,1.93557,0.57013,0.755947
1,t_ph_mw_logp_organism,2,2.639583,1.899736,0.59763,0.773905,2.723892,1.943633,0.564309,0.751876
2,t_ph_mw_logp_organism,3,2.577704,1.841716,0.606644,0.779422,2.681859,1.898447,0.577652,0.761163
3,t_ph_mw_logp_organism,4,2.655584,1.923808,0.582812,0.76414,2.700322,1.921301,0.571817,0.757335
4,t_ph_mw_logp_organism,5,2.630073,1.882321,0.602631,0.778306,2.685209,1.93148,0.576596,0.759531
