## Code to Generate Explanations for Different Scents
Using exmol package: https://github.com/ur-whitelab/exmol

In [None]:
# Imports
from typing import *
import exmol

import seaborn as sns
import pandas as pd
import numpy as np
import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw
from IPython.display import display, SVG
from rdkit.Chem.Draw import MolToImage as mol2img, DrawMorganBit  # type: ignore
from rdkit.Chem import rdchem, MACCSkeys, AllChem  # type: ignore
import skunk
import cairosvg
import matplotlib
import warnings

warnings.filterwarnings("ignore")

import pyrfume  # for loading dataset

# Packages needed for GNN model (model used when creating spaces)
##not needed if generating descriptor explanations or NLEs using a previously created sample space
import haiku as hk
import jax
import jax.numpy as jnp
import tensorflow as tf

# Plotting style
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import urllib.request

urllib.request.urlretrieve(
    "https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf",
    "IBMPlexMono-Regular.ttf",
)
fe = font_manager.FontEntry(fname="IBMPlexMono-Regular.ttf", name="plexmono")
font_manager.fontManager.ttflist.append(fe)
plt.rcParams.update(
    {
        "axes.facecolor": "#f5f4e9",
        "grid.color": "#AAAAAA",
        "axes.edgecolor": "#333333",
        "figure.facecolor": "#FFFFFF",
        "axes.grid": False,
        "axes.prop_cycle": plt.cycler("color", plt.cm.Dark2.colors),
        "font.family": fe.name,
        "figure.figsize": (3.5, 3.5 / 1.2),
        "ytick.left": True,
        "xtick.bottom": True,
    }
)


np.random.seed(0)
tf.random.set_seed(0)

### GNN model related code
Code below modified from example code given in the "Predicting DFT Energies with GNNs", "Interpretability and Deep Learning" sections of "Deep Learning for Molecules and Materials" textbook (https://dmol.pub/applied/QM9.html)

In [None]:
# Parameters for GNN model (parameters being read in)
node_feat_length = 256
message_feat_length = 256
graph_feat_length = 512
weights_stddevGNN = 0.01

# Load data from pyrfume
scentdata = scentdata = pyrfume.load_data(
    "leffingwell/leffingwell_data.csv", remote=True
)


# Code to generate list of all scent labels (scentClasses)
numMolecules = len(scentdata.odor_labels_filtered)
numClasses = 112  # No odorless class
scentClasses = pd.read_csv("scentClasses.csv")
scentClasses = scentClasses["Scent"].tolist()
moleculeScentList = []
for i in range(numMolecules):
    scentString = scentdata.odor_labels_filtered[i]
    temp = scentString.replace("[", "")
    temp = temp.replace("]", "")
    temp = temp.replace("'", "")
    temp = temp.replace(" ", "")
    scentList = temp.split(",")
    if "odorless" in scentList:
        scentList.remove("odorless")
    moleculeScentList.append(scentList)


def gen_smiles2graph(sml):
    """Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    """
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {
        rdkit.Chem.rdchem.BondType.SINGLE: 1,
        rdkit.Chem.rdchem.BondType.DOUBLE: 2,
        rdkit.Chem.rdchem.BondType.TRIPLE: 3,
        rdkit.Chem.rdchem.BondType.AROMATIC: 4,
    }
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N, node_feat_length))
    for i in m.GetAtoms():
        nodes[i.GetIdx(), i.GetAtomicNum()] = 1
        # Add in whether atom is in a ring or not for one-hot encoding
        if i.IsInRing():
            nodes[i.GetIdx(), -1] = 1

    adj = np.zeros((N, N))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning("Ignoring bond order" + order)
        adj[u, v] = 1
        adj[v, u] = 1
    adj += np.eye(N)
    return nodes, adj


# Function that creates label vector given list of strings describing scent of molecule as input
# Each index in label vector corresponds to specific scent -> if output has a 0 at index i, then molecule does not have scent i
# If label vector has 1 at index i, then molecule does have scent i


def createLabelVector(scentsList):
    # Find class index in label vector that each scent corresponds to & update label for that molecule to 1
    labelVector = np.zeros(numClasses)
    for j in range(len(scentsList)):
        # Find class index
        classIndex = scentClasses.index(scentsList[j])
        # print(classIndex)
        # print(scentsList[j])
        # print(scentClasses[classIndex])
        # Update label vector
        labelVector[classIndex] = 1
    return labelVector


