In [None]:
# Counterfactual generation using exmol package

In [None]:
# print('Remember to update CUDA_VISIBLE_DEVICES')
# For GPU nodes, edit value below based on allocated GPU
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
# Imports
#!pip install matplotlib numpy pandas seaborn jax jaxlib dm-haiku tensorflow exmol
import exmol
import tensorflow as tf
import seaborn as sns
import jax.numpy as jnp
import jax
import jax.experimental.optimizers as opt
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import haiku as hk
import numpy as np
import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw
import sklearn.metrics

import warnings

warnings.filterwarnings("ignore")
sns.set_context("notebook")
sns.set_style(
    "dark",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#666666",
        "ytick.color": "#666666",
        "axes.edgecolor": "#666666",
        "axes.linewidth": 0.8,
        "figure.dpi": 300,
    },
)
color_cycle = ["#1BBC9B", "#F06060", "#5C4B51", "#F3B562", "#6e5687"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)
np.random.seed(0)
tf.random.set_seed(0)

### Counterfactual Generation with GNN Model

**GNN Model Related Code** 

Code modified from example code given in the "Predicting DFT Energies with GNNs" and "Interpretability and Deep Learning" sections of "Deep Learning for Molecules and Materials" textbook (https://whitead.github.io/dmol-book/applied/QM9.html))

In [None]:
# Parameters for GNN model
node_feat_length = 256
message_feat_length = 256
graph_feat_length = 512
weights_stddevGNN = 0.01

# Code to load data & generate graphs + labels
# Load data --> file uploaded to jhub (locally stored)
scentdata = pd.read_csv("leffingwell_data_shuffled.csv")

# Code to generate list of all scent labels (scentClasses)
numMolecules = len(scentdata.odor_labels_filtered)
numClasses = 113
scentClasses = []
moleculeScentList = []
for i in range(numMolecules):
    scentString = scentdata.odor_labels_filtered[i]
    temp = scentString.replace("[", "")
    temp = temp.replace("]", "")
    temp = temp.replace("'", "")
    temp = temp.replace(" ", "")
    scentList = temp.split(",")
    moleculeScentList.append(scentList)
    for j in range(len(scentList)):
        if not (scentList[j] in scentClasses):
            scentClasses.append(scentList[j])

# Check to make sure read in data properly & created scentClasses & moleculeScentList correctly
print(f"Is the number of scent classes 113?: {len(scentClasses)==113}")
print(f"Is the number of molecules 3523?: {len(moleculeScentList)==3523}")


def gen_smiles2graph(sml):
    """Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    """
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {
        rdkit.Chem.rdchem.BondType.SINGLE: 1,
        rdkit.Chem.rdchem.BondType.DOUBLE: 2,
        rdkit.Chem.rdchem.BondType.TRIPLE: 3,
        rdkit.Chem.rdchem.BondType.AROMATIC: 4,
    }
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N, node_feat_length))
    for i in m.GetAtoms():
        nodes[i.GetIdx(), i.GetAtomicNum()] = 1
        # Add in whether atom is in a ring or not for one-hot encoding
        if i.IsInRing():
            nodes[i.GetIdx(), -1] = 1

    adj = np.zeros((N, N))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning("Ignoring bond order" + order)
        adj[u, v] = 1
        adj[v, u] = 1
    adj += np.eye(N)
    return nodes, adj


# Function that creates label vector given list of strings describing scent of molecule as input
# Each index in label vector corresponds to specific scent -> if output has a 0 at index i, then molecule does not have scent i
# If label vector has 1 at index i, then molecule does have scent i
def createLabelVector(scentsList):
    # Find class index in label vector that each scent corresponds to & update label for that molecule to 1
    labelVector = np.zeros(numClasses)
    for j in range(len(scentsList)):
        # Find class index
        classIndex = scentClasses.index(scentsList[j])
        # print(classIndex)
        # print(scentsList[j])
        # print(scentClasses[classIndex])
        # Update label vector
        labelVector[classIndex] = 1
    return labelVector


def generateGraphs():
    for i in range(numMolecules):
        graph = gen_smiles2graph(scentdata.smiles[i])
        labels = createLabelVector(moleculeScentList[i])
        yield graph, labels


# Check that generateGraphs() works for 1st molecule
# print(gen_smiles2graph(scentdata.SMILES[0]))
# print(scentdata.SENTENCE[0].split(','))
# print(np.nonzero(createLabelVector(scentdata.SENTENCE[0].split(','))))
# print(scentClasses[89])
data = tf.data.Dataset.from_generator(
    generateGraphs,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, node_feat_length]), tf.TensorShape([None, None])),
        tf.TensorShape([None]),
    ),
)

