# Notebook 1: Use MCTS to generate chemical compounds based on user requirements

This notebook demonstrates how to manipulate MCTS search parameters to generate molecules for a simple use-case with RDKit.

Set a seed for reproducible MCTS results.

In [1]:
import random
seed = 42
random.seed(seed)

In [2]:
from minervachem.mcts import Node, LogP, utcbeam
from minervachem.mcts import make_tree_nodes, make_node_info, mass_plotting
from rdkit import Chem
import pandas as pd
import matplotlib.pyplot as plt


## 1: Setting MCTS search parameters

### Chemistry-specific parameters

MCTS can be used to optimize for any quantifiable chemical property.

In this example, we search for and generate molecules that trend toward a target solubility value.

We are using RDKit's function ```Descriptors.MolLogP()``` to calculate logP, a measure of solubility.

Set a target value for logP and the maxmium allowed logP value.

In [3]:
logp_goal = -0.5 # target logP value
logp_max = 20 # max allowed logP value

The maximum allowed logP value or ```logp_max``` is used in the reward function calculation, or the distance of the current SMILES string's logP from the target value. If the current SMILES string's logP value is too far from the target, the resulting reward value does not sufficiently reward or penalize the MCTS search

This value is determined somewhat arbitrarily. A good guide is to look at a distribution of the logP values of molecules with similar size and chemical make up and to set the max value as some number of standard deviations from the mean. The number of standard deviations will depend on how strictly close you want the search to be.

### An example for estimating maximum value for a target

As an example, let's look at the qm9 dataset. To create this, follow the instructions from section 1 of ```/demos/2_Regression_and_visualization.ipynb```. The instructions are also copied below:

>We'll be working with the QM9 dataset:
>
>Ramakrishnan, R., Dral, P., Rupp, M. et al. Quantum chemistry structures and properties of 134 kilo molecules. Sci Data 1, 140022 (2014). https://doi.org/10.1038/sdata.2014.22
>
>This dataset is publically available, and we provide a python script to preprocess it into a CSV file that is convenient to work with. 
>
>The script is `/demos/process_qm9.py`, and to run it you will need to download a few things from the QM9 dataset, and place them in a new folder `/demos/qm9_data_files/`. The items needed are:
>
>
>1) `qm9_data_files/dsgdb9nsd.xyz.tar.bz2` - The main file, a tarball which contains inside it the properties for each molecule.
>2) `qm9_data_files/uncharacterized.txt` - A list of molecules which failed a consistency check in the data generation process. These are usually excluded from machine learning analyses.
>3) `qm9_data_files/atomref.txt` - Reference atom values for different types of energy calculations.
>
>https://doi.org/10.6084/m9.figshare.978904
>
>Run the script using:
>
>    python preprocess_qm9.py
>    
>    
>The script takes approximately 1 minute to run. When complete, you should now have a file called `/demos/qm9_processed.csv`.

In [None]:
df = pd.read_csv('../qm9_processed.csv')

df['logp'] = df['smiles'].apply(lambda x: Chem.Descriptors.MolLogP(Chem.MolFromSmiles(x)))


mean = df['logp'].mean()
std = df['logp'].std()

plt.hist(df['logp'], bins=10, color='blue', edgecolor='black', alpha=0.7)
plt.title('Distribution of LogP values of QM9')
plt.xlabel('LogP')
plt.ylabel('Count')

num_std = 3

plt.axvline(mean, color='red', linestyle='--', linewidth=1.5, label=f'Mean: {mean:.2f}')
plt.axvline(mean + std*num_std, color='green', linestyle='--', linewidth=1.5, label=f'Mean + {num_std} Std: {mean + std*num_std:.2f}')
plt.axvline(mean - std*num_std, color='green', linestyle='--', linewidth=1.5, label=f'Mean - {num_std} Std: {mean - std*num_std:.2f}')
plt.legend()
plt.show()

If there are other chemical properties, you'd like to consider in the search, they can be easily added to the MCTS state.

In this example, we also consider a target synethesizability score (SA score) in addition to a target logP value.

SA score is evaluated with RDKit's function ```SA_Score.sascorer.calculateScore()```.

In [5]:
sa_goal = 0 # target synthesizability score
sa_max = 10 # max allowed SA score value

Reward value will be calculated by a weighted average of distance from the target logP and SA score.

In this demonstration, we are representing molecules as SMILES strings. MCTS will be generating molecules by growing SMILES strings. You can set the "move pool" by designating the SMILES symbols MCTS is allowed to add to the string.