def generateGraphs():
    for i in range(numMolecules):
        graph = gen_smiles2graph(scentdata.smiles[i])
        labels = createLabelVector(moleculeScentList[i])
        yield graph, labels


# Get graph data for training, testing & validation sets
data = tf.data.Dataset.from_generator(
    generateGraphs,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, node_feat_length]), tf.TensorShape([None, None])),
        tf.TensorShape([None]),
    ),
)


class GNNLayer(hk.Module):
    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, inputs):
        # split input into nodes, edges & features
        nodes, edges, features = inputs
        # Nodes is of shape (N, Nf) --> N = # atoms, Nf = node_feature_length
        # Edges is of shape (N,N) (adjacency matrix)
        # Features is of shape (Gf) --> Gf = graph_feature_length

        graph_feature_len = features.shape[-1]  # graph_feature_len (Gf)
        node_feature_len = nodes.shape[-1]  # node_feature_len (Nf)
        message_feature_len = message_feat_length  # message_feature_length (Mf)

        # Initialize weights
        w_init = hk.initializers.RandomNormal(stddev=weights_stddevGNN)

        # we is of shape (Nf,Mf)
        we = hk.get_parameter(
            "we", shape=[node_feature_len, message_feature_len], init=w_init
        )

        # b is of shape (Mf)
        b = hk.get_parameter("b", shape=[message_feature_len], init=w_init)

        # wv is of shape (Mf,Nf)
        wv = hk.get_parameter(
            "wv", shape=[message_feature_len, node_feature_len], init=w_init
        )

        # wu is of shape (Nf,Gf)
        wu = hk.get_parameter(
            "wu", shape=[node_feature_len, graph_feature_len], init=w_init
        )

        # make nodes be N x N x Nf so we can just multiply directly (N = number of atoms)
        # ek is now shaped N x N x Mf
        ek = jax.nn.leaky_relu(
            b
            + jnp.repeat(nodes[jnp.newaxis, ...], nodes.shape[0], axis=0)
            @ we
            * edges[..., None]
        )

        # Uncomment lines below to update edges (also edit return line so new_edges is returned)
        # Update edges, use jnp.any to have new_edges be of shape N x N
        # new_edges = jnp.any(ek, axis=-1)

        # Normalize over edge features w/layer normalization
        # new_edges = hk.LayerNorm(axis=[0,1], create_scale=False, create_offset=False, eps=1e-05)(new_edges)

        # take sum over neighbors to get ebar shape = Nf x Mf
        ebar = jnp.sum(ek, axis=1)

        # dense layer for new nodes to get new_nodes shape = N x Nf
        new_nodes = jax.nn.leaky_relu(ebar @ wv) + nodes  # Use leaky ReLU

        # Normalize over node features w/layer normalization
        new_nodes = hk.LayerNorm(
            axis=[0, 1], create_scale=False, create_offset=False, eps=1e-05
        )(new_nodes)

        # sum over nodes to get shape features so global_node_features shape = Nf
        global_node_features = jnp.sum(new_nodes, axis=0)

        # dense layer for new features so new_features shape = Gf
        new_features = (
            jax.nn.leaky_relu(global_node_features @ wu) + features
        )  # Use leaky ReLU for activation

        # just return features for ease of use
        return new_nodes, edges, new_features


def model_fn(x):
    nodes, edges = x
    features = jnp.ones(graph_feat_length)
    x = nodes, edges, features

    # NOTE: If edited num_GNN_layers, need to edit code below (increase or decrease # times have x = GNNLayer(...))
    # 4 GNN layers
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)

    # 2 dense layers
    logits = hk.Linear(numClasses)(x[-1])
    # logits = jax.nn.relu(logits) #ReLU activation between dense layer
    logits = hk.Linear(numClasses)(logits)

    return logits  # Model now returns logits


model = hk.without_apply_rng(hk.transform(model_fn))

# Initialize model
rng = jax.random.PRNGKey(0)
sampleData = data.take(1)
for dataVal in sampleData:  # Look into later how to get larger set
    (nodes_i, edges_i), yi = dataVal
nodes_i = nodes_i.numpy()
edges_i = edges_i.numpy()

yi = yi.numpy()
xi = (nodes_i, edges_i)

params = model.init(rng, xi)

