In [1]:
import pandas as pd
import numpy as np
import rdkit.Chem as Chem
from stratified_continious_split import scsplit, ContinuousStratifiedKFold
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
from utils import standardize_df, add_features, set_seeds
from sklearn.model_selection import GroupShuffleSplit, GroupKFold
from ray import tune
from ray.tune.search.optuna import OptunaSearch
from feature_pipeline import get_pipeline, get_pipeline_param_space
from utils import calc_scores

set_seeds(42)

In [2]:
def get_scaffold(smi) -> str:
    """
    Generate the Bemis-Murcko scaffold for a given molecule.

    :param smi: A SMILES string or an RDKit molecule object representing the
                molecule for which to generate the scaffold.
    :return: A SMILES string representing the Bemis-Murcko scaffold of the input
             molecule. If the scaffold cannot be generated, the input SMILES
             string is returned.
    """
    scaffold = MurckoScaffoldSmiles(smi)
    if len(scaffold) == 0:
        scaffold = smi
    return scaffold

In [3]:
def dedupplicate_parasite(df, duplicate_selection_criteria):
    uniques = (
        df.groupby("inchi")
        .filter(lambda x: len(x) == 1)[["inchi", "par_inhibition_per"]]
        .reset_index(drop=True)
    )
    duplicates = df.groupby("inchi").filter(lambda x: len(x) > 1).reset_index(drop=True)
    deduped = (
        # The inchi keys are the index after groupby; do not drop them but reset the index
        duplicates.groupby("inchi").agg(duplicate_selection_criteria).reset_index()
    )

    return pd.concat([uniques, deduped]).reset_index(drop=True)

In [4]:
derbyshire_df = pd.read_csv('./datasets/Derbyshire_reg_chembl_scores_corrected.csv').reset_index(drop=True)
derbyshire_df = derbyshire_df[['Compound SMILES', 'parasite % average']]

derbyshire_df = derbyshire_df.rename({
    'Compound SMILES': 'smiles',
    'parasite % average': 'par_inhibition_per',
}, axis=1)

derbyshire_df = standardize_df(derbyshire_df)
derbyshire_df = derbyshire_df[~derbyshire_df['inchi'].isna()].reset_index(drop=True)

duplicate_selection_criteria = {'par_inhibition_per': np.min}
derbyshire_df = dedupplicate_parasite(derbyshire_df, duplicate_selection_criteria)
derbyshire_df = add_features(derbyshire_df).dropna(axis=0).reset_index(drop=True)

clusters, _ = pd.factorize(
    derbyshire_df['mol']
        .map(Chem.MolToSmiles)
        .map(get_scaffold)
)
clusters = pd.Series(clusters)

derbyshire_df["inhibit_parasite"] = (derbyshire_df["par_inhibition_per"] <= 15.0).astype(float)

In [5]:
splitter = GroupShuffleSplit(n_splits=1, random_state=42)

X = derbyshire_df.drop(['inhibit_parasite', 'par_inhibition_per'], axis=1)
y = derbyshire_df['inhibit_parasite']
groups = clusters

train_val_idxs, test_idxs = next(splitter.split(X, y, groups))

X_train_val = X.loc[train_val_idxs].reset_index(drop=True)
y_train_val = y.loc[train_val_idxs].reset_index(drop=True)
groups_train_val = groups.loc[train_val_idxs].reset_index(drop=True)

X_test = X.loc[test_idxs].reset_index(drop=True)
y_test = y.loc[test_idxs].reset_index(drop=True)
groups_test = groups.loc[test_idxs].reset_index(drop=True)

In [6]:
def objective(config, X_train_val, y_train_val, groups_train_val):
    pipeline = get_pipeline()
    pipeline = pipeline.set_params(**config)

    kfold = GroupKFold(n_splits=3, shuffle=True, random_state=42)
    tally = []
    for train_idxs, val_idxs in kfold.split(X_train_val, y_train_val, groups_train_val):
        X_train, y_train = X_train_val.loc[train_idxs], y_train_val.loc[train_idxs]
        X_val, y_val = X_train_val.loc[val_idxs], y_train_val.loc[val_idxs]

        pipeline = pipeline.fit(X_train, y_train)
        y_val_pred_prob = pipeline.predict_proba(X_val)[:, 1]
        y_val_pred = np.where(y_val_pred_prob > 0.5, 1.0, 0.0)

        scores = calc_scores(y_val_pred_prob, y_val_pred, y_val)
        tally.append(scores)

    scores_tally = pd.DataFrame.from_records(tally)
    median_scores = scores_tally.median()
    return median_scores.to_dict()

In [7]:
tuner = tune.Tuner(
    tune.with_resources(
        tune.with_parameters(
            objective, 
            X_train_val=X_train_val, 
            y_train_val=y_train_val,
            groups_train_val=groups_train_val
        ),
        {'cpu': 4}
    ),
    param_space=get_pipeline_param_space(),
    tune_config=tune.TuneConfig(
        search_alg=OptunaSearch(seed=42),
        num_samples=200,
        metric="average_precision",
        mode="max",
    )
)

In [8]:
results = tuner.fit()

0,1
Current time:,2025-01-30 05:02:35
Running for:,00:08:32.75
Memory:,3.5/117.9 GiB