Note: ```\n``` is a terminal symbol.

In [6]:
# SMILES symbols that are choices for molecular generation
choices = ['C', 'O', '=', 'N', 'c', '1', 'S', 'P', 'F', '\n']

Set the desired number of levels of the tree search. In the context of molecule generation, this is also the number of SMILES symbols in the optimized molecule.

In [7]:
levels = 5

### MCTS-specific parameters

Now, you can set parameters that will influence MCTS's search behavior.

```num_sims``` is the computational limit for MCTS

```turn``` is a counter to track the number of levels of the tree search

```beamsize``` is the number of top k candidates to select; Our version of MCTS does not select the single best child, it will select the k best children.

```alpha``` is the boltzmann constant; to lessen the distance between large and small values of children and to increase the chances of a small value child to be selected for exploration, we apply a Boltzmann distribution to the raw values of the children.

In [8]:
num_sims = 10000 # number of simulations, default 10000; set smaller to test code

turn = levels + 4 # counter to track number of turns

beamsize = 10 # set the beamsize

alpha = 2 # boltzmann constant

## 2: Running MCTS

With all user inputs set, we can now run MCTS to generate a batch of optimized molecules. With a fast evaluator like RDKit, this takes <30 seconds on an Apple M2 Chip.

For further details on how MCTS works in general, refer to:
1. Browne, C. B., Powley, E., Whitehouse, D., Lucas, S. M., Cowling, P. I., Rohlfshagen, P., ... & Colton, S. (2012). A survey of monte carlo tree search methods. IEEE Transactions on Computational Intelligence and AI in games, 4(1), 1-43.
2. Świechowski, M., Godlewski, K., Sawicki, B., & Mańdziuk, J. (2023). Monte Carlo tree search: A review of recent modifications and applications. Artificial Intelligence Review, 56(3), 2497-2562.

In [None]:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*') # these lines are to silence RDKit warnings from invalid molecules

current_nodes = [Node(LogP(
    logp_target=logp_goal, 
    sa_target=sa_goal, 
    allchoices=choices,
    logp_max=logp_max,
    turn=turn))]
for l in range(levels):
    next_nodes=utcbeam(budget=num_sims,rootpop=current_nodes, beamsize=beamsize, alpha=alpha, scalar=0.4)
    for i in current_nodes:
        print(f"Level {l}")
        print(f"This is one of the current nodes: {i}")
        print(f"Num Children: {len(i.children)}")
        for j,c in enumerate(i.children):
            print(j,c)
    print("These are the best children:")
    for i in next_nodes:
        print(i)
    current_nodes = next_nodes
    print("--------------------------------")

We can display the graphical representations of the generated molecules.

In [None]:
allsmiles = []
allmols = []

for i in current_nodes:
    smiles = i.state.smiles
    allsmiles += [smiles]

    mol = Chem.MolFromSmiles(smiles)
    allmols.append(mol)
    mol.SetProp("name", str(Chem.rdMolDescriptors.CalcMolFormula(mol)))
    mol.SetProp("logP", str(f"{Chem.Descriptors.MolLogP(mol):.3f}"))
    # mol.SetProp("SA score", str(f"{sascorer.calculateScore(mol):.3f}"))

img = Chem.Draw.MolsToGridImage(
    allmols, 
    legends=[f"{mol.GetProp('name')}" for mol in allmols],
    subImgSize=(350,350))

img

For further reading on other similar work of MCTS on chemical molecular generation, refer to:

1. Yang, X., Zhang, J., Yoshizoe, K., Terayama, K., & Tsuda, K. (2017). ChemTS: an efficient python library for de novo molecular generation. Science and technology of advanced materials, 18(1), 972-976.
2. Yang, X., Aasawat, T. K., & Yoshizoe, K. (2020). Practical massively parallel monte-carlo tree search applied to molecular design. arXiv preprint arXiv:2006.10504.


## 3: Plot tree searches

To help understand MCTS's search behavior with the set parameters, we can plot each final molecule's tree. Each visualized tree shows node values of interest at every level. Options for plotting values are ```'visits'```, ```'reward'```, ```'logp'```, and ```'sa_score'```.

In [None]:
for i, j in enumerate(current_nodes):
    tuples = make_tree_nodes(node=j, size=levels)
    node_info = make_node_info(node=j, size=levels)
    # display(node_info)
    mass_plotting(node_info=node_info, params=['reward', 'logp'], tuples=tuples, smiles=allsmiles[i])