In [None]:
# Load optimal parameters for GNN model
print("Edit fileName to change parameters being loaded")
fileName = "optParams_dry-waterfall-17.npy"  # Currently optimal parameters, edit when get better model
paramsArr = jnp.load(fileName, allow_pickle=True)
opt_params = {
    "gnn_layer": {
        "b": paramsArr[0],
        "we": paramsArr[1],
        "wu": paramsArr[2],
        "wv": paramsArr[3],
    },
    "gnn_layer_1": {
        "b": paramsArr[4],
        "we": paramsArr[5],
        "wu": paramsArr[6],
        "wv": paramsArr[7],
    },
    "gnn_layer_2": {
        "b": paramsArr[8],
        "we": paramsArr[9],
        "wu": paramsArr[10],
        "wv": paramsArr[11],
    },
    "gnn_layer_3": {
        "b": paramsArr[12],
        "we": paramsArr[13],
        "wu": paramsArr[14],
        "wv": paramsArr[15],
    },
    "linear": {"b": paramsArr[16], "w": paramsArr[17]},
    "linear_1": {"b": paramsArr[18], "w": paramsArr[19]},
}

### Explanation & plotting related functions

#### Modified functions used when selecting counterfactuals such that only the label for the selected scent is flipped

In [None]:
def _select_examples(cond, examples, nmols):
    result = []

    # similarity filtered by if cluster/counter
    def cluster_score(e, i):
        return (e.cluster == i) * cond(e) * e.similarity

    clusters = set([e.cluster for e in examples])
    for i in clusters:
        close_counter = max(examples, key=lambda e, i=i: cluster_score(e, i))
        # check if actually is (since call could have been zero)
        if cluster_score(close_counter, i):
            result.append(close_counter)

    # trim, in case we had too many cluster
    result = sorted(result, key=lambda v: v.similarity * cond(v), reverse=True)[:nmols]

    # fill in remaining
    ncount = sum([cond(e) for e in result])
    fill = max(0, nmols - ncount)
    result.extend(
        sorted(examples, key=lambda v: v.similarity * cond(v), reverse=True)[:fill]
    )

    return list(filter(cond, result))


def cf_explain(examples, nmols):
    """From given :obj:`Examples<Example>`, find closest counterfactuals (see :doc:`index`)
    :param examples: Output from :func:`sample_space`
    :param nmols: Desired number of molecules
    """

    def is_counter(e):
        return e.yhat != examples[0].yhat

    scent_classes = np.array([x for x in scentClasses])
    origin_labels = np.array(
        [my_model(examples[0].smiles, scent) for scent in scent_classes]
    )
    print(origin_labels)
    #     print([scentClasses[i] for i, x in enumerate(origin_labels) if x==1])
    print(scent_classes[origin_labels == 1])

    selection = _select_examples(is_counter, examples[1:], len(examples[1:]))
    selection = sorted(
        selection, key=lambda v: v.similarity * is_counter(v), reverse=True
    )

    result = []
    for e in selection:
        labels = np.array([my_model(e.smiles, scent) for scent in scent_classes])
        flips = origin_labels != labels
        print(f"flipped labels: {scent_classes[flips]}")
        flip_count = np.sum(flips)
        if flip_count < 2 and e.similarity > 0.3:
            result.append(e)
        if len(result) == nmols:
            break

    for i, r in enumerate(result):
        r.label = f"Counterfactual {i+1}"

    return examples[:1] + result

#### Cosine similarity function used to compute similarity for descriptor explanations with multiple base molecules

In [None]:
# compute dot product with labels
def cosine_similarity_base(df, bases, llists):
    df["label_dot"] = np.array(0.0)
    for j, row in df.iterrows():
        if j in bases:
            base = j
            df["label_dot"][j] = 1
        else:
            # cosine similarity
            if np.all(llists[j] == 0):
                df["label_dot"][j] = 0
                continue
            df["label_dot"][j] = (
                llists[base]
                @ llists[j]
                / np.linalg.norm(llists[base])
                / np.linalg.norm(llists[j])
            )
    return df

#### Functions used for generating plots with model fit & creating a list of exmol Examples from a csv file

In [None]:
# Code in this cell taken from Solubility-RNN.ipynb notebook from https://github.com/ur-whitelab/exmol/blob/main/paper2_LIME/Solubility-RNN.ipynb
def weighted_mean(x, w):
    return np.sum(x * w) / np.sum(w)


def weighted_cov(x, y, w):
    return np.sum(w * (x - weighted_mean(x, w)) * (y - weighted_mean(y, w))) / np.sum(w)