Trial name,status,loc,featureunion__morgan fp__n_bits,...line__correlation threshold__threshold,...ipeline__variance threshold__threshold,lgbmclassifier__cols ample_bytree,lgbmclassifier__max_ depth,lgbmclassifier__min_ child_samples,lgbmclassifier__n_es timators,lgbmclassifier__num_ leaves,lgbmclassifier__reg_ alpha,lgbmclassifier__reg_ lambda,lgbmclassifier__scal e_pos_weight,lgbmclassifier__subs ample,iter,total time (s),accuracy,balanced_accuracy,f1
objective_2eb069e9,TERMINATED,10.128.15.207:4308,512,0.990143,0.143635,0.118526,13,97,251,154,2.55026e-08,0.0115673,88,0.737265,1,2.51629,0.624877,0.638821,0.214876
objective_de8dbb8d,TERMINATED,10.128.15.207:4432,4096,0.860848,0.0958511,0.806658,7,23,612,95,9.47233e-08,1.10921e-06,66,0.510463,1,10.9256,0.901356,0.550277,0.181818
objective_9bcf43f5,TERMINATED,10.128.15.207:4710,2048,0.834105,0.201886,0.209834,16,52,913,175,1.35611e-06,4.82731e-08,32,0.496137,1,9.89661,0.911221,0.529971,0.121951
objective_44c10f2a,TERMINATED,10.128.15.207:4831,4096,0.862342,0.215631,0.929687,19,13,92,229,0.00266643,0.0377131,44,0.63811,1,5.42823,0.91492,0.529358,0.12766
objective_ee567c0e,TERMINATED,10.128.15.207:4928,512,0.85427,0.147169,0.79502,36,23,824,20,9.69325e-08,0.00412457,30,0.988198,1,11.3652,0.908755,0.5263,0.114943
objective_bbe606cf,TERMINATED,10.128.15.207:5028,4096,0.954254,0.232252,0.392665,26,74,892,18,0.000230721,2.07151e-06,74,0.379884,1,9.90356,0.89889,0.524832,0.11236
objective_357ac669,TERMINATED,10.128.15.207:5133,2048,0.942649,0.0798986,0.128286,45,65,533,8,4.56173e-05,9.83529e-06,52,0.197102,1,4.29002,0.755348,0.621486,0.246914
objective_6441a3ab,TERMINATED,10.128.15.207:5233,512,0.882077,0.112323,0.884315,29,81,897,207,1.34446e-07,0.0322021,42,0.670063,1,11.0125,0.905055,0.555587,0.189474
objective_c5bb3f39,TERMINATED,10.128.15.207:5330,4096,0.979218,0.25186,0.47567,47,26,370,3,0.00532235,0.0105953,38,0.559673,1,4.0693,0.409674,0.573827,0.176309
objective_3ac2869a,TERMINATED,10.128.15.207:5425,2048,0.903758,0.130801,0.133198,17,62,98,78,5.7873e-07,3.0251e-05,66,0.356356,1,2.45594,0.676209,0.642327,0.211538


2025-01-30 05:02:35,190	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/rahul_e_dev/ray_results/objective_2025-01-30_04-53-58' in 0.0479s.
2025-01-30 05:02:35,243	INFO tune.py:1041 -- Total run time: 512.83 seconds (512.70 seconds for the tuning loop).


In [9]:
from sklearn.model_selection import TunedThresholdClassifierCV

best_result = results.get_best_result()
pipeline = get_pipeline()
config = {k:v for k,v in best_result.config.items() if k in pipeline.get_params()}
pipeline.set_params(**config)

pipeline = pipeline.fit(X_train_val, y_train_val)

kfold = ContinuousStratifiedKFold(n_splits=3)
th_pipeline = TunedThresholdClassifierCV(pipeline, scoring='f1', cv=kfold, thresholds=500)
th_pipeline.fit(X_train_val, y_train_val)

y_test_pred_prob = th_pipeline.predict_proba(X_test)[:, 1]
y_test_pred = th_pipeline.predict(X_test)
test_scores = calc_scores(y_test_pred_prob, y_test_pred, y_test.to_numpy())
test_scores = {k: np.round(v, 3) for k, v in test_scores.items()}
test_scores



{'accuracy': 0.868,
 'balanced_accuracy': 0.648,
 'f1': 0.341,
 'precision': 0.31,
 'recall': 0.379,
 'roc_auc': 0.741,
 'average_precision': 0.314,
 'specificity': 0.916,
 'sensitivity': 0.379,
 'test_delong_auc': 0.741,
 'lb': 0.671,
 'ub': 0.812}

In [10]:
best_result.metrics_dataframe

Unnamed: 0,accuracy,balanced_accuracy,f1,precision,recall,roc_auc,average_precision,specificity,sensitivity,test_delong_auc,...,config/lgbmclassifier__reg_lambda,config/lgbmclassifier__num_leaves,config/lgbmclassifier__subsample,config/lgbmclassifier__colsample_bytree,config/lgbmclassifier__min_child_samples,config/lgbmclassifier__n_jobs,config/lgbmclassifier__random_state,config/lgbmclassifier__scale_pos_weight,config/lgbmclassifier__n_estimators,config/lgbmclassifier__max_depth
0,0.918619,0.531587,0.125,0.416667,0.074627,0.727096,0.235004,0.990629,0.074627,0.727096,...,0.000499,131,0.156771,0.194216,14,4,42,68,469,28
