In [4]:
#Code modified from that given in "Efficient Interpolation & Exploration with STONED SELFIES" tutorial 
#Github link for that tutorial: https://github.com/aspuru-guzik-group/stoned-selfies
#Paper associated with that tutorial: "Beyond generative models: superfast traversal, optimization, novelty, exploration and discovery (STONED) algorithm for molecules using SELFIES"
#Paper link: https://doi.org/10.26434/chemrxiv.13383266.v2

In [5]:
print('Remember to update CUDA_VISIBLE_DEVICES')
#For GPU nodes, edit value below based on allocated GPU
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

Remember to update CUDA_VISIBLE_DEVICES


In [14]:
#Imports
import time 
import selfies
import rdkit
import random
import numpy as np
import random
from rdkit import Chem
from selfies import encoder, decoder
from rdkit.Chem import MolFromSmiles as smi2mol
from rdkit.Chem import AllChem
from rdkit.DataStructs.cDataStructs import TanimotoSimilarity
from rdkit.Chem import Mol
from rdkit.Chem.AtomPairs.Sheridan import GetBPFingerprint, GetBTFingerprint
from rdkit.Chem.Pharm2D import Generate, Gobbi_Pharm2D
from rdkit.Chem import Draw
from rdkit.Chem import rdFMCS


from rdkit.Chem import MolToSmiles as mol2smi
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

#!pip install matplotlib numpy pandas seaborn jax jaxlib dm-haiku tensorflow 

import tensorflow as tf
import seaborn as sns
import jax.numpy as jnp
import jax
import jax.experimental.optimizers as opt
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import haiku as hk

import warnings
warnings.filterwarnings('ignore')
sns.set_context('notebook')
sns.set_style('dark',  {'xtick.bottom':True, 'ytick.left':True, 'xtick.color': '#666666', 'ytick.color': '#666666',
                        'axes.edgecolor': '#666666', 'axes.linewidth':     0.8 , 'figure.dpi': 300})
color_cycle = ['#1BBC9B', '#F06060', '#5C4B51', '#F3B562', '#6e5687']
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=color_cycle) 
np.random.seed(0)
tf.random.set_seed(0)

**GNN Model Related Code** 

Code modified from example code given in the "Predicting DFT Energies with GNNs" and "Interpretability and Deep Learning" sections of "Deep Learning for Molecules and Materials" textbook (https://whitead.github.io/dmol-book/applied/QM9.html))

In [7]:
#Load data --> file uploaded to jhub (locally stored)
scentdata = pd.read_csv('train.csv')

#Read in vocabulary text file --> this file gives all of the scent classes used in dataset
file = open('vocabulary.txt')
#Create list that stores all scent classes
scentClasses = file.read().split('\n')
numClasses = len(scentClasses)
#print(numClasses)
#print(scentdata)
#print(scentClasses)

In [8]:
def gen_smiles2graph(sml):
    '''Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    '''
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {rdkit.Chem.rdchem.BondType.SINGLE: 1,
                    rdkit.Chem.rdchem.BondType.DOUBLE: 2,
                    rdkit.Chem.rdchem.BondType.TRIPLE: 3,
                    rdkit.Chem.rdchem.BondType.AROMATIC: 4}
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N,256))
    for i in m.GetAtoms():
        nodes[i.GetIdx(), i.GetAtomicNum()] = 1
    
    adj = np.zeros((N,N))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(),j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(),j.GetEndAtomIdx())        
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning('Ignoring bond order' + order)
        adj[u, v] = 1        
        adj[v, u] = 1
    adj += np.eye(N)
    return nodes, adj

In [9]:
#Function that creates label vector given list of strings describing scent of molecule as input
#Each index in label vector corresponds to specific scent -> if output has a 0 at index i, then molecule does not have scent i
#If label vector has 1 at index i, then molecule does have scent i

def createLabelVector(scentsList):
    #Find class index in label vector that each scent corresponds to & update label for that molecule to 1
    labelVector = np.zeros(numClasses)
    for j in range(len(scentsList)):
        #Find class index
        classIndex = scentClasses.index(scentsList[j])
        #print(classIndex)
        #print(scentsList[j])
        #print(scentClasses[classIndex])
        #Update label vector
        labelVector[classIndex] = 1
    return labelVector