In [None]:
class GNNLayer(
    hk.Module
):  # TODO: If increase number of layers, stack features & new_features and shrink via dense layer
    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, inputs):
        # split input into nodes, edges & features
        nodes, edges, features = inputs
        # Nodes is of shape (N, Nf) --> N = # atoms, Nf = node_feature_length
        # Edges is of shape (N,N) (adjacency matrix)
        # Features is of shape (Gf) --> Gf = graph_feature_length

        graph_feature_len = features.shape[-1]  # graph_feature_len (Gf)
        node_feature_len = nodes.shape[-1]  # node_feature_len (Nf)
        message_feature_len = message_feat_length  # message_feature_length (Mf)

        # Initialize weights
        w_init = hk.initializers.RandomNormal(stddev=weights_stddevGNN)

        # we is of shape (Nf,Mf)
        we = hk.get_parameter(
            "we", shape=[node_feature_len, message_feature_len], init=w_init
        )

        # b is of shape (Mf)
        b = hk.get_parameter("b", shape=[message_feature_len], init=w_init)

        # wv is of shape (Mf,Nf)
        wv = hk.get_parameter(
            "wv", shape=[message_feature_len, node_feature_len], init=w_init
        )

        # wu is of shape (Nf,Gf)
        wu = hk.get_parameter(
            "wu", shape=[node_feature_len, graph_feature_len], init=w_init
        )

        # make nodes be N x N x Nf so we can just multiply directly (N = number of atoms)
        # ek is now shaped N x N x Mf
        ek = jax.nn.leaky_relu(
            b
            + jnp.repeat(nodes[jnp.newaxis, ...], nodes.shape[0], axis=0)
            @ we
            * edges[..., None]
        )

        # Uncomment lines below to update edges
        # Update edges, use jnp.any to have new_edges be of shape N x N
        # new_edges = jnp.any(ek, axis=-1)

        # Normalize over edge features w/layer normalization
        # new_edges = hk.LayerNorm(axis=[0,1], create_scale=False, create_offset=False, eps=1e-05)(new_edges)

        # take sum over neighbors to get ebar shape = Nf x Mf
        ebar = jnp.sum(ek, axis=1)

        # dense layer for new nodes to get new_nodes shape = N x Nf
        new_nodes = jax.nn.leaky_relu(ebar @ wv) + nodes  # Use leaky ReLU

        # Normalize over node features w/layer normalization
        new_nodes = hk.LayerNorm(
            axis=[0, 1], create_scale=False, create_offset=False, eps=1e-05
        )(new_nodes)

        # sum over nodes to get shape features so global_node_features shape = Nf
        global_node_features = jnp.sum(new_nodes, axis=0)

        # dense layer for new features so new_features shape = Gf
        new_features = (
            jax.nn.leaky_relu(global_node_features @ wu) + features
        )  # Use leaky ReLU for activation

        return new_nodes, edges, new_features


def model_fn(x):
    nodes, edges = x
    features = jnp.ones(graph_feat_length)
    x = nodes, edges, features

    # NOTE: If edited config.num_GNN_layers, need to edit code below (increase or decrease # times have x = GNNLayer(...))
    # 4 GNN layers
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)
    x = GNNLayer(output_size=graph_feat_length)(x)

    # 2 dense layers
    logits = hk.Linear(numClasses)(x[-1])
    # logits = jax.nn.relu(logits) #ReLU activation between dense layer
    logits = hk.Linear(numClasses)(logits)

    return logits  # Model now returns logits


model = hk.without_apply_rng(hk.transform(model_fn))

# Initialize model
rng = jax.random.PRNGKey(0)
sampleData = data.take(1)
for dataVal in sampleData:  # Look into later how to get larger set
    (nodes_i, edges_i), yi = dataVal
nodes_i = nodes_i.numpy()
edges_i = edges_i.numpy()

yi = yi.numpy()
xi = (nodes_i, edges_i)

params = model.init(rng, xi)

In [None]:
# Load optimal parameters for GNN model (stored locally)
print("Edit fileName to change parameters being loaded")
fileName = "optParams_100Epochs_astral-pond-171.npy"  # Currently optimal parameters, edit when get better model
paramsArr = jnp.load(fileName, allow_pickle=True)
opt_params = {
    "gnn_layer": {
        "b": paramsArr[0],
        "we": paramsArr[1],
        "wu": paramsArr[2],
        "wv": paramsArr[3],
    },
    "gnn_layer_1": {
        "b": paramsArr[4],
        "we": paramsArr[5],
        "wu": paramsArr[6],
        "wv": paramsArr[7],
    },
    "gnn_layer_2": {
        "b": paramsArr[8],
        "we": paramsArr[9],
        "wu": paramsArr[10],
        "wv": paramsArr[11],
    },
    "gnn_layer_3": {
        "b": paramsArr[12],
        "we": paramsArr[13],
        "wu": paramsArr[14],
        "wv": paramsArr[15],
    },
    "linear": {"b": paramsArr[16], "w": paramsArr[17]},
    "linear_1": {"b": paramsArr[18], "w": paramsArr[19]},
}

