# Notebook 3: Delver further into MCTS-generated chemicals

This notebook explores some of the non-fully expanded molecules generated by MCTS in the RDKit solubility use-case.

Set a seed for reproducible MCTS results.

In [1]:
import random
seed = 42
random.seed(seed)

In [2]:
from minervachem.mcts.tree import Node, LogP, utcbeam, top_percent_search
from rdkit import Chem

import pandas as pd
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*') # these lines are to silence RDKit warnings from invalid molecules

## 1: Get MCTS-generated molecules

We are using the molecules from solubility RDKit use case.

In [None]:
logp_goal = -0.5 # target logP value
logp_max = 20 # max allowed logP value

sa_goal = 0 # target synthesizability score
sa_max = 10 # max allowed SA score value

# SMILES symbols that are choices for molecular generation
choices = ['C', 'O', '=', 'N', 'c', '1', 'S', 'P', 'F', '\n']

levels = 5 # size of molecule

num_sims = 10000 # number of simulations, default 10000; set smaller to test code
turn = levels + 4 # counter to track number of turns
beamsize = 10 # set the beamsize
alpha = 2 # boltzmann constant

current_nodes = [Node(LogP(
    logp_target=logp_goal, 
    sa_target=sa_goal, 
    allchoices=choices,
    logp_max=logp_max,
    turn=turn))]
for l in range(levels):
    next_nodes=utcbeam(budget=num_sims,rootpop=current_nodes, beamsize=beamsize, alpha=alpha, scalar=0.4)
    for i in current_nodes:
        print(f"Level {l}")
        print(f"This is one of the current nodes: {i}")
        print(f"Num Children: {len(i.children)}")
        for j,c in enumerate(i.children):
            print(j,c)
    print("These are the best children:")
    for i in next_nodes:
        print(i)
    current_nodes = next_nodes
    print("--------------------------------")

## 2: Get the non-fully expanded children 

We do this by looping back through the fully expanded children and saving the top 1%, 5% and 10% of children based on reward value.

In [4]:
nonexpanded = {}
all_nodes0 = []
tiers = [0.01, 0.05, 0.10]
for j in tiers:
    if j not in nonexpanded:
        nonexpanded[j] = []
    for i in current_nodes:
        root_node = i
        while root_node.parent is not None:
            root_node = root_node.parent
        out1, all_node = top_percent_search(root_node,top_percent=j)
        nonexpanded[j].extend(out1)
        all_nodes0.extend(all_node)

all_nodes = {**nonexpanded, "best": current_nodes}

For ease of data manipulation and visualization, we can put the molecules and their meta data in a pandas dataframe. We will drop any duplicate molecules at the same time.

In [None]:
# visualize distribution of nonexpanded children
data = []
for k in all_nodes:
    for i in all_nodes[k]:
        smiles = i.state.smiles
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            canon_smiles = Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
            data.append(
                {
                    "top percent": k,
                    "visits": i.visits,
                    "reward": i.reward/i.visits,
                    "logp": i.state.logp,
                    "sa_score": i.state.sa_score,
                    "smiles": i.state.smiles,
                    "canon smiles": canon_smiles,
                    "size with terminal": len(i.state.moves),
                    "size without terminal": len([j for j in i.state.smiles if j != "\n"]),
                    "turn": i.state.turn
                }
            )
df = pd.DataFrame(data)
df = df.drop_duplicates(subset=['canon smiles'], keep='first')
df

## 3: Now, we can do any exploration

We can make some plots.

In [9]:
import seaborn as sns
import matplotlib.pyplot as plt
from rdkit import DataStructs
from rdkit.Chem.Fingerprints import FingerprintMols
import numpy as np

In [None]:
turns = sorted(list(df['turn'].unique()))

sns.histplot(x=df['logp'], hue=df['turn'], palette='viridis_r', alpha=0.3)
plt.vlines(x=logp_goal, ymin=1, ymax=plt.gca().get_ylim()[1], colors='r', linestyle='--')
plt.yscale("log")
plt.show()

For example, we can evaluate the chemical diversity of the MCTS generated molecules by comparing them to some randomly generated molecules via Tanimoto similarity.

In [None]:
# Smiles component "token" choices
smiles_choices = ['C', 'O', '=', 'N', 'c', '1', 'S', 'P', 'F', '\n'] #same as MCTS search

# Number of atoms possible
maxn = df['size with terminal'].max() #same as MCTS search
minn = df['size with terminal'].min()
numbers = np.arange(minn,maxn,1)

# Number of molecules to generate and compare
number = len(df)

###########################