def weighted_correlation(x, y, w):
    return weighted_cov(x, y, w) / np.sqrt(
        weighted_cov(x, x, w) * weighted_cov(y, y, w)
    )


def plot_correlation(space, descriptor_type):
    beta = exmol.lime_explain(space, descriptor_type)
    fkw = {"figsize": (6, 4)}
    font = {"family": "normal", "weight": "normal", "size": 16}

    fig = plt.figure(figsize=(10, 5))
    matplotlib.rc("axes", titlesize=12)
    matplotlib.rc("font", size=16)
    ax_dict = fig.subplot_mosaic("AABBB")
    # Plot space by fit
    base_examples = [e for e in space if e.is_origin == True]
    svg = plot_space_by_fit_multiple_bases(
        space,
        base_examples,
        figure_kwargs=fkw,
        mol_size=(200, 200),
        offset=1,
        ax=ax_dict["B"],
        beta=beta,
    )
    # Compute y_wls
    w = np.array([1 / (1 + (1 / (e.similarity + 0.000001) - 1) ** 5) for e in space])
    non_zero = w > 10 ** (-6)
    w = w[non_zero]
    N = w.shape[0]

    ys = np.array([e.yhat for e in space])[non_zero].reshape(N).astype(float)
    x_mat = np.array([list(e.descriptors.descriptors) for e in space])[
        non_zero
    ].reshape(N, -1)
    y_wls = x_mat @ beta
    y_wls += np.mean(ys)

    lower = np.min(ys)
    higher = np.max(ys)

    # set transparency using w
    norm = plt.Normalize(min(w), max(w))
    cmap = plt.cm.Oranges(w)
    cmap[:, -1] = w

    corr = weighted_correlation(ys, y_wls, w)

    ax_dict["A"].plot(
        np.linspace(lower, higher, 100),
        np.linspace(lower, higher, 100),
        "--",
        linewidth=2,
    )
    sc = ax_dict["A"].scatter(ys, y_wls, s=50, marker=".", c=cmap, cmap=cmap)
    ax_dict["A"].text(max(ys) - 10, min(ys) + 1, f"weighted \ncorrelation = {corr:.3f}")
    ax_dict["A"].set_xlabel(r"$\hat{y}$")
    ax_dict["A"].set_ylabel(r"$g$")
    ax_dict["A"].set_title("Weighted Least Squares Fit")
    ax_dict["A"].set_xlim(lower - 0.1, higher + 0.1)
    ax_dict["A"].set_ylim(lower - 0.1, higher + 0.1)
    ax_dict["A"].set_aspect(1.0 / ax_dict["A"].get_data_ratio(), adjustable="box")
    sm = plt.cm.ScalarMappable(cmap=plt.cm.Oranges, norm=norm)
    cbar = plt.colorbar(sm, orientation="horizontal", pad=0.15, ax=ax_dict["A"])
    cbar.set_label("Chemical similarity")
    plt.tight_layout()

