In [None]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import numpy as np
import tensorflow as tf
import selfies as sf
import exmol
from dataclasses import dataclass
from rdkit.Chem.Draw import rdDepictor

rdDepictor.SetPreferCoordGen(True)
sns.set_context("notebook")
sns.set_style(
    "dark",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#666666",
        "ytick.color": "#666666",
        "axes.edgecolor": "#666666",
        "axes.linewidth": 0.8,
        "figure.dpi": 300,
    },
)
color_cycle = ["#1BBC9B", "#F06060", "#5C4B51", "#F3B562", "#6e5687"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)
soldata = pd.read_csv(
    "https://github.com/whitead/dmol-book/raw/master/data/curated-solubility-dataset.csv"
)
features_start_at = list(soldata.columns).index("MolWt")
np.random.seed(0)

In [None]:
# scramble them
# Reduced for CI!
soldata = soldata.sample(frac=0.01, random_state=0).reset_index(drop=True)
soldata.head()

In [None]:
selfies_list = [sf.encoder(exmol.sanitize_smiles(s)[1]) for s in soldata.SMILES]

In [None]:
basic = set(exmol.get_basic_alphabet())
data_vocab = set(sf.get_alphabet_from_selfies([s for s in selfies_list if s is not None]))
vocab = ['[Nop]']
vocab.extend(list(data_vocab.union(basic)))
vocab_stoi = {o:i for o,i in zip(vocab, range(len(vocab)))}
#vocab_stoi['[nop]'] = 0 

def selfies2ints(s):
    result = []
    for token in sf.split_selfies(s):
        if token == '.':
            continue #?
        if token in vocab_stoi:
            result.append(vocab_stoi[token])
        else:
            result.append(np.nan)
            #print('Warning')
    return result

def ints2selfies(v):
    return ''.join([vocab[i] for i in v])

# test them out
s = selfies_list[0]
print('selfies:', s)
v = selfies2ints(s)
print('selfies2ints:', v)
so = ints2selfies(v)
print('ints2selfes:', so)
assert so == s.replace('.','') #make sure '.' is removed from Selfies string during assertion

In [None]:
@dataclass
class Config:
    vocab_size: int
    example_number: int
    batch_size: int
    buffer_size: int
    embedding_dim: int
    rnn_units: int
    hidden_dim: int


config = Config(
    vocab_size=len(vocab),
    example_number=len(selfies_list),
    batch_size=16,
    buffer_size=10000,
    embedding_dim=256,
    hidden_dim=128,
    rnn_units=128,
)

In [None]:
# now get sequences
encoded = [selfies2ints(s) for s in selfies_list if s is not None]
padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(encoded, padding="post")

# Now build dataset
data = tf.data.Dataset.from_tensor_slices(
    (padded_seqs, soldata.Solubility.iloc[[bool(s) for s in selfies_list]].values)
)
# now split into val, test, train and batch
N = len(data)
split = int(0.1 * N)
test_data = data.take(split).batch(config.batch_size)
nontest = data.skip(split)
val_data, train_data = nontest.take(split).batch(config.batch_size), nontest.skip(
    split
).shuffle(config.buffer_size).batch(config.batch_size).prefetch(
    tf.data.experimental.AUTOTUNE
)

In [None]:
model = tf.keras.Sequential()

# make embedding and indicate that 0 should be treated as padding mask
model.add(
    tf.keras.layers.Embedding(
        input_dim=config.vocab_size, output_dim=config.embedding_dim, mask_zero=True
    )
)

# RNN layer
model.add(tf.keras.layers.GRU(config.rnn_units))
# a dense hidden layer
model.add(tf.keras.layers.Dense(config.hidden_dim, activation="relu"))
# regression, so no activation
model.add(tf.keras.layers.Dense(1))

model.summary()

In [None]:
model.compile(tf.optimizers.Adam(1e-4), loss="mean_squared_error")
result = model.fit(train_data, validation_data=val_data, epochs=100, verbose=2)

In [None]:
plt.plot(result.history["loss"], label="training")
plt.plot(result.history["val_loss"], label="validation")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

In [None]:
yhat = []
test_y = []
for x, y in test_data:
    yhat.extend(model(x).numpy().flatten())
    test_y.extend(y.numpy().flatten())
yhat = np.array(yhat)
test_y = np.array(test_y)

# plot test data
plt.plot(test_y, test_y, ":")
plt.plot(test_y, yhat, ".")
plt.text(min(y) - 7, max(y) - 2, f"correlation = {np.corrcoef(test_y, yhat)[0,1]:.3f}")
plt.text(min(y) - 7, max(y) - 3, f"loss = {np.sqrt(np.mean((test_y - yhat)**2)):.3f}")
plt.title("Testing Data")
plt.savefig("rnn-fit.png", dpi=300)
plt.show()

## CF explanation:

In the following example let's say we would like our molecules to return a solubility value of -3.5. Here we use `counterstone` algorithm to create counter factual explanations. In other words, we would like to see what are the minimal mutations that could to be done to our input structure to get our desired solubility.