def generateGraphs():
    for i in range(len(scentdata)):
        graph = gen_smiles2graph(scentdata.SMILES[i])   
        tempScents = scentdata.SENTENCE[i].split(',') #Create list of strings describing scent of molecule i
        labels = createLabelVector(tempScents)
        yield graph, labels

#Check that generateGraphs() works for 1st molecule
#print(gen_smiles2graph(scentdata.SMILES[0]))
#print(scentdata.SENTENCE[0].split(','))
#print(np.nonzero(createLabelVector(scentdata.SENTENCE[0].split(','))))
#print(scentClasses[89])
data = tf.data.Dataset.from_generator(generateGraphs, output_types=((tf.float32, tf.float32), tf.float32), 
                                      output_shapes=((tf.TensorShape([None, 256]), tf.TensorShape([None, None])), tf.TensorShape([None])))



In [17]:
#GNN Model
class GNNLayer(hk.Module): #TODO: If increase number of layers, stack features & new_features and shrink via dense layer

    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, inputs):
        # split input into nodes, edges & features
        nodes, edges, features = inputs
        #Nodes is of shape (N, Nf) --> N = # atoms, Nf = node_feature_length
        #Edges is of shape (N,N) (adjacency matrix)
        #Features is of shape (Gf) --> Gf = graph_feature_length

        graph_feature_len = features.shape[-1] #graph_feature_len (Gf)
        node_feature_len = nodes.shape[-1] #node_feature_len (Nf)
        message_feature_len = node_feature_len #message_feature_length (Mf)
        
        #Initialize weights
        w_init = hk.initializers.RandomNormal(stddev = 0.01)
        
        #we is of shape (Nf,Mf)
        we = hk.get_parameter("we", shape=[node_feature_len, message_feature_len], init=w_init)
        
        #b is of shape (Mf)
        b = hk.get_parameter("b", shape=[message_feature_len], init=w_init)
        
        #wv is of shape (Mf,Nf)
        wv = hk.get_parameter("wv", shape=[message_feature_len, node_feature_len], init=w_init)
        
        #wu is of shape (Nf,Gf)
        wu = hk.get_parameter("wu", shape=[node_feature_len, graph_feature_len], init=w_init)
        
        # make nodes be N x N x Nf so we can just multiply directly (N = number of atoms)
        # ek is now shaped N x N x Mf
        ek = jax.nn.leaky_relu(b + 
            jnp.repeat(nodes[jnp.newaxis,...], nodes.shape[0], axis=0) @ we * edges[..., None])

        #ek *= edges[...,None]
        
        #Update edges, use jnp.any to have new_edges be of shape N x N
        new_edges = jnp.any(ek, axis=-1)
        
        #Normalize over edge features w/layer normalization
        new_edges = hk.LayerNorm(axis=[0,1], create_scale=False, create_offset=False, eps=1e-05)(new_edges)
    
        # take sum over neighbors to get ebar shape = Nf x Mf
        ebar = jnp.sum(ek, axis=1)
        
        # dense layer for new nodes to get new_nodes shape = N x Nf
        new_nodes = jax.nn.leaky_relu(ebar @ wv) + nodes
        
        #Normalize over node features w/layer normalization
        new_nodes = hk.LayerNorm(axis=[0,1], create_scale=False, create_offset=False, eps=1e-05)(new_nodes)
        
        # sum over nodes to get shape features so global_node_features shape = Nf
        global_node_features = jnp.sum(new_nodes, axis=0)
        
        # dense layer for new features so new_features shape = Gf
        new_features = jax.nn.leaky_relu(global_node_features  @ wu) + features
        
        # just return features for ease of use
        return new_nodes, new_edges, new_features

    