In [None]:
# Modified function for multiple bases based on plot_space_by_fit function in exmol plot_utils.py file
def plot_space_by_fit_multiple_bases(
    examples,
    exps,
    beta,
    mol_size=(200, 200),
    mol_fontsize=8,
    offset=0,
    ax=None,
    figure_kwargs=None,
    cartoon=False,
    rasterized=False,
):
    imgs = exmol.plot_utils._mol_images(exps, mol_size, mol_fontsize)
    if figure_kwargs is None:
        figure_kwargs = {"figsize": (12, 8)}
    base_color = "gray"
    if ax is None:
        ax = plt.figure(**figure_kwargs).gca()

    yhat = np.array([e.yhat for e in examples]).astype(float)
    yhat -= np.mean(yhat)
    x_mat = np.array([list(e.descriptors.descriptors) for e in examples]).reshape(
        len(examples), -1
    )
    y = x_mat @ beta
    # use resids as colors
    colors = (yhat - y) ** 2
    normalizer = plt.Normalize(min(colors), max(colors))
    cmap = "PuBu_r"

    space_x = [e.position[0] for e in examples]
    space_y = [e.position[1] for e in examples]
    if cartoon:
        # plot shading, lines, front
        ax.scatter(space_x, space_y, 50, "0.0", lw=2, rasterized=rasterized)
        ax.scatter(space_x, space_y, 50, "1.0", lw=0, rasterized=rasterized)
        ax.scatter(
            space_x,
            space_y,
            50,
            c=normalizer(colors),
            cmap=cmap,
            lw=2,
            alpha=0.1,
            rasterized=rasterized,
        )
    else:
        im = ax.scatter(
            space_x,
            space_y,
            40,
            c=normalizer(colors),
            cmap=cmap,
            edgecolors="grey",
            linewidth=0.25,
        )
    ax.set_aspect(1.0 / ax.get_data_ratio(), adjustable="box")
    cbar = plt.colorbar(im, orientation="horizontal", aspect=35, pad=0.05)
    cbar.set_label("squared error")

    # now plot cfs/annotated points
    ax.scatter(
        [e.position[0] for e in exps],
        [e.position[1] for e in exps],
        c=normalizer([e.yhat for e in exps]),
        cmap=cmap,
        edgecolors="black",
    )

    x = np.array([e.position[0] for e in exps])
    y = np.array([e.position[1] for e in exps])

    titles = []
    colors = []
    for e in exps:
        if not e.is_origin:
            titles.append(f"Similarity = {e.similarity:.2f}\nf(x)={e.yhat:.3f}")
            colors.append(cast(any, base_color))
        else:
            titles.append(f"Base \nf(x)={e.yhat:.3f}")
            colors.append(cast(any, base_color))
    # exmol.plot_utils._image_scatter(x, y, imgs, titles, colors, ax, offset=offset)
    ax.axis("off")
    ax.set_aspect("auto")

In [None]:
def createExampleListfromDataFrame(data):
    exampleList = []  # list[exmol.Example]()
    for i in range(len(data.index)):
        # Since reading position values from csv file, data.position is a string - need to convert to list of floats
        positionString = data.position.tolist()[i]
        positionString = positionString.replace("[", "")
        positionString = positionString.replace("]", "")
        positionList = positionString.split(" ")
        # remove empty string ('') elements in positionList
        while "" in positionList:
            positionList.remove("")
        positions = [float(p) for p in positionList]
        # using weighted tanimoto with dot product
        exampleList.append(
            exmol.Example(
                data.smiles.tolist()[i],
                data.selfies.tolist()[i],
                data.similarity.tolist()[i],
                data.yhat.tolist()[i],
                data.index.tolist()[i],
                positions,
                data.is_origin.tolist()[i],
                data.cluster.tolist()[i],
                data.label.tolist()[i],
            )
        )
    return exampleList

#### Create modified alphabet to use when creating explanations

In [None]:
# Create modified alphabet to use
alphabet = exmol.get_basic_alphabet()
to_remove = []
# remove [B],[#B],[=B]
to_remove.extend(["[B]", "[#B]", "[=B]"])

# remove [I],[F],[Cl], [Br]
to_remove.extend(["[I]", "[F]", "[Cl]", "[Br]"])

alphabet -= set(to_remove)

In [None]:
# Check that alphabet does not contain boron, iodine, fluorine, bromine nor chlorine
"""
print(alphabet)

boronInAlphabet = ("[#B]" in alphabet) and ("[B]" in alphabet) and ("[=B]" in alphabet)
iodineInAlphabet = ("[I]" in alphabet)
fluorineInAlphabet = ("[F]" in alphabet)
bromineInAlphabet = ("[Br]" in alphabet)
chlorineInAlphabet = ("[Cl]" in alphabet)

#Boron is in exmol basic alphabet
print("Check that using modified basic alphabet (i.e. no Boron, Iodine, Fluorine, Bromine, nor Chlorine)")
print(f"Boron is in basic alphabet? {boronInAlphabet}")
print(f"Iodine is in basic alphabet? {iodineInAlphabet}")
print(f"Fluorine is in basic alphabet? {fluorineInAlphabet}")
print(f"Bromine is in basic alphabet? {bromineInAlphabet}")
print(f"Chlorine is in basic alphabet? {chlorineInAlphabet}")
"""

### Generating counterfactual explanations for ethyl benzoate ("fruity")

In [None]:
# For counterfactuals, need hard yhat values - read in thresholds from csv file
thresholds = pd.read_csv(
    "ThresholdsForMaxF1_OdorlessClassRemoved_dry-waterfall-17_trainAndValid.csv"
)