In [None]:
def predictor_function(smile_list, selfies):
    encoded = [selfies2ints(s) for s in selfies]
    # check for nans
    valid = [1.0 if sum(e) > 0 else np.nan for e in encoded]
    encoded = [np.nan_to_num(e, nan=0) for e in encoded]
    padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(encoded, padding="post")
    labels = np.reshape(model.predict(padded_seqs), (-1))
    return labels * valid

In [None]:
predictor_function([], ["[C][C][O]", "[C][C][Nop][O]"])

In [None]:
stoned_kwargs = {
    "num_samples": 2500,
    "alphabet": exmol.get_basic_alphabet(),
    "max_mutations": 2,
}
space = exmol.sample_space(
    soldata.SMILES[4], predictor_function, stoned_kwargs=stoned_kwargs
)
exps = exmol.rcf_explain(space, 0.5, nmols=4)

In [None]:
fkw = {"figsize": (10, 3)}
exmol.plot_cf(exps, figure_kwargs=fkw, mol_size=(450, 400), nrows=1)
plt.savefig("rnn-simple.png", bbox_inches="tight", dpi=180)
svg = exmol.insert_svg(exps, mol_fontsize=16)
with open("rnn-simple.svg", "w") as f:
    f.write(svg)

In [None]:
fkw = {"figsize": (10, 4)}
font = {"family": "normal", "weight": "normal", "size": 22}

exmol.plot_space(space, exps, figure_kwargs=fkw, mol_size=(100, 100), offset=1)
ax = plt.gca()
plt.colorbar(
    ax.get_children()[1],
    ax=[ax],
    label="Solubility [Log M]",
    location="left",
    shrink=0.8,
)
plt.savefig("rnn-space.png", bbox_inches="tight", dpi=180)
svg = exmol.insert_svg(exps, mol_fontsize=16)
with open("svg_figs/rnn-space.svg", "w") as f:
    f.write(svg)

In [None]:
space = exmol.sample_space(soldata.SMILES[4], predictor_function, preset="wide")
exps = exmol.rcf_explain(space, 0.5)

In [None]:
fkw = {"figsize": (8, 6)}
font = {"family": "normal", "weight": "normal", "size": 22}


exmol.plot_space(space, exps, figure_kwargs=fkw, mol_size=(200, 200), offset=1)
ax = plt.gca()
plt.colorbar(ax.get_children()[1], ax=[ax], location="left", label="Solubility [Log M]")
plt.savefig("rnn-wide.png", bbox_inches="tight", dpi=180)
svg = exmol.insert_svg(exps, mol_fontsize=16)
with open("rnn-space-wide.svg", "w") as f:
    f.write(svg)

## Figure showing effect of mutation number and Alphabet


In [None]:
exps = []
spaces = []
for i in [1, 3, 5]:
    stoned_kwargs = {
        "num_samples": 2500,
        "alphabet": exmol.get_basic_alphabet(),
        "min_mutations": i,
        "max_mutations": i,
    }
    space = exmol.sample_space(
        soldata.SMILES[4], predictor_function, stoned_kwargs=stoned_kwargs
    )
    spaces.append(space)
    e = exmol.rcf_explain(space, nmols=2)
    if len(exps) == 0:
        exps.append(e[0])
    for ei in e:
        if not ei.is_origin and "Decrease" in ei.label:
            ei.label = f"Mutations = {i}"
            exps.append(ei)
            break

In [None]:
fkw = {"figsize": (10, 4)}
exmol.plot_cf(exps, figure_kwargs=fkw, mol_fontsize=26, mol_size=(400, 400), nrows=1)
plt.savefig("rnn-mutations.png", bbox_inches="tight", dpi=180)
svg = exmol.insert_svg(exps, mol_fontsize=16)
with open("rnn-mutations.svg", "w") as f:
    f.write(svg)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(8, 3), dpi=180, squeeze=True, sharey=True)
for i, n in enumerate([1, 3, 5]):
    axs[i].hist([e.similarity for e in spaces[i][1:]], bins=99, edgecolor="none")
    axs[i].set_title(f"Mutations = {n}")
    axs[i].set_xlim(0, 1)
plt.tight_layout()
plt.savefig("rnn-mutation-hist.png", bbox_inches="tight", dpi=180)

In [None]:
basic = exmol.get_basic_alphabet()
train = sf.get_alphabet_from_selfies(selfies_list)
wide = sf.get_semantic_robust_alphabet()

alphs = {"Basic": basic, "Training Data": train, "SELFIES": wide}

exps = []
for l, a in alphs.items():
    stoned_kwargs = {"num_samples": 2500 // 2, "alphabet": a, "max_mutations": 2}
    space = exmol.sample_space(
        soldata.SMILES[4], predictor_function, stoned_kwargs=stoned_kwargs
    )
    e = exmol.rcf_explain(space, nmols=2)
    if len(exps) == 0:
        exps.append(e[0])
    for ei in e:
        if not ei.is_origin and "Decrease" in ei.label:
            ei.label = f"Alphabet = {l}"
            exps.append(ei)
            break

In [None]:
fkw = {"figsize": (10, 4)}
exmol.plot_cf(exps, figure_kwargs=fkw, mol_fontsize=26, mol_size=(400, 400), nrows=1)
plt.savefig("rnn-alphabets.png", bbox_inches="tight", dpi=180)
svg = exmol.insert_svg(exps, mol_fontsize=16)