In [1]:
import pandas as pd
from sklearn.model_selection import GroupKFold, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

In [30]:
clash = pd.read_csv("clash_readout.csv")
train_df = clash[clash["IS_NEG_CONTROL"] == 0].copy() #only train on barbed-end toxins

In [31]:
agg = (                            #one input per toxin
    train_df
    .groupby(["toxin", "receptor_state", "box_type"], as_index=False)
    .agg(
        vina_score_best=("vina_score", "min"),
        self_clash_med_06=("self_clash_0.6", "median"),
        self_clash_max_04=("self_clash_0.4", "max"),
        incoming_clash_med_06=("incoming_clash_0.6", "median"),
        incoming_clash_max_06=("incoming_clash_0.6", "max"),
        label=("IS_BLOCKER", "max"),
    )
)

In [33]:
feats = pd.read_csv("ligand_features_with_flags.csv")

Unnamed: 0,Name,SMILES,valid,cLogP,TPSA,HBD,HBA,MW,RotB,Rings,...,FracCSP3,Lipinski_violations,Flag_high_cLogP,Flag_high_TPSA,Flag_high_RotB,Flag_high_MW,Flag_low_MW,Flag_any,Flag_Category,Flag_Reasons
0,2c,[H]C(N(CC)/C=C/[C@@H](C)[C@@H](OC)[C@@H](C)[C@...,True,5.633,91.37,0,7,541.77,20,0,...,0.833,2,True,False,True,False,False,True,TooFlexible,too_flexible;too_hydrophobic_cLogP;lipinski_vi...
1,2a,O=C([C@@H]1OCCC1)O[C@H](CC[C@H](C)[C@@H](OC)C[...,True,4.485,126.62,1,9,598.82,22,1,...,0.844,1,False,False,True,False,False,True,TooFlexible,too_flexible
2,2u,O=C([C@@H]1OCCC1)O[C@H](CC[C@H](C)[C@@H](OC)C[...,True,4.094,126.62,1,9,584.8,21,1,...,0.839,1,False,False,True,False,False,True,TooFlexible,too_flexible
3,Swinholide_A,C[C@@H]1O[C@H](C[C@@H](C1)OC)CC[C@@H]([C@H](O)...,True,10.108,288.28,8,20,1389.89,16,5,...,0.821,4,True,True,True,True,False,True,VeryLarge,very_large_MW;too_flexible;too_polar_TPSA;too_...
4,Aplyronine_A,C[C@@H]1CC[C@@H](/C(=C/C[C@H](C[C@H](/C=C/C[C@...,True,7.736,200.14,2,16,1076.46,23,1,...,0.746,3,True,True,True,True,False,True,VeryLarge,very_large_MW;too_flexible;too_polar_TPSA;too_...


In [34]:
feats = feats.rename(columns={"Name": "toxin"})
data = agg.merge(feats, on="toxin", how="left")

Unnamed: 0,toxin,receptor_state,box_type,vina_score_best,self_clash_med_06,self_clash_max_04,incoming_clash_med_06,incoming_clash_max_06,label,SMILES,...,FracCSP3,Lipinski_violations,Flag_high_cLogP,Flag_high_TPSA,Flag_high_RotB,Flag_high_MW,Flag_low_MW,Flag_any,Flag_Category,Flag_Reasons
0,2a,AMPPNP,TIGHT,-5.3,0.0,2,226.0,232,1,O=C([C@@H]1OCCC1)O[C@H](CC[C@H](C)[C@@H](OC)C[...,...,0.844,1,False,False,True,False,False,True,TooFlexible,too_flexible
1,2c,AMPPNP,TIGHT,-5.7,0.0,0,201.0,207,1,[H]C(N(CC)/C=C/[C@@H](C)[C@@H](OC)[C@@H](C)[C@...,...,0.833,2,True,False,True,False,False,True,TooFlexible,too_flexible;too_hydrophobic_cLogP;lipinski_vi...
2,2u,AMPPNP,TIGHT,-5.5,0.0,0,220.0,228,1,O=C([C@@H]1OCCC1)O[C@H](CC[C@H](C)[C@@H](OC)C[...,...,0.839,1,False,False,True,False,False,True,TooFlexible,too_flexible
3,Aplyronine_A,AMPPNP,TIGHT,-5.3,0.0,0,261.0,358,1,C[C@@H]1CC[C@@H](/C(=C/C[C@H](C[C@H](/C=C/C[C@...,...,0.746,3,True,True,True,True,False,True,VeryLarge,very_large_MW;too_flexible;too_polar_TPSA;too_...
4,Bistramide_A,AMPPNP,TIGHT,-7.0,0.0,0,241.0,241,1,C/C=C/C(=O)C[C@H]1CC[C@@H]([C@@H](O1)CC(=O)NC[...,...,0.825,2,True,True,True,True,False,True,VeryLarge,very_large_MW;too_flexible;too_polar_TPSA;too_...


In [35]:
potency = pd.read_csv("cell_potency.csv")
potency = potency.rename(columns={"Name": "toxin"})
data_train = data.merge(potency, on="toxin", how="left")

In [36]:
X_cols = [
    "vina_score_best",
    "self_clash_med_06", "self_clash_max_04",
    "incoming_clash_med_06", "incoming_clash_max_06",
    "MW", "cLogP", "TPSA", "HBD", "HBA", "RotB", "Rings", "FracCSP3",
]
X = data_train[X_cols]

In [39]:
data_train["IS_POTENT"] = (data_train["CELL_POTENCY"] < 2).astype(int) #potent = 1  if IC50 < 100 nM

In [42]:
y = data_train[["IS_POTENT"]].values.ravel()

In [45]:
groups = data_train["toxin"]
cv = GroupKFold(n_splits=min(5, data_train["toxin"].nunique()))

In [46]:
model = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", LogisticRegression(
        penalty="l2",
        class_weight="balanced",
        max_iter=500
    ))
])

scores = cross_val_score(
    model, X, y,
    cv=cv,
    groups=groups,
    scoring="roc_auc"
)

In [48]:
print("Grouped ROC-AUC:", scores.mean(), "+/-", scores.std())

Grouped ROC-AUC: 0.9 +/- 0.20000000000000004