In [None]:
# Model function that takes in SMILES string and scent string as input (rather than molecular graph + parameter)
# Output is prediction on whether molecule has a certain scent (1) or not (0)
def my_model(smilesString, scentString):
    molecularGraph = gen_smiles2graph(smilesString)
    pos = scentClasses.index(scentString)
    thresholdIndex_scent = thresholds.index[thresholds.Scent == scentString].tolist()
    threshold = thresholds.Threshold[thresholdIndex_scent].tolist()[
        0
    ]  # Threshold is the one that maximizes the F1 score
    pred = jax.nn.sigmoid(model.apply(opt_params, molecularGraph))[pos]
    if pred > threshold:
        pred = 1
    else:
        pred = 0
    return pred

In [None]:
# Molecules being examined/generating counterfactuals for:
fruity_ethylbenzoate = "CCOC(=O)C1=CC=CC=C1"  # Ethyl benzoate (https://pubchem.ncbi.nlm.nih.gov/compound/Ethyl-benzoate#section=InChI)

fatty_24decadienal = "CCCCCC=CC=CC=O"  # 2,4 Decadienal (https://pubchem.ncbi.nlm.nih.gov/compound/2_4-Decadienal#section=InChIKey)

In [None]:
# Uncomment code below to generate ethyl benzoate counterfactuals around fruity
"""
samples_ethylbenzoate_fruity = exmol.sample_space(
    fruity_ethylbenzoate,
    lambda smi, sel: my_model(smi, "fruity"),
    batched=False,
    preset="medium",
    num_samples=20000,
    method_kwargs= {'alphabet': alphabet}
)

cfs_ethylbenzoate_fruity = cf_explain(samples_ethylbenzoate_fruity, nmols=1)
exmol.plot_cf(cfs_ethylbenzoate_fruity[:2], figure_kwargs = {'figsize': (10,8)}, mol_size=(300, 300), nrows = 1)
plt.show()
"""

In [None]:
# Read in ethyl benzoate cfs file & create image
cfs_ethylbenzoate_fruity = createExampleListfromDataFrame(
    pd.read_csv("cfs_ethylbenzoate_fruity.csv")
)

# Uncomment lines below to save results to csv
# cfs_ethylbenzoate_fruity_data = pd.DataFrame(cfs_ethylbenzoate_fruity)
# cfs_ethylbenzoate_fruity_data.to_csv("cfs_ethylbenzoate_fruity.csv", index=False)
# samples_ethylbenzoate_fruity_data = pd.DataFrame(samples_ethylbenzoate_fruity)
# samples_ethylbenzoate_fruity_data.to_csv("samples_ethylbenzoate_fruity.csv", index=False)

exmol.plot_cf(
    cfs_ethylbenzoate_fruity[:2],
    figure_kwargs={"figsize": (10, 8)},
    mol_size=(300, 300),
    nrows=1,
)
plt.show()
# plt.savefig("cfs_ethylbenzoate_fruity.png")

### Generate descriptor explanations & natural language explanations for "fatty"

In [None]:
# Function takes in a SMILES string for a molecule and returns the logit predictions for all scent classes
def my_model_allScents_logits(smilesString):
    molecularGraph = gen_smiles2graph(smilesString)
    logits_predictions = model.apply(opt_params, molecularGraph)
    return logits_predictions


# Function takes in a SMILES string for a molecule and returns the logit predictions for the specified scent
def my_model_logits(smilesString, scentString):
    molecularGraph = gen_smiles2graph(smilesString)
    pos = scentClasses.index(scentString)
    pred_logits = model.apply(opt_params, molecularGraph)[pos]
    return pred_logits

#### Create sample space

In [None]:
# Uncomment code in cell below to generate sample space
"""
#For each positive example, generate sample space (using STONED) around that example (use that example as the base molecule)
space_total = []
scent = 'fatty'
for i in range(numMolecules): #sample space created using positive examples
    molecule = scentdata.smiles[i]
    if(moleculeScentList[i].count(scent) == 1): 
        sampleSpace = exmol.sample_space(molecule, lambda smi, sel: my_model_logits(smi,scent), batched=False, preset='medium', num_samples=200, method_kwargs= {'alphabet': alphabet})
        space_total.extend(sampleSpace)

space_total = pd.DataFrame(space_total)

#Get predictions for all scent classes for all molecules in the space
predictions_logits = []
mols = space_total['smiles']
for mol in mols:
    preds_logits = my_model_allScents_logits(mol)
    predictions_logits.append(preds_logits)

#Save the space along with scent class predictions to a csv file
scentClassesNamesWithLogits = []
for c in scentClasses:
    scentLogitString = f'{c} + logits'
    scentClassesNamesWithLogits.append(scentLogitString)

space_total[scentClassesNamesWithLogits] = predictions_logits
scentFileName = f'space_{scent}.csv'
space_total.to_csv(scentFileName)
space_total.head()
"""