def model_fn(x):
    nodes, edges = x
    features = jnp.ones(256)
    x = nodes, edges, features
    
    # 2 GNN layers
    x = GNNLayer(output_size=256)(x)
    x = GNNLayer(output_size=256)(x)
    
    # 2 dense layer
    logits = hk.Linear(numClasses)(x[-1])
    
    logits = hk.Linear(numClasses)(logits)
    return logits #Model now returns logits

model = hk.without_apply_rng(hk.transform(model_fn))

In [18]:
#Initialize model
rng = jax.random.PRNGKey(0)
sampleData = data.take(1)
for dataVal in sampleData: #Look into later how to get larger set
    (nodes_i, edges_i), yi = dataVal
nodes_i = nodes_i.numpy()
edges_i = edges_i.numpy()

yi = yi.numpy()
xi = (nodes_i,edges_i)

params = model.init(rng, xi)

In [19]:
#Load optimal parameters for GNN model 
print('Edit fileName to change parameters being loaded')
fileName = 'optParams_1000Epochs.npy' #NOT optimal parameters, need to edit once get best model
paramsArr = jnp.load(fileName, allow_pickle = True)
opt_params =  {'gnn_layer': {'b': paramsArr[0], 'we': paramsArr[1], 'wu': paramsArr[2], 'wv': paramsArr[3]},'gnn_layer_1': {'b': paramsArr[4], 'we': paramsArr[5], 'wu': paramsArr[6], 'wv': paramsArr[7]}, 'linear': {'b': paramsArr[8], 'w': paramsArr[9]} , 'linear_1': {'b': paramsArr[10], 'w': paramsArr[11]}}

Edit fileName to change parameters being loaded


**STONED-SELFIES related code**

Code below modified from that given in "Efficient Interpolation & Exploration with STONED SELFIES" tutorial 
- Github link for that tutorial: https://github.com/aspuru-guzik-group/stoned-selfies

Paper associated with that tutorial: "Beyond generative models: superfast traversal, optimization, novelty, exploration and discovery (STONED) algorithm for molecules using SELFIES"
- Paper link: https://doi.org/10.26434/chemrxiv.13383266.v2

In [20]:
def randomize_smiles(mol):
    '''Returns a random (dearomatized) SMILES given an rdkit mol object of a molecule.
    Parameters:
    mol (rdkit.Chem.rdchem.Mol) :  RdKit mol object (None if invalid smile string smi)
    
    Returns:
    mol (rdkit.Chem.rdchem.Mol) : RdKit mol object  (None if invalid smile string smi)
    '''
    if not mol:
        return None

    Chem.Kekulize(mol)
    return rdkit.Chem.MolToSmiles(mol, canonical=False, doRandom=True, isomericSmiles=False,  kekuleSmiles=True) 


def sanitize_smiles(smi):
    '''Return a canonical smile representation of smi
    
    Parameters:
    smi (string) : smile string to be canonicalized 
    
    Returns:
    mol (rdkit.Chem.rdchem.Mol) : RdKit mol object                          (None if invalid smile string smi)
    smi_canon (string)          : Canonicalized smile representation of smi (None if invalid smile string smi)
    conversion_successful (bool): True/False to indicate if conversion was  successful 
    '''
    try:
        mol = smi2mol(smi, sanitize=True)
        smi_canon = mol2smi(mol, isomericSmiles=False, canonical=True)
        return (mol, smi_canon, True)
    except:
        return (None, None, False)
    

def get_selfie_chars(selfie):
    '''Obtain a list of all selfie characters in string selfie
    
    Parameters: 
    selfie (string) : A selfie string - representing a molecule 
    
    Example: 
    >>> get_selfie_chars('[C][=C][C][=C][C][=C][Ring1][Branch1_1]')
    ['[C]', '[=C]', '[C]', '[=C]', '[C]', '[=C]', '[Ring1]', '[Branch1_1]']
    
    Returns:
    chars_selfie: list of selfie characters present in molecule selfie
    '''
    chars_selfie = [] # A list of all SELFIE sybols from string selfie
    while selfie != '':
        chars_selfie.append(selfie[selfie.find('['): selfie.find(']')+1])
        selfie = selfie[selfie.find(']')+1:]
    return chars_selfie