smiles_lst = []
unique_smiles = set()  # Use a set to keep track of unique SMILES
for _ in range(number):
    while True:  # Keep trying until we generate a unique SMILES
        # Randomly sample different molecule sizes
        tmpn = int(np.random.choice(numbers, 1)[0])
        things = ''.join(np.random.choice(smiles_choices, tmpn).tolist())
        
        # Check if this SMILES string is unique
        if things not in unique_smiles:
            unique_smiles.add(things)  # Add to set of unique SMILES
            smiles_lst.append(things)
            break 

# Check that this is a valid smiles.
can_smiles = []
for ds in smiles_lst:
    try:
        cs = Chem.CanonSmiles(ds)
        can_smiles.append(cs)
    except:
        print('Invalid SMILES:', ds)

# Get Molecules and Fingerprints
mols = [Chem.MolFromSmiles(x) for x in can_smiles]

# Need to specify fingerpintSize to ensure comparison across all valid smiles possible.
fps = [FingerprintMols.FingerprintMol(x, minPath=1, maxPath=7, fpSize=2048,
                               bitsPerHash=2, useHs=True, tgtDensity=0.0,
                               minSize=128) for x in mols]

# Lists for the dataframe
firsts, seconds, sim = [], [], []
# compare all fp pairwise without duplicates
for n in range(len(fps)-1): # -1 so the last fp will not be used
    s = DataStructs.BulkTanimotoSimilarity(fps[n], fps[n+1:]) # +1 compare with the next to the last fp
    # collect the SMILES and values
    for i in range(len(s)):
        firsts.append(can_smiles[n])
        seconds.append(can_smiles[n+1:][i])
        sim.append(s[i])
        
d = {'first_smi':firsts, 'second_smi':seconds, 'similarity':sim}
df_rand = pd.DataFrame(data=d)

df_rand

In [None]:
# plot Tanimoto similarity
# 0 indicates no common bits between fingerprints
# 1 indicates identical bits between fingerprints

tiers = list(df['top percent'].unique())
firsts, seconds, sim, top_p_first, top_p_second = [], [], [], [], []

for k in tiers:
    smiles_dict = dict(zip(df.loc[df['top percent'] == k, 'canon smiles'], df.loc[df['top percent'] == k, 'top percent']))

    can_smiles = df.loc[df['top percent'] == k, 'canon smiles'].to_list()

    # Get Molecules and Fingerprints
    mols = [Chem.MolFromSmiles(x) for x in can_smiles]

    # Need to specify fingerpintSize to ensure comparison across all valid smiles possible.
    fps = [FingerprintMols.FingerprintMol(x, minPath=1, maxPath=7, fpSize=128,
                                bitsPerHash=2, useHs=True, tgtDensity=0.0,
                                minSize=128) for x in mols]

    # Lists for the dataframe
    # compare all fp pairwise WITHOUT duplicates
    for n in range(len(fps)-1): # -1 so the last fp will not be used
        s = DataStructs.BulkTanimotoSimilarity(fps[n], fps[n+1:]) # +1 compare with the next to the last fp
        # collect the SMILES and values
        for i in range(len(s)):
            first_smi = can_smiles[n]
            second_smi = can_smiles[n+1:][i]
            firsts.append(first_smi)
            seconds.append(second_smi)
            sim.append(s[i])
            top_p_first.append(smiles_dict.get(first_smi))
            top_p_second.append(smiles_dict.get(second_smi))

df_final = pd.DataFrame({'first_smi':firsts, 'second_smi':seconds, 'similarity':sim, 'top p first_smi': top_p_first, 'top p second_smi': top_p_second})

df_final

In [None]:
# Set the binrange
# For tanimoto similarity range is 0-1 with 0 being totally unalike, and 1 being identical.
bins = np.arange(0,1.01,0.01)
sns.histplot(data=df_final, x='similarity', hue='top p first_smi', bins=bins, alpha=0.5, 
            stat='count', multiple='stack',
            palette='coolwarm', legend=True)
sns.histplot(data=df_rand, x='similarity', bins=bins, alpha=0.5, color='black')
# plt.hist(df_rand['similarity'], bins=bins, label='randomly generated', density=True, alpha=0.5)
plt.yscale("log")
plt.title("Chemical Diversity of Top X% of Non-Expanded Children")
plt.xlabel('Tanimoto Similarity')
plt.ylabel('Counts')


In [None]:
bins = np.arange(0,1.01,0.01)
unique_categories = df_final['top p first_smi'].unique()
for category in unique_categories:
    subset = df_final[df_final['top p first_smi'] == category]
    plt.hist(subset['similarity'], bins=bins, alpha=0.5, density=True, label=f"Top {category*100:.0f}%")
plt.hist(df_rand['similarity'], bins=bins, label='randomly generated', density=True, alpha=0.5)

plt.legend()
plt.yscale("log")
plt.title("Chemical Diversity of Children with Top n% Reward values")
plt.xlabel('Tanimoto Similarity')
plt.ylabel('Counts')