#### Read in sample space from csv file & create list of exmol Examples from it

In [None]:
# Download file for "fatty" space from figshare
filename, headers = urllib.request.urlretrieve(
    "https://figshare.com/ndownloader/files/38472275"
)

In [None]:
scent = "fatty"
# NOTE: below reading in a subset of the "fatty" sample space from figshare, code not working right now
df = pd.read_csv(
    filename,
    usecols=np.arange(1, 11),
)
labels = pd.read_csv(
    filename,
    usecols=np.append([1], np.arange(11, 123)),
)
llists = labels.to_numpy()[:, 1:]
bases = list(df[df["is_origin"] == True].index)

df = cosine_similarity_base(df, bases, llists)
df["similarity"] = df["similarity"] * df["label_dot"]

# Use smaller set of data (just to check code does not crash) - comment out line below when generating results
df = df.sample(frac=0.2, random_state=0).reset_index(drop=True)

samples = createExampleListfromDataFrame(df)

In [None]:
# delete pandas dataframe to save space
colNames = df.columns.values.tolist()
for c in colNames:
    del df[c]

#### Create descriptor explanation plots for "fatty" & plot model fit for each explanatory model

In [None]:
beta_ECFP = exmol.lime_explain(
    samples, descriptor_type="ECFP"
)  # create ECFP descriptor explanation

svg_ECFP = exmol.plot_descriptors(
    samples, figure_kwargs={"figsize": (9, 5)}, return_svg=True
)
cairosvg.svg2png(svg_ECFP, write_to=f"{scent}_ecfp.png", background_color="white")

base_examples = [e for e in samples if e.is_origin == True]

plot_correlation(samples, "ECFP")  # plot ECFP model fit
# plt.savefig(f'{scent}_ECFPfitAndCorrelation.png', bbox_inches="tight", transparent=False)

In [None]:
beta_MACCS = exmol.lime_explain(
    samples, descriptor_type="MACCS"
)  # create MACCS descriptor explanation

svg_MACCS = exmol.plot_descriptors(
    samples, figure_kwargs={"figsize": (9, 5)}, return_svg=True
)
cairosvg.svg2png(svg_MACCS, write_to=f"{scent}_maccs.png", background_color="white")

base_examples = [e for e in samples if e.is_origin == True]

plot_correlation(samples, "MACCS")  # plot MACCS model fit
# plt.savefig(f'{scent}_MACCSfitAndCorrelation.png', bbox_inches="tight", transparent=False)

#### Create natural language explanations for "fatty" scent

In [None]:
exmol.lime_explain(samples, "ecfp")
s1_ecfp = exmol.text_explain(samples, "ecfp", presence_thresh=0.1)
exmol.lime_explain(samples, "maccs")
s2_maccs = exmol.text_explain(samples, "maccs", presence_thresh=0.1)

# print(f's1_ecfp: {s1_ecfp}\n')
# print(f's2_maccs: {s2_maccs}\n')

explanation_list = exmol.merge_text_explains(s1_ecfp, s2_maccs, filter=1.96)[:5]

# explanation_list = exmol.merge_text_explains(ecfp_standardized,maccs_standardized, filter=1.96)[:5]
nle = exmol.text_explain_generate(explanation_list, f"{scent} scent")
print(nle)

#### Additional code to examine positive examples for the "fatty" scent

In [None]:
base_examples = [e for e in samples if e.is_origin]
rdkit.Chem.Draw.MolsToGridImage(
    [rdkit.Chem.MolFromSmiles(s.smiles) for s in base_examples], molsPerRow=6
)

In [None]:
maccs_string = "Is there a C=O double bond?"
count = 0
for base_mol in base_examples:
    pos_desc = base_mol.descriptors.descriptor_names.index(maccs_string)
    desc = base_mol.descriptors.descriptors[pos_desc]
    # print(desc, base_mol.descriptors.descriptor_names[pos_desc])
    if desc == 1:
        count += 1
print(
    f'For the "fatty" base molecules, {count}/{len(base_examples)} match the descriptor: {maccs_string}'
)