class _FingerprintCalculator:
    ''' Calculate the fingerprint for a molecule, given the fingerprint type
    Parameters: 
        mol (rdkit.Chem.rdchem.Mol) : RdKit mol object (None if invalid smile string smi)
        fp_type (string)            :Fingerprint type  (choices: AP/PHCO/BPF,BTF,PAT,ECFP4,ECFP6,FCFP4,FCFP6)  
    Returns:
        RDKit fingerprint object
    '''

    def get_fingerprint(self, mol: Mol, fp_type: str):
        method_name = 'get_' + fp_type
        method = getattr(self, method_name)
        if method is None:
            raise Exception(f'{fp_type} is not a supported fingerprint type.')
        return method(mol)

    def get_AP(self, mol: Mol):
        return AllChem.GetAtomPairFingerprint(mol, maxLength=10)

    def get_PHCO(self, mol: Mol):
        return Generate.Gen2DFingerprint(mol, Gobbi_Pharm2D.factory)

    def get_BPF(self, mol: Mol):
        return GetBPFingerprint(mol)

    def get_BTF(self, mol: Mol):
        return GetBTFingerprint(mol)

    def get_PATH(self, mol: Mol):
        return AllChem.RDKFingerprint(mol)

    def get_ECFP4(self, mol: Mol):
        return AllChem.GetMorganFingerprint(mol, 2)

    def get_ECFP6(self, mol: Mol):
        return AllChem.GetMorganFingerprint(mol, 3)

    def get_FCFP4(self, mol: Mol):
        return AllChem.GetMorganFingerprint(mol, 2, useFeatures=True)

    def get_FCFP6(self, mol: Mol):
        return AllChem.GetMorganFingerprint(mol, 3, useFeatures=True)


def get_fingerprint(mol: Mol, fp_type: str):
    ''' Fingerprint getter method. Fingerprint is returned after using object of 
        class '_FingerprintCalculator'
        
    Parameters: 
        mol (rdkit.Chem.rdchem.Mol) : RdKit mol object (None if invalid smile string smi)
        fp_type (string)            :Fingerprint type  (choices: AP/PHCO/BPF,BTF,PAT,ECFP4,ECFP6,FCFP4,FCFP6)  
    Returns:
        RDKit fingerprint object
        
    '''
    return _FingerprintCalculator().get_fingerprint(mol=mol, fp_type=fp_type)

