In [1]:
import pandas as pd
import numpy as np
import rdkit.Chem as Chem
from stratified_continious_split import scsplit, ContinuousStratifiedKFold
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
from utils import standardize_df, add_features, set_seeds
from sklearn.model_selection import GroupShuffleSplit, GroupKFold
from ray import tune
from ray.tune.search.optuna import OptunaSearch
from feature_pipeline import get_pipeline, get_pipeline_param_space
from utils import calc_scores

In [13]:
RANDOM_SEED = 2
set_seeds(RANDOM_SEED)

In [14]:
def get_scaffold(smi) -> str:
    """
    Generate the Bemis-Murcko scaffold for a given molecule.

    :param smi: A SMILES string or an RDKit molecule object representing the
                molecule for which to generate the scaffold.
    :return: A SMILES string representing the Bemis-Murcko scaffold of the input
             molecule. If the scaffold cannot be generated, the input SMILES
             string is returned.
    """
    scaffold = MurckoScaffoldSmiles(smi)
    if len(scaffold) == 0:
        scaffold = smi
    return scaffold

In [15]:
def dedupplicate_parasite(df, duplicate_selection_criteria):
    uniques = (
        df.groupby("inchi")
        .filter(lambda x: len(x) == 1)[["inchi", "chembl_model_score", "par_inhibition_per"]]
        .reset_index(drop=True)
    )
    duplicates = df.groupby("inchi").filter(lambda x: len(x) > 1).reset_index(drop=True)
    deduped = (
        # The inchi keys are the index after groupby; do not drop them but reset the index
        duplicates.groupby("inchi").agg(duplicate_selection_criteria).reset_index()
    )

    return pd.concat([uniques, deduped]).reset_index(drop=True)

In [16]:
derbyshire_df = pd.read_csv('./datasets/Derbyshire_reg_chembl_scores_corrected.csv').reset_index(drop=True)
derbyshire_df = derbyshire_df[['Compound SMILES', 'parasite % average', 'chembl_model_score']]

derbyshire_df = derbyshire_df.rename({
    'Compound SMILES': 'smiles',
    'parasite % average': 'par_inhibition_per',
}, axis=1)

derbyshire_df = standardize_df(derbyshire_df)
derbyshire_df = derbyshire_df[~derbyshire_df['inchi'].isna()].reset_index(drop=True)

duplicate_selection_criteria = {'par_inhibition_per': np.min, 'chembl_model_score': np.max}
derbyshire_df = dedupplicate_parasite(derbyshire_df, duplicate_selection_criteria)
derbyshire_df = add_features(derbyshire_df).dropna(axis=0).reset_index(drop=True)

clusters, _ = pd.factorize(
    derbyshire_df['mol']
        .map(Chem.MolToSmiles)
        .map(get_scaffold)
)
clusters = pd.Series(clusters)
# derbyshire_df['cluster'] = factorized

derbyshire_df["inhibit_parasite"] = (derbyshire_df["par_inhibition_per"] <= 15.0).astype(float)

In [17]:
splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)

X = derbyshire_df.drop(['inhibit_parasite', 'par_inhibition_per'], axis=1)
y = derbyshire_df['inhibit_parasite']
groups = clusters

train_val_idxs, test_idxs = next(splitter.split(X, y, groups))

X_train_val = X.loc[train_val_idxs].reset_index(drop=True)
y_train_val = y.loc[train_val_idxs].reset_index(drop=True)
groups_train_val = groups.loc[train_val_idxs].reset_index(drop=True)

X_test = X.loc[test_idxs].reset_index(drop=True)
y_test = y.loc[test_idxs].reset_index(drop=True)
groups_test = groups.loc[test_idxs].reset_index(drop=True)

In [18]:
def objective(config, X_train_val, y_train_val, groups_train_val):
    pipeline = get_pipeline(RANDOM_SEED)
    pipeline = pipeline.set_params(**config)

    kfold = GroupKFold(n_splits=3, shuffle=True, random_state=RANDOM_SEED)
    tally = []
    for train_idxs, val_idxs in kfold.split(X_train_val, y_train_val, groups_train_val):
        X_train, y_train = X_train_val.loc[train_idxs], y_train_val.loc[train_idxs]
        X_val, y_val = X_train_val.loc[val_idxs], y_train_val.loc[val_idxs]

        pipeline = pipeline.fit(X_train, y_train)
        y_val_pred_prob = pipeline.predict_proba(X_val)[:, 1]
        y_val_pred = np.where(y_val_pred_prob > 0.5, 1.0, 0.0)

        scores = calc_scores(y_val_pred_prob, y_val_pred, y_val)
        tally.append(scores)

    scores_tally = pd.DataFrame.from_records(tally)
    median_scores = scores_tally.median()
    return median_scores.to_dict()

In [19]:
tuner = tune.Tuner(
    tune.with_resources(
        tune.with_parameters(
            objective, 
            X_train_val=X_train_val, 
            y_train_val=y_train_val,
            groups_train_val=groups_train_val
        ),
        {'cpu': 4}
    ),
    param_space=get_pipeline_param_space(RANDOM_SEED),
    tune_config=tune.TuneConfig(
        search_alg=OptunaSearch(seed=RANDOM_SEED),
        num_samples=200,
        metric="average_precision",
        mode="max",
    )
)