In [None]:
# Read in threshold values for each scent class (in test set) that maximizes F1 score
thresholds = pd.read_csv("ThresholdsForMaxF1.csv")

**Counterfactual Generation Code using exmol package**

Code using exmol package modified based on that given in exmol documentation (https://ur-whitelab.github.io/exmol/)

In [None]:
# Model function that takes in SMILES string and scent string as input (rather than molecular graph + parameter)
# Output is prediction on whether molecule has a certain scent (1) or not (0)
def my_model(smilesString, scentString):
    molecularGraph = gen_smiles2graph(smilesString)
    pos = scentClasses.index(scentString)
    thresholdIndex_scent = thresholds.index[thresholds.Scent == scentString].tolist()
    threshold = thresholds.Threshold[thresholdIndex_scent].tolist()[
        0
    ]  # Threshold is the one that maximizes the F1 score
    pred = jax.nn.sigmoid(model.apply(opt_params, molecularGraph))[pos]
    if pred > threshold:
        pred = 1
    else:
        pred = 0
    return pred

In [None]:
# If saving counterfactuals or sample space generated to csv, use createExampleListfromDataFrame when plotting/using exmol functions

# Need to convert pandas dataframe (created after reading in csv) to a list of Example (Examples definition: https://github.com/ur-whitelab/exmol/blob/main/exmol/data.py)
def createExampleListfromDataFrame(data):
    exampleList = list[exmol.Example]()
    for i in range(len(data.index)):
        exampleList.append(
            exmol.Example(
                data.smiles.tolist()[i],
                data.selfies.tolist()[i],
                data.similarity.tolist()[i],
                data.yhat.tolist()[i],
                data.index.tolist()[i],
                data.position.tolist()[i],
                data.is_origin.tolist()[i],
                data.cluster.tolist()[i],
                data.label.tolist()[i],
            )
        )
    return exampleList


# Uncomment lines below for plotting counterfactuals from csv files saved
# cfs1 = createExampleListfromDataFrame(pd.read_csv('cfs1_green.csv'))
# exmol.plot_cf(cfs1, nrows=1)
# plt.savefig('cfs1_green.png') #Save image of counterfactuals

In [None]:
# SMILES Strings for molecules that have "interesting" scents/structure-scent relations
vanillin = "COc1cc(C=O)ccc1O"  # Vanillin smiles string (https://www.sigmaaldrich.com/US/en/product/SIGMA/V2375)
isovanillin = "COc1ccc(C=O)cc1O"  # Isovanillin smiles string (https://www.sigmaaldrich.com/US/en/product/aldrich/59940)

muscone = "CC1CCCCCCCCCCCCC(=O)C1"  # Muscone smiles string (https://pubchem.ncbi.nlm.nih.gov/compound/3-Methylcyclopentadecanone#section=InChI-Key)
muskKetone = "CC1=C(C(=C(C(=C1[N+](=O)[O-])C(C)(C)C)[N+](=O)[O-])C)C(=O)C"  # Musk ketone smiles string (https://pubchem.ncbi.nlm.nih.gov/compound/Musk-ketone#section=InChI)

methylAnthranilate = "COC(=O)C1=CC=CC=C1N"  # Methyl anthranilate smiles string (grape scent) (https://pubchem.ncbi.nlm.nih.gov/compound/Methyl-anthranilate#section=InChI)

In [None]:
scentClasses

In [None]:
def _select_examples(cond, examples, nmols):
    result = []

    # similarity filtered by if cluster/counter
    def cluster_score(e, i):
        return (e.cluster == i) * cond(e) * e.similarity

    clusters = set([e.cluster for e in examples])
    for i in clusters:
        close_counter = max(examples, key=lambda e, i=i: cluster_score(e, i))
        # check if actually is (since call could have been zero)
        if cluster_score(close_counter, i):
            result.append(close_counter)

    # trim, in case we had too many cluster
    result = sorted(result, key=lambda v: v.similarity * cond(v), reverse=True)[:nmols]

    # fill in remaining
    ncount = sum([cond(e) for e in result])
    fill = max(0, nmols - ncount)
    result.extend(
        sorted(examples, key=lambda v: v.similarity * cond(v), reverse=True)[:fill]
    )

    return list(filter(cond, result))


def cf_explain(examples, nmols):
    """From given :obj:`Examples<Example>`, find closest counterfactuals (see :doc:`index`)
    :param examples: Output from :func:`sample_space`
    :param nmols: Desired number of molecules
    """

    def is_counter(e):
        return e.yhat != examples[0].yhat

    scent_classes = np.array([x for x in scentClasses])
    origin_labels = np.array(
        [my_model(examples[0].smiles, scent) for scent in scent_classes]
    )
    print(origin_labels)
    #     print([scentClasses[i] for i, x in enumerate(origin_labels) if x==1])
    print(scent_classes[origin_labels == 1])

    selection = _select_examples(is_counter, examples[1:], len(examples[1:]))
    selection = sorted(
        selection, key=lambda v: v.similarity * is_counter(v), reverse=True
    )

    result = []
    for e in selection:
        labels = np.array([my_model(e.smiles, scent) for scent in scent_classes])
        flips = origin_labels != labels
        print(f"flipped labels: {scent_classes[flips]}")
        flip_count = np.sum(flips)
        if flip_count < 2 and e.similarity > 0.3:
            result.append(e)
        if len(result) == nmols:
            break

    for i, r in enumerate(result):
        r.label = f"Counterfactual {i+1}"

    return examples[:1] + result

#### Vanillin counterfactuals ('vanilla scent')

In [None]:
samples_vanillin = pd.read_csv("samples_vanillin.csv")
samples_vanillin

In [None]:
def createExampleListfromDataFrame(data):
    exampleList = []  # list[exmol.Example]()
    for i in range(len(data.index)):
        exampleList.append(
            exmol.Example(
                data.smiles.tolist()[i],
                data.selfies.tolist()[i],
                data.similarity.tolist()[i],
                data.yhat.tolist()[i],
                data.index.tolist()[i],
                data.position.tolist()[i],
                data.is_origin.tolist()[i],
                data.cluster.tolist()[i],
                data.label.tolist()[i],
            )
        )
    return exampleList

In [None]:
samples_vanillin = createExampleListfromDataFrame(samples_vanillin)

In [None]:
# Generate vanillin counterfactuals
samples_vanillin = exmol.sample_space(
    vanillin,
    lambda smi, sel: my_model(smi, "vanilla"),
    batched=False,
    preset="medium",
    num_samples=50000,
)

In [None]:
cfs_vanillin = cf_explain(samples_vanillin, nmols=10)

In [None]:
# Plot sample space with vanillin as base molecule
# exmol.plot_space(samples_vanillin, cfs_vanillin)
# plt.savefig('sampleSpace_vanillin.png')

# Uncomment lines below to save results to csv
cfs_vanillin_data = pd.DataFrame(cfs_vanillin)
cfs_vanillin_data.to_csv("cfs_vanillin.csv", index=False)
samples_vanillin_data = pd.DataFrame(samples_vanillin)
samples_vanillin_data.to_csv("samples_vanillin.csv", index=False)

In [None]:
cfs_vanillin_data = pd.DataFrame(cfs_vanillin)
cfs_vanillin_data.to_csv("cfs_vanillin.csv", index=False)

In [None]:
exmol.plot_cf(cfs_vanillin[:6], nrows=2)  # , nrows=3)
plt.savefig("cfs_vanillin_truncated.png")

#### Methyl anthranilate ('grape' scent) 

In [None]:
# Methyl anthranilate
# Generate methyl anthranilate counterfactuals
samples_methylAnthranilate = exmol.sample_space(
    methylAnthranilate,
    lambda smi, sel: my_model(smi, "grape"),
    batched=False,
    preset="medium",
    num_samples=50000,
)

In [None]:
samples_methylAnthranilate = pd.read_csv("samples_methylAnthranilate.csv")
samples_methylAnthranilate = createExampleListfromDataFrame(samples_methylAnthranilate)

In [None]:
# Plot methyl anthranilate counterfactuals
cfs_methylAnthranilate = cf_explain(samples_methylAnthranilate, nmols=10)
print(f"Methyl anthranilate Counterfactuals: ")
exmol.plot_cf(cfs_methylAnthranilate, nrows=3)
plt.savefig("cfs_methylAnthranilate.png")

In [None]:
# Uncomment lines below to save results to csv
cfs_methylAnthranilate_data = pd.DataFrame(cfs_methylAnthranilate)
cfs_methylAnthranilate_data.to_csv("cfs_methylAnthranilate.csv", index=False)
samples_methylAnthranilate_data = pd.DataFrame(samples_methylAnthranilate)
samples_methylAnthranilate_data.to_csv("samples_methylAnthranilate.csv", index=False)

In [None]:
cfs_methylAnthranilate_data = pd.DataFrame(cfs_methylAnthranilate)
cfs_methylAnthranilate_data.to_csv("cfs_methylAnthranilate.csv", index=False)

In [None]:
exmol.plot_cf(cfs_methylAnthranilate[:6], nrows=2)
plt.savefig("cfs_methylAnthranilate.png")