In [39]:
#Code below modified to call GNN model to check whether molecule has desired scent
def mutate_selfie(selfie, scentToChange, max_molecules_len, write_fail_cases=False):
    '''Return a mutated selfie string (only one mutation on slefie is performed)
    
    Mutations are done until a valid molecule is obtained 
    Rules of mutation: With a 33.3% propbabily, either: 
        1. Add a random SELFIE character in the string
        2. Replace a random SELFIE character with another
        3. Delete a random character
    
    Parameters:
    selfie            (string)  : SELFIE string to be mutated 
    scentToChange     (string)  : String corresponding to scent that needs to be changed (mutated molecule should/should not have that scent)
    max_molecules_len (int)     : Mutations of SELFIE string are allowed up to this length
    write_fail_cases  (bool)    : If true, failed mutations are recorded in "selfie_failure_cases.txt"
    
    Returns:
    selfie_mutated    (string)  : Mutated SELFIE string
    smiles_canon      (string)  : canonical smile of mutated SELFIE string
    '''
    valid=False
    fail_counter = 0
    chars_selfie = get_selfie_chars(selfie)
    
    #Find index that corresponds to scent string that needs to be changed
    scentIndex = scentClasses.index(scentToChange)
    #Find whether input molecule has or does not have that scent
    smiles_inputMol = selfies.decoder(selfie) #Convert from SELFIES to SMILES
    inputMol_graph = gen_smiles2graph(smiles_inputMol)
    pred = model.apply(params, inputMol_graph)[scentIndex] #Logits --> + (has scent) or - (does not have scent)
    if pred > 0: 
        pred = 1 
        desiredPred_mutated = 0 #So mutated molecule should have predicted scent = 0
    else: 
        pred = 0
        desiredPred_mutated = 1 #So mutated molecule should have predicted scent = 1

    while not valid:
        fail_counter += 1
                
        alphabet = list(selfies.get_semantic_robust_alphabet()) # 34 SELFIE characters 

        choice_ls = [1, 2, 3] # 1=Insert; 2=Replace; 3=Delete
        random_choice = np.random.choice(choice_ls, 1)[0]
        
        # Insert a character in a Random Location
        if random_choice == 1: 
            random_index = np.random.randint(len(chars_selfie)+1)
            random_character = np.random.choice(alphabet, size=1)[0]
            
            selfie_mutated_chars = chars_selfie[:random_index] + [random_character] + chars_selfie[random_index:]

        # Replace a random character 
        elif random_choice == 2:                         
            random_index = np.random.randint(len(chars_selfie))
            random_character = np.random.choice(alphabet, size=1)[0]
            if random_index == 0:
                selfie_mutated_chars = [random_character] + chars_selfie[random_index+1:]
            else:
                selfie_mutated_chars = chars_selfie[:random_index] + [random_character] + chars_selfie[random_index+1:]
                
        # Delete a random character
        elif random_choice == 3: 
            random_index = np.random.randint(len(chars_selfie))
            if random_index == 0:
                selfie_mutated_chars = chars_selfie[random_index+1:]
            else:
                selfie_mutated_chars = chars_selfie[:random_index] + chars_selfie[random_index+1:]
                
        else: 
            raise Exception('Invalid Operation trying to be performed')

        selfie_mutated = "".join(x for x in selfie_mutated_chars)
        sf = "".join(x for x in chars_selfie)
        
        smiles_mutated = selfies.decoder(selfie_mutated) #Convert from SELFIES to SMILES
        mutatedMol_graph = gen_smiles2graph(smiles_mutated)
        pred_mutated = model.apply(params, mutatedMol_graph)[scentIndex]
        
        if pred_mutated > 0: 
            pred_mutated = 1 
        else: 
            pred_mutated = 0
        
        try:
            smiles = decoder(selfie_mutated)
            mol, smiles_canon, done = sanitize_smiles(smiles)
            if pred_mutated != desiredPred_mutated or len(selfie_mutated_chars) > max_molecules_len or smiles_canon=="":
                done = False
            if done:
                valid = True
            else:
                valid = False
        except:
            valid=False
            if fail_counter > 1 and write_fail_cases == True:
                f = open("selfie_failure_cases.txt", "a+")
                f.write('Tried to mutate SELFIE: '+str(sf)+' To Obtain: '+str(selfie_mutated) + '\n')
                f.close()
    
    return (selfie_mutated, smiles_canon)

def get_mutated_SELFIES(selfies_ls, num_mutations, scentToChange): 
    ''' Mutate all the SELFIES in 'selfies_ls' 'num_mutations' number of times. 
    
    Parameters:
    selfies_ls   (list)  : A list of SELFIES 
    num_mutations (int)  : number of mutations to perform on each SELFIES within 'selfies_ls'
    scentToChange (string): String describing predicted scent that should be changed (1->0 or 0->1) for mutated molecule
    
    Returns:
    selfies_ls   (list)  : A list of mutated SELFIES
    
    '''
    for _ in range(num_mutations): 
        selfie_ls_mut_ls = []
        for str_ in selfies_ls: 
            
            str_chars = get_selfie_chars(str_)
            max_molecules_len = len(str_chars) + num_mutations
            
            selfie_mutated, smiles_canon = mutate_selfie(str_,scentToChange, max_molecules_len)
            selfie_ls_mut_ls.append(selfie_mutated)
        
        selfies_ls = selfie_ls_mut_ls.copy()
    return selfies_ls