In [20]:
results = tuner.fit()

0,1
Current time:,2025-02-03 03:46:42
Running for:,00:07:15.49
Memory:,5.4/117.9 GiB

Trial name,status,loc,featureunion__morgan fp__n_bits,...line__correlation threshold__threshold,...ipeline__variance threshold__threshold,lgbmclassifier__cols ample_bytree,lgbmclassifier__max_ depth,lgbmclassifier__min_ child_samples,lgbmclassifier__n_es timators,lgbmclassifier__num_ leaves,lgbmclassifier__reg_ alpha,lgbmclassifier__reg_ lambda,lgbmclassifier__scal e_pos_weight,lgbmclassifier__subs ample,iter,total time (s),accuracy,balanced_accuracy,f1
objective_e841bc60,TERMINATED,10.128.15.208:217608,512,0.805185,0.158999,0.65902,13,55,537,78,2.70733e-07,0.00021622,38,0.340145,1,7.04906,0.911538,0.561856,0.203704
objective_fb3ced66,TERMINATED,10.128.15.208:217704,1024,0.970795,0.246334,0.637071,20,26,259,26,2.86421e-08,9.92787e-06,36,0.214444,1,5.27004,0.898438,0.59547,0.26087
objective_d7cac4f6,TERMINATED,10.128.15.208:217798,512,0.840349,0.166947,0.968096,30,52,374,43,0.00359304,0.000114823,94,0.730677,1,6.43983,0.889744,0.564569,0.206897
objective_39430a1b,TERMINATED,10.128.15.208:217896,2048,0.887349,0.156886,0.124482,48,28,994,218,3.75547e-08,3.66773e-06,34,0.465648,1,15.1371,0.920228,0.550159,0.179775
objective_0862bd57,TERMINATED,10.128.15.208:217995,512,0.920363,0.250065,0.5536,42,35,417,251,3.13596e-06,2.08811e-08,48,0.497219,1,9.21092,0.914103,0.587825,0.263736
objective_3d5fb8e6,TERMINATED,10.128.15.208:218097,1024,0.875842,0.234187,0.473415,10,38,974,100,1.50958e-08,0.000431931,68,0.547366,1,16.3583,0.919231,0.585775,0.265306
objective_5ed3624d,TERMINATED,10.128.15.208:218196,512,0.80836,0.128315,0.800207,40,17,761,45,0.000321287,0.000426942,92,0.893487,1,18.7203,0.917949,0.556428,0.195652
objective_7dec4362,TERMINATED,10.128.15.208:218295,512,0.844168,0.185836,0.843625,27,64,611,139,0.0249694,8.67258e-06,42,0.647598,1,10.0816,0.90641,0.558431,0.19802
objective_fa914e28,TERMINATED,10.128.15.208:218395,4096,0.939904,0.186977,0.185759,26,31,321,76,1.17964e-08,0.000254401,44,0.268558,1,6.9142,0.910256,0.560565,0.204082
objective_77ea212a,TERMINATED,10.128.15.208:218495,512,0.969023,0.187374,0.687718,36,80,268,57,1.34979e-07,0.0123317,32,0.76758,1,4.29461,0.868946,0.621101,0.259259


2025-02-03 03:46:42,189	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/rahul_e_dev/ray_results/objective_2025-02-03_03-39-26' in 0.0552s.
2025-02-03 03:46:42,242	INFO tune.py:1041 -- Total run time: 435.56 seconds (435.44 seconds for the tuning loop).


In [21]:
from sklearn.model_selection import TunedThresholdClassifierCV

best_result = results.get_best_result()
pipeline = get_pipeline(RANDOM_SEED)
config = {k:v for k,v in best_result.config.items() if k in pipeline.get_params()}
pipeline.set_params(**config)

pipeline = pipeline.fit(X_train_val, y_train_val)

kfold = ContinuousStratifiedKFold(n_splits=3, shuffle=True, random_state=RANDOM_SEED)
th_pipeline = TunedThresholdClassifierCV(pipeline, scoring='f1', cv=kfold, thresholds=500, random_state=RANDOM_SEED)
th_pipeline.fit(X_train_val, y_train_val)

y_test_pred_prob = th_pipeline.predict_proba(X_test)[:, 1]
y_test_pred = th_pipeline.predict(X_test)
test_scores = calc_scores(y_test_pred_prob, y_test_pred, y_test.to_numpy())
test_scores = {k: np.round(v, 3) for k, v in test_scores.items()}



In [22]:
pd.DataFrame.from_records([test_scores])

Unnamed: 0,accuracy,balanced_accuracy,f1,precision,recall,roc_auc,average_precision,specificity,sensitivity,test_delong_auc,lb,ub
0,0.889,0.557,0.203,0.32,0.148,0.712,0.258,0.967,0.148,0.712,0.634,0.79


In [23]:
best_result.metrics_dataframe[
    list(test_scores.keys())
].round(3)

Unnamed: 0,accuracy,balanced_accuracy,f1,precision,recall,roc_auc,average_precision,specificity,sensitivity,test_delong_auc,lb,ub
0,0.908,0.581,0.248,0.325,0.2,0.728,0.278,0.962,0.2,0.728,0.674,0.786


In [13]:
import dill

with open('./saved_models/parasite.pkl', 'wb') as outfile:
    dill.dump(th_pipeline, outfile)