# MMACE Paper: Random Forest for Blood-Brain Barrier

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import rdkit, rdkit.Chem, rdkit.Chem.Draw
from rdkit.Chem.Draw import IPythonConsole
import numpy as np
import skunk
import mordred, mordred.descriptors
import exmol as exmol
from rdkit.Chem.Draw import rdDepictor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, plot_roc_curve

rdDepictor.SetPreferCoordGen(True)

IPythonConsole.ipython_useSVG = True
sns.set_context("notebook")
sns.set_style(
    "dark",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#666666",
        "ytick.color": "#666666",
        "axes.edgecolor": "#666666",
        "axes.linewidth": 0.8,
        "figure.dpi": 300,
    },
)
color_cycle = ["#1BBC9B", "#F06060", "#F3B562", "#6e5687", "#5C4B51"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)
np.random.seed(0)

In [None]:
data = pd.read_csv("BBBP.csv")
data.head()

In [None]:
def largest_mol(smiles):
    ss = smiles.split(".")
    ss.sort(key=lambda a: len(a))
    return ss[-1]

In [None]:
# make object that can compute descriptors
calc = mordred.Calculator(mordred.descriptors, ignore_3D=True)
# make subsample from pandas df
molecules = [rdkit.Chem.MolFromSmiles(largest_mol(smi)) for smi in data.smiles]

# the invalid molecules were None, so we'll just
# use the fact the None is False in Python
valid_mol_idx = [bool(m) for m in molecules]
valid_mols = [m for m in molecules if m]
try:
    raw_features = pd.read_pickle("raw_features.pb")
except FileNotFoundError as e:
    raw_features = calc.pandas(valid_mols, nproc=8, quiet=True)
    raw_features.to_pickle("raw_features.pb")

In [None]:
labels = data[valid_mol_idx].p_np

In [None]:
fm = raw_features.mean()
fs = raw_features.std()


def feature_convert(f):
    f -= fm
    f /= fs
    return f


features = feature_convert(raw_features)

# we have some nans in features, likely because std was 0
features = features.values.astype(float)
features_select = np.all(np.isfinite(features), axis=0)
features = features[:, features_select]

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    features, labels, test_size=0.2, shuffle=True
)

clf = RandomForestClassifier(max_depth=8, random_state=0)
clf.fit(X_train, y_train)
predicted = clf.predict(X_test)
print("AUC", roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]))
plt.figure(figsize=(4, 3), dpi=300)
plot_roc_curve(clf, X_test, y_test)
plt.plot([0, 1], [0, 1], linestyle="--")
plt.savefig("RF-ROC.png")

In [None]:
def model_eval(smiles, _=None):
    molecules = [rdkit.Chem.MolFromSmiles(smi) for smi in smiles]
    # input wrangling. Get some weird values from weird smiles
    raw_features = calc.pandas(molecules, nproc=8, quiet=True)
    features = feature_convert(raw_features)
    features = features.values.astype(float)
    features = features[:, features_select]
    labels = clf.predict(np.nan_to_num(features))
    return labels
    # return np.random.choice([True, False], size=labels.shape)


labels = data.iloc[valid_mol_idx].p_np

In [None]:
example_neg = largest_mol(data.iloc[valid_mol_idx].smiles.values[np.argmin(labels)])
example_pos = largest_mol(data.iloc[valid_mol_idx].smiles.values[np.argmax(labels)])
example_neg_y, example_pos_y = model_eval([example_neg, example_pos])
print("neg:", example_neg, "\npos:", example_pos)
print(example_neg_y, example_pos_y)

In [None]:
import syngen

mols, props = syngen.chemical_space(
    example_neg, use_mannifold=True, samples=1000, steps=1, threshold=0.5, max=2000
)
print(len(mols))
mols = mols[:5000]
data = []
for m in mols:
    d = rdkit.Chem.MolToSmiles(m)
    data.append(d.replace("~", ""))

space = exmol.sample_space(
    example_neg, model_eval, data=data, preset="custom", quiet=True
)

In [None]:
exps = exmol.cf_explain(space)
print(exps)

In [None]:
fkw = {"figsize": (8, 6)}
mpl.rc("axes", titlesize=12)
exmol.plot_cf(exps, figure_kwargs=fkw, mol_size=(450, 400), nrows=1)

plt.savefig("rf-simple.png", dpi=180)

In [None]:
font = {"family": "normal", "weight": "normal", "size": 22}
exmol.plot_space(
    space,
    exps,
    figure_kwargs=fkw,
    mol_size=(300, 200),
    offset=0,
    cartoon=True,
    rasterized=True,
)
plt.scatter([], [], label="Crosses BBBP", s=150, color=plt.get_cmap("viridis")(1.0))
plt.scatter([], [], label="Does Not Cross", s=150, color=plt.get_cmap("viridis")(0.0))
plt.legend(fontsize=22)
plt.tight_layout()
plt.savefig("rf-space.png", dpi=180)