def get_fp_scores(smiles_back, target_smi, fp_type): 
    '''Calculate the Tanimoto fingerprint (using fp_type fingerint) similarity between a list 
       of SMILES and a known target structure (target_smi). 
       
    Parameters:
    smiles_back   (list) : A list of valid SMILES strings 
    target_smi (string)  : A valid SMILES string. Each smile in 'smiles_back' will be compared to this stucture
    fp_type (string)     : Type of fingerprint  (choices: AP/PHCO/BPF,BTF,PAT,ECFP4,ECFP6,FCFP4,FCFP6) 
    
    Returns: 
    smiles_back_scores (list of floats) : List of fingerprint similarities
    '''
    smiles_back_scores = []
    target    = Chem.MolFromSmiles(target_smi)

    fp_target = get_fingerprint(target, fp_type)

    for item in smiles_back: 
        mol    = Chem.MolFromSmiles(item)
        fp_mol = get_fingerprint(mol, fp_type)
        score  = TanimotoSimilarity(fp_mol, fp_target)
        smiles_back_scores.append(score)
    return smiles_back_scores

**Code to Generate Counterfactuals**

In [None]:
#Test
smi  = scentdata.SMILES[0] #Edit this line if want to change input molecule
inputMolGraph = gen_smiles2graph(smi)
scentString = 'fruity' #Edit this line if want to change scent examining (counterfactual will be molecule w/opposite prediction on this scentString)
inputMolScents = model.apply(params, inputMolGraph)
if(inputMolScents[scentClasses.index(scentString)] > 0):
    desiredPred = 0
else:
    desiredPred = 1

#print(inputMolScents)
print(f'Input molecule predicted scent for {scentString}: {inputMolScents[scentClasses.index(scentString)]}')
fp_type = 'PATH' #For now, using rdkit path fingerprint, can change later


total_time = time.time()
# num_random_samples = 50000 # For a more exhaustive search! 
num_random_samples = 100     
num_mutation_ls    = [1, 2, 3, 4, 5]

mol = Chem.MolFromSmiles(smi)
if mol == None: 
    raise Exception('Invalid starting structure encountered')

start_time = time.time()
randomized_smile_orderings  = [randomize_smiles(mol) for _ in range(num_random_samples)]

# Convert all the molecules to SELFIES
selfies_ls = [encoder(x) for x in randomized_smile_orderings]
print('Randomized molecules (in SELFIES) time: ', time.time()-start_time)


all_smiles_collect = []
all_smiles_collect_broken = []

start_time = time.time()
for num_mutations in num_mutation_ls: 
    # Mutate the SELFIES: 
    selfies_mut = get_mutated_SELFIES(selfies_ls.copy(), num_mutations=num_mutations, scentToChange = scentString)

    # Convert back to SMILES: 
    smiles_back = [decoder(x) for x in selfies_mut]
    all_smiles_collect = all_smiles_collect + smiles_back
    all_smiles_collect_broken.append(smiles_back)


print('Mutation obtainment time (back to smiles): ', time.time()-start_time)


# Work on:  all_smiles_collect
start_time = time.time()
canon_smi_ls = []
for item in all_smiles_collect: 
    mol, smi_canon, did_convert = sanitize_smiles(item)
    if mol == None or smi_canon == '' or did_convert == False: 
        raise Exception('Invalid smile string found')
    canon_smi_ls.append(smi_canon)
canon_smi_ls = list(set(canon_smi_ls))
print('Unique mutated structure obtainment time: ', time.time()-start_time)

start_time = time.time()
canon_smi_ls_scores = get_fp_scores(canon_smi_ls, target_smi=smi, fp_type=fp_type)
print('Fingerprint calculation time: ', time.time()-start_time)
print('Total time: ', time.time()-total_time)

# Molecules with fingerprint similarity > 0.8
indices_thresh_8 = [i for i,x in enumerate(canon_smi_ls_scores) if x > 0.8]
mols_8 = [Chem.MolFromSmiles(canon_smi_ls[idx]) for idx in indices_thresh_8]

# Molecules with fingerprint similarity > 0.6
indices_thresh_6 = [i for i,x in enumerate(canon_smi_ls_scores) if x > 0.6 and x < 0.8]
mols_6 = [Chem.MolFromSmiles(canon_smi_ls[idx]) for idx in indices_thresh_6]

