# LIME paper: Recurrent Neural Network for Blood brain barrier permeation

## Import packages and set up RNN

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, FancyBboxPatch
from matplotlib.offsetbox import AnnotationBbox
import seaborn as sns
import textwrap
import skunk
import matplotlib as mpl
import numpy as np
import tensorflow as tf
import selfies as sf
import exmol
from dataclasses import dataclass
from rdkit.Chem.Draw import rdDepictor, MolsToGridImage
from rdkit.Chem import MolFromSmiles, MACCSkeys

rdDepictor.SetPreferCoordGen(True)
import matplotlib.font_manager as font_manager
import urllib.request

urllib.request.urlretrieve(
    "https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf",
    "IBMPlexMono-Regular.ttf",
)
fe = font_manager.FontEntry(fname="IBMPlexMono-Regular.ttf", name="plexmono")
font_manager.fontManager.ttflist.append(fe)
plt.rcParams.update(
    {
        "axes.facecolor": "#f5f4e9",
        "grid.color": "#AAAAAA",
        "axes.edgecolor": "#333333",
        "figure.facecolor": "#FFFFFF",
        "axes.grid": False,
        "axes.prop_cycle": plt.cycler("color", plt.cm.Dark2.colors),
        "font.family": fe.name,
        "figure.figsize": (3.5, 3.5 / 1.2),
        "ytick.left": True,
        "xtick.bottom": True,
    }
)

color_cycle = ["#F06060", "#1BBC9B", "#F06060", "#5C4B51", "#F3B562", "#6e5687"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)
mpl.rcParams["font.size"] = 10
bbb_data = pd.read_csv("../paper1_CFs/BBBP.csv")
# features_start_at = list(bbb_data.columns).index("MolWt")
np.random.seed(0)

In [None]:
# scramble them
bbb_data = bbb_data.sample(frac=1, random_state=0).reset_index(drop=True)
bbb_data.head()

In [None]:
from rdkit.Chem import MolToSmiles


def _randomize_smiles(mol, isomericSmiles=True):
    return MolToSmiles(
        mol,
        canonical=False,
        doRandom=True,
        isomericSmiles=isomericSmiles,
        kekuleSmiles=random.random() < 0.5,
    )

In [None]:
import random

smiles = list(bbb_data["smiles"])
permeabilities = list(bbb_data["p_np"])

aug_data = 0

aug_smiles = []
aug_perm = []
for sml, sol in zip(smiles, permeabilities):
    new_smls = []
    new_smls.append(sml)
    aug_perm.append(sol)
    for _ in range(aug_data):
        try:
            new_sml = _randomize_smiles(MolFromSmiles(sml))
            # print(new_sml)
            if new_sml not in new_smls:
                new_smls.append(new_sml)
                aug_perm.append(sol)
        except:
            continue
    aug_smiles.extend(new_smls)

aug_df_bbb = pd.DataFrame(data={"smiles": aug_smiles, "p_np": aug_perm})

print(f"The dataset was augmented from {len(bbb_data)} to {len(aug_df_bbb)}.")

In [None]:
selfies_list = []
for i, s in enumerate(aug_df_bbb.smiles):
    try:
        selfies_list.append(sf.encoder(exmol.sanitize_smiles(s)[1]))
    except (sf.EncoderError, TypeError):
        selfies_list.append(None)
    bbb_data.smiles[i] = exmol.sanitize_smiles(s)[1]
len(selfies_list)

In [None]:
basic = set(exmol.get_basic_alphabet())
data_vocab = set(
    sf.get_alphabet_from_selfies([s for s in selfies_list if s is not None])
)
vocab = ['[nop]']
vocab.extend(list(data_vocab.union(basic)))
vocab_stoi = {o: i for o, i in zip(vocab, range(len(vocab)))}


def selfies2ints(s):
    result = []
    for token in sf.split_selfies(s):
        if token == '.':
            continue  # ?
        if token in vocab_stoi:
            result.append(vocab_stoi[token])
        else:
            print(token)
            result.append(np.nan)
            # print('Warning')
    return result


def ints2selfies(v):
    return "".join([vocab[i] for i in v])


# test them out
s = selfies_list[0]
print('selfies:', s)
v = selfies2ints(s)
print('selfies2ints:', v)
so = ints2selfies(v)
print('ints2selfes:', so)
assert so == s.replace(
    '.', ''
)  # make sure '.' is removed from Selfies string during assertion

In [None]:
# creating an object
@dataclass
class Config:
    vocab_size: int
    example_number: int
    batch_size: int
    buffer_size: int
    embedding_dim: int
    rnn_units: int
    hidden_dim: int
    drop_rate: float


config = Config(
    vocab_size=len(vocab),
    example_number=len(selfies_list),
    batch_size=128,
    buffer_size=10000,
    embedding_dim=64,
    hidden_dim=32,
    rnn_units=64,
    drop_rate=0.20,
)

In [None]:
# now get sequences
encoded = [selfies2ints(s) for s in selfies_list if s is not None]
padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(encoded, padding="post")

permeabilities = aug_df_bbb.p_np.values[[bool(s) for s in selfies_list]]

# Should be shuffled from the beginning, so no worries
N = len(padded_seqs)
split = int(0.1 * N)

# Now build dataset
test_data = tf.data.Dataset.from_tensor_slices(
    (padded_seqs[:split], permeabilities[:split])
).batch(config.batch_size)

nontest = tf.data.Dataset.from_tensor_slices(
    (
        padded_seqs[split:],
        permeabilities[split:],
    )
)

val_data, train_data = nontest.take(split).batch(config.batch_size), nontest.skip(
    split
).shuffle(config.buffer_size).batch(config.batch_size).prefetch(
    tf.data.experimental.AUTOTUNE
)

In [None]:
model = tf.keras.Sequential()

# make embedding and indicate that 0 should be treated as padding mask
model.add(
    tf.keras.layers.Embedding(
        input_dim=config.vocab_size, output_dim=config.embedding_dim, mask_zero=True
    )
)
model.add(tf.keras.layers.Dropout(config.drop_rate))
# RNN layer
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(config.rnn_units)))
model.add(tf.keras.layers.Dropout(config.drop_rate))
# a dense hidden layer
model.add(tf.keras.layers.Dense(config.hidden_dim, activation="relu"))
model.add(tf.keras.layers.Dropout(config.drop_rate))
# regression, so no activation
model.add(tf.keras.layers.Dense(1))

model.summary()