# Molecules with fingerprint similarity > 0.4
indices_thresh_4 = [i for i,x in enumerate(canon_smi_ls_scores) if x > 0.4 and x < 0.6]
mols_4 = [Chem.MolFromSmiles(canon_smi_ls[idx]) for idx in indices_thresh_4]

Input molecule predicted scent for fruity: -1.2717444896697998
Randomized molecules (in SELFIES) time:  0.004717111587524414


In [None]:
#Show Image of input molecule that the counterfactual molecules are compared against
molInput = Chem.MolFromSmiles(smi)
print(f'Input molecule predicted scent for fruity: {inputMolScents[scentClasses.index(scentString)]}')
print('Input molecule:')

molInput

In [None]:
#Print prediction results for mutated molecules w/similarity > 0.8
predictions_8 = np.asarray([model.apply(params, gen_smiles2graph(canon_smi_ls[idx]))[scentClasses.index(scentString)] for idx in indices_thresh_8])
print(f'Prediction for each molecule: {predictions_8}')

#Show image of all molecules w/fingerprint similarity > 0.8
print('Counterfactual molecules w/similarity > 0.8')
img=Draw.MolsToGridImage(mols_8,molsPerRow=4,subImgSize=(200,200))    
img

In [None]:
#Show images of molecules that have similarity > 0.8 & are of opposite scent as input molecule
bestMols = []
numMolWithIncorrectScent = 0
for i in range(len(predictions_8)):
    if(desiredPred == 1):
        if(predictions_8[i] > 0): #Since want counterfactual molecule to have predicted scent < 0
            bestMols.append(mols_8[i])
        else:
            numMolWithIncorrectScent +=1
    else:
        if(predictions_8[i] < 0): #Since want counterfactual molecule to have predicted fruity scent < 0
            bestMols.append(mols_8[i])
        else:
            numMolWithIncorrectScent +=1

#Uncomment print line below to check that algorithm only gave molecules w/desired scent prediction
#print(f'Number of molecules generated w/stoned_selfies algorithm that do not have desired scent prediction:{numMolWithIncorrectScent}')
print('Counterfactual molecules w/similarity > 0.8 and correct predicted scent:')
img=Draw.MolsToGridImage(bestMols,molsPerRow=4,subImgSize=(200,200))    
img    

In [None]:
#Show image of molecule w/the highest similarity score & correct prediction alongside input molecule
#Also print out selfies strings for both
print('Input molecule & counterfactual molecule with highest similarity score & correct prediction')
topMolecules_correctPred = []
topMolecules_score = []

for i in range(len(indices_thresh_8)):
    currSMILES = canon_smi_ls[indices_thresh_8[i]]
    currPred = model.apply(params, gen_smiles2graph(currSMILES))[scentClasses.index(scentString)]
    if(currPred < 0):
        currPred = 0
    else:
        currPred = 1
    if currPred == desiredPred:
        topMolecules_correctPred.append(currSMILES)
        topMolecules_score.append(canon_smi_ls_scores[indices_thresh_8[i]])

indexMostSimilar = np.argmax(np.asarray(topMolecules_score))
mostSimilarMol_correctPred = Chem.MolFromSmiles(topMolecules_correctPred[indexMostSimilar])
smiles_mostSimilarMol = topMolecules_correctPred[indexMostSimilar]
selfies_mostSimilarMol = selfies.encoder(smiles_mostSimilarMol)
selfies_inputMol = selfies.encoder(smi)

print(f'Input molecule SELFIES:\n {selfies_inputMol}')
print(f'Input molecule prediction for scent {scentString}: {jax.nn.sigmoid(model.apply(params, gen_smiles2graph(smi))[scentClasses.index(scentString)])}')
print(f'\nCounterfactual molecule SELFIES:\n {selfies_mostSimilarMol}')
print(f'Counterfactual molecule prediction for scent {scentString}: {jax.nn.sigmoid(model.apply(params, gen_smiles2graph(smiles_mostSimilarMol))[scentClasses.index(scentString)])}')
img=Draw.MolsToGridImage([mostSimilarMol_correctPred, molInput],molsPerRow=4,subImgSize=(200,200),legends = ('Mutated molecule','Input molecule'))    
img