In [None]:
model.compile(
    tf.optimizers.Adam(1e-3),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
# verbose=0 silences output, to get progress bar set verbose=1
result = model.fit(train_data, validation_data=val_data, epochs=100, verbose=1)

In [None]:
model.save("bbbp-rnn")
# model = tf.keras.models.load_model('solubility-rnn-accurate/')

In [None]:
# plot test data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.plot(result.history["loss"], label="training")
ax1.plot(result.history["val_loss"], label="validation")
ax1.legend()
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")

ax2.plot(result.history["accuracy"], label="training")
ax2.plot(result.history["val_accuracy"], label="validation")
ax2.legend()
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
fig.tight_layout()
fig.savefig("bbp-rnn-loss-acc.png", dpi=180)
fig.show()

In [None]:
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

prediction = []
test_y = []

for x, y in test_data:
    prediction.extend(model(x).numpy().flatten())
    test_y.extend(y.numpy().flatten())

prediction = np.array(prediction).flatten()
test_y = np.array(test_y)

fpr_keras, tpr_keras, thresholds_keras = roc_curve(test_y, prediction)
auc_keras = auc(fpr_keras, tpr_keras)

plt.figure(figsize=(5, 3.5), dpi=100)
plt.plot(fpr_keras, tpr_keras, label="AUC = {:.3f}".format(auc_keras))
plt.plot([0, 1], [0, 1], linestyle="--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.savefig("bbbp-rnn-roc.png", dpi=300)
plt.show()

## LIME explanations

In the following example, we find out what descriptors influence solubility of a molecules. For example, let's say we have a molecule with LogS=1.5. We create a perturbed chemical space around that molecule using `stoned` method and then use `lime` to find out which descriptors affect solubility predictions for that molecule. 

### Wrapper function for RNN, to use in STONED

In [None]:
# Predictor function is used as input to sample_space function
def predictor_function(smile_list, selfies):
    if len(selfies) < 1:
        selfies = [sf.encoder(s) for s in smile_list]
    encoded = [selfies2ints(s) for s in selfies]
    # check for nans
    valid = [1.0 if sum(e) > 0 else np.nan for e in encoded]
    encoded = [np.nan_to_num(e, nan=0) for e in encoded]
    padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(encoded, padding="post")
    labels = np.reshape(model.predict(padded_seqs), (-1))
    return labels

### Other ploting utilities

In [None]:
def space_fit_plot(space, beta, mode="regression"):
    fkw = {"figsize": (10, 4)}
    font = {"family": "normal", "weight": "normal", "size": 16}

    fig = plt.figure(figsize=(10, 5))
    mpl.rc("axes", titlesize=12)
    mpl.rc("font", size=16)
    ax_dict = fig.subplot_mosaic("AABBB")

    # Plot space by fit
    svg = exmol.plot_utils.plot_space_by_fit(
        space,
        [space[0]],
        figure_kwargs=fkw,
        mol_size=(200, 200),
        offset=1,
        ax=ax_dict["B"],
        beta=beta,
    )
    # Compute y_wls
    w = np.array([1 / (1 + (1 / (e.similarity + 0.000001) - 1) ** 5) for e in space])
    non_zero = w > 10 ** (-6)
    w = w[non_zero]
    N = w.shape[0]

    ys = np.array([e.yhat for e in space])[non_zero].reshape(N).astype(float)
    x_mat = np.array([list(e.descriptors.descriptors) for e in space])[
        non_zero
    ].reshape(N, -1)
    y_wls = x_mat @ beta
    y_wls += np.mean(ys)

    lower = np.min(ys)
    higher = np.max(ys)

    # set transparency using w
    norm = plt.Normalize(min(w), max(w))
    cmap = plt.cm.Oranges(w)
    cmap[:, -1] = w

    def weighted_mean(x, w):
        return np.sum(x * w) / np.sum(w)

    def weighted_cov(x, y, w):
        return np.sum(
            w * (x - weighted_mean(x, w)) * (y - weighted_mean(y, w))
        ) / np.sum(w)

    def weighted_correlation(x, y, w):
        return weighted_cov(x, y, w) / np.sqrt(
            weighted_cov(x, x, w) * weighted_cov(y, y, w)
        )

    corr = weighted_correlation(ys, y_wls, w)

    if mode == "regression":
        ax_dict["A"].plot(
            np.linspace(lower, higher, 100),
            np.linspace(lower, higher, 100),
            "--",
            linewidth=2,
        )
        sc = ax_dict["A"].scatter(ys, y_wls, s=50, marker=".", c=cmap, cmap=cmap)
        ax_dict["A"].text(
            max(ys) - 10,
            min(ys) + 1,
            f"weighted \ncorrelation = {corr:.3f}",
            fontsize=10,
        )
        ax_dict["A"].set_xlabel(r"$\hat{y}$")
        ax_dict["A"].set_ylabel(r"$g$")
        ax_dict["A"].set_title("Weighted Least Squares Fit")
        ax_dict["A"].set_xlim(lower, higher)
        ax_dict["A"].set_ylim(lower, higher)
        ax_dict["A"].set_aspect(1.0 / ax_dict["A"].get_data_ratio(), adjustable="box")
        sm = plt.cm.ScalarMappable(cmap=plt.cm.Oranges, norm=norm)
        cbar = plt.colorbar(sm, orientation="horizontal", pad=0.15, ax=ax_dict["A"])
        cbar.set_label("Chemical similarity")
        plt.tight_layout()
        plt.savefig("weighted_fit.svg", dpi=300, bbox_inches="tight", transparent=False)
    if mode == "classification":
        fpr_keras, tpr_keras, thresholds_keras = roc_curve(ys, y_wls)
        auc_keras = auc(fpr_keras, tpr_keras)
        ax_dict["A"].plot(fpr_keras, tpr_keras, label="AUC = {:.3f}".format(auc_keras))
        ax_dict["A"].plot([0, 1], [0, 1], linestyle="--")
        ax_dict["A"].set_xlabel("False Positive Rate")
        ax_dict["A"].set_ylabel("True Positive Rate")
        ax_dict["A"].set_aspect(1.0 / ax_dict["A"].get_data_ratio(), adjustable="box")
        ax_dict["A"].legend()
        plt.tight_layout()
        plt.savefig("space_fit.svg", dpi=300)
        plt.show()

### Descriptor explanations

In [None]:
# smi = 'Cc1onc(-c2ccccc2Cl)c1C(=O)NC1C(=O)N2C1SC(C)(C)C2C(=O)O'
# smi = soldata.SMILES[1400]
ibuprofen = "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"  # ibuprofen is known to cross BBB
nicotine = bbb_data.smiles[bbb_data.name == "nicotine"].values[0]  # treat anxiety
caffeine = bbb_data.smiles[bbb_data.name == "caffeine"].values[0]
chlorpromazine = "CN(CCCN1c2ccccc2Sc2c1cc(Cl)cc2)C"
gleevec = "Cc1ccc(cc1Nc2nccc(n2)c3cccnc3)NC(=O)c4ccc(cc4)CN5CCN(CC5)C"
# mol = MolFromSmiles(smi)
# from rdkit.Chem.Draw import MolToFile

# MolToFile(mol, 'mol_paper.svg')
predictor_function([ibuprofen, nicotine, caffeine, chlorpromazine, gleevec], [])

In [None]:
diazepam = "CN1C(=O)CN=C(C2=C1C=CC(=C2)Cl)C3=CC=CC=C3"
predictor_function([diazepam], [])

In [None]:
# Make sure SMILES doesn't contain multiple fragments
# smi = soldata.SMILES[1400]
# print(smi)
alprozolam = "CC1=NN=C2N1C3=C(C=C(C=C3)Cl)C(=NC2)C4=CC=CC=C4"
stoned_kwargs = {
    "num_samples": 5000,
    "alphabet": exmol.get_basic_alphabet(),
    "max_mutations": 1,
}
space = exmol.sample_space(alprozolam, predictor_function, stoned_kwargs=stoned_kwargs)

In [None]:
# Filter space
from synspace.reos import REOS

reos = REOS()

filtered_space = []
for e in space:
    if reos.process_mol(MolFromSmiles(e.smiles)) == ("ok", "ok"):
        filtered_space.append(e)

len(filtered_space)

In [None]:
from IPython.display import display, SVG

desc_type = ["Classic", "ECFP", "MACCS"]
fkw = {"figsize": (6, 4)}
for d in desc_type:
    beta = exmol.lime_explain(filtered_space, descriptor_type=d)
    if d == "Classic":
        exmol.plot_descriptors(filtered_space, output_file=f"alprozolam_{d}.svg")
    else:
        svg = exmol.plot_descriptors(
            filtered_space, output_file=f"alprozolam_{d}.svg", return_svg=True
        )
        plt.close()
        skunk.display(svg)
    space_fit_plot(filtered_space, beta, mode="regression")

In [None]:
exmol.lime_explain(filtered_space, "ECFP")
_ = exmol.plot_utils.similarity_map_using_tstats(filtered_space[0], return_svg=True)

In [None]:
display(SVG(_))

In [None]:
OPEN_AI_KEY = os.environ.get("OPENAI_API_KEY")
exmol.lime_explain(filtered_space, "ecfp")
s1_ecfp = exmol.text_explain(filtered_space, "ecfp")
explanation = exmol.text_explain_generate(s1_ecfp, "aqueous solubility")
print(explanation)