In [None]:
#Highlight common substructure between counterfactual & input molecule
commonSubstructure = rdFMCS.FindMCS([mostSimilarMol_correctPred, molInput],matchValences = True, ringMatchesRingOnly=True,completeRingsOnly=True)
commonSubstructure = commonSubstructure.smartsString
commonSubstructure_mol = Chem.MolFromSmarts(commonSubstructure)
mostSimilarMol_correctPred_highlighted = mostSimilarMol_correctPred.GetSubstructMatches(commonSubstructure_mol)
molInput_highlighted = molInput.GetSubstructMatches(commonSubstructure_mol)
print('Input molecule')
molInput

In [None]:
print('Counterfactual molecule')
mostSimilarMol_correctPred

In [None]:
#Run the cells below to isolate difference between the generated counterfactual & input molecule
#Try and isolate difference between the 2 molecules -> delete common substructure
mostSimilarMol_substructRemoved = Chem.DeleteSubstructs(mostSimilarMol_correctPred,commonSubstructure_mol)
mostSimilarMol_substructRemoved

In [None]:
#Try w/another mutated molecule (not most similar one, but still has correct prediction & similarity > 0.8)
mutatedMol2 = Chem.MolFromSmiles(topMolecules_correctPred[2]) #This molecule has more than 1 difference from common substructure
commonSub2 = rdFMCS.FindMCS([mutatedMol2, molInput],matchValences = True, ringMatchesRingOnly=True,completeRingsOnly=True)
commonSub2 = Chem.MolFromSmarts(commonSub2.smartsString)
mutatedMol2.GetSubstructMatches(commonSub2)
mutatedMol2

In [None]:
#delete common substructure for mutated molecule 2
mutatedMol2_substructRemoved = Chem.DeleteSubstructs(mutatedMol2,commonSub2)
mutatedMol2_substructRemoved

In [None]:
#Try w/mutated molecule that differs much more (similarity between 0.6 & 0.8)
mutatedMol3 = mols_6[0]
commonSub3 = rdFMCS.FindMCS([mutatedMol3, molInput],matchValences = True, ringMatchesRingOnly=True,completeRingsOnly=True)
commonSub3 = Chem.MolFromSmarts(commonSub3.smartsString)
mutatedMol3.GetSubstructMatches(commonSub3)
mutatedMol3

In [None]:
#delete common substructure for mutated molecule 3
mutatedMol3_substructRemoved = Chem.DeleteSubstructs(mutatedMol3,commonSub3)
mutatedMol3_substructRemoved

In [None]:
#Print prediction results for mutated molecules w/similarity > 0.6
predictions_6 = np.asarray([jax.nn.sigmoid(model.apply(params, gen_smiles2graph(canon_smi_ls[idx]))[scentClasses.index(scentString)]) for idx in indices_thresh_6])
for i in range(len(predictions_6)):
    if predictions_6[i] > 0.5:
        predictions_6[i] = 1
    else:
        predictions_6[i] = 0
print(predictions_6)
#Image of all molecules w/fingerprint similarity > 0.6
print('Counterfactual molecules w/ 0.8 > similarity > 0.6:')
img=Draw.MolsToGridImage(mols_6[:10],molsPerRow=4,subImgSize=(200,200))    
img

In [None]:
#Print prediction results for mutated molecules w/similarity > 0.6
predictions_4 = np.asarray([jax.nn.sigmoid(model.apply(params, gen_smiles2graph(canon_smi_ls[idx]))[scentClasses.index(scentString)]) for idx in indices_thresh_4])
for i in range(len(predictions_4)):
    if predictions_4[i] > 0.5:
        predictions_4[i] = 1
    else:
        predictions_4[i] = 0
print(predictions_4)
#Image of all molecules w/fingerprint similarity > 0.5
print('Counterfactual molecules w/ 0.6 > similarity > 0.4:')
img=Draw.MolsToGridImage(mols_4[:10],molsPerRow=4,subImgSize=(200,200))    
img