# Notebooke 2: Use MCTS with ML evaluators to generate chemical compounds

This notebook demonstrates how to manipulate MCTS search parameters to generate molecules wtih a more sophisticated chemical evaluator, MinervaChem.

Set a seed for reproducible MCTS results.

In [1]:
import random
seed = 42
random.seed(seed)

In [2]:
from minervachem.mcts import Node, BondEnergy, utcbeam
from minervachem.mcts import make_tree_nodes, make_node_info, mass_plotting

import pandas as pd
from rdkit import Chem

## 1: Creating a MinervaChem model

First, we need to create a MinervaChem model to be used as an evaluator for MCTS.

This model will be made from the ```qm9``` dataset. To create this, follow the instructions from section 1 of ```/demos/2_Regression_and_visualization.ipynb```. The instructions are also copied below:

>We'll be working with the QM9 dataset:
>
>Ramakrishnan, R., Dral, P., Rupp, M. et al. Quantum chemistry structures and properties of 134 kilo molecules. Sci Data 1, 140022 (2014). https://doi.org/10.1038/sdata.2014.22
>
>This dataset is publically available, and we provide a python script to preprocess it into a CSV file that is convenient to work with. 
>
>The script is `/demos/process_qm9.py`, and to run it you will need to download a few things from the QM9 dataset, and place them in a new folder `/demos/qm9_data_files/`. The items needed are:
>
>
>1) `qm9_data_files/dsgdb9nsd.xyz.tar.bz2` - The main file, a tarball which contains inside it the properties for each molecule.
>2) `qm9_data_files/uncharacterized.txt` - A list of molecules which failed a consistency check in the data generation process. These are usually excluded from machine learning analyses.
>3) `qm9_data_files/atomref.txt` - Reference atom values for different types of energy calculations.
>
>https://doi.org/10.6084/m9.figshare.978904
>
>Run the script using:
>
>    python preprocess_qm9.py
>    
>    
>The script takes approximately 1 minute to run. When complete, you should now have a file called `/demos/qm9_processed.csv`.

Once you've created the dataset, create the MinervaChem model as a pipeline object with ```create_pipeline.py```. This takes around 3 minutes on an Apple M2.

In [None]:
from create_pipeline import create_pipeline

df = pd.read_csv('../qm9_processed.csv')
df['mol'] = df['smiles'].map(lambda s: Chem.AddHs(Chem.MolFromSmiles(s)))

pipeline = create_pipeline(df)

## 2: Setting MCTS search parameters

As shown in Notebook 1, we can set MCTS's search parameters for any specific scenario. (See Notebook 1 for a full explanation of each parameter.)

In this example, we are using a MinervaChem ML model to optimized for atomization energy.

In [4]:
e_at_target = -2000 # target atomization energy value
e_at_max = -1000 # max atomization energy

sa_target = 0 # target synthesizability score (sa score)
sa_max = 5 # max sa score

# SMILES symbols that are choices for molecular generation
choices = ['C', 'O', '=', 'N', 'c', '1', 'F', '\n']

levels = 5 # size of molecule

num_sims = 10000 # number of simulations, default 10000; set smaller to test code
turn = levels + 4 # counter to track number of turns
beamsize = 10 # set the beamsize
alpha = 2 # boltzmann constant

## 3: Run MCTS with MinervaChem

Because MinervaChem is a more complex evaluator, this extends MCTS's run time from <30 seconds with RDKit. On an Apple M2, this cell takes 5-10 minutes to run.

In [None]:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*') # these lines are to silence RDKit warnings from invalid molecules

current_nodes = [Node(BondEnergy(
    e_at_target=e_at_target, 
    sa_target=sa_target, 
    allchoices=choices,
    e_at_max=e_at_max,
    sa_max = sa_max,
    turn=turn,
    model=pipeline))]
for l in range(levels):
    next_nodes=utcbeam(budget=num_sims,rootpop=current_nodes, beamsize=beamsize, alpha=alpha, scalar=0.4)
    for i in current_nodes:
        print(f"Level {l}")
        print(f"This is one of the current nodes: {i}")
        print(f"Num Children: {len(i.children)}")
        for j,c in enumerate(i.children):
            print(j,c)
    print("These are the best children:")
    for i in next_nodes:
        print(i)
    current_nodes = next_nodes
    print("--------------------------------")

We can display the graphical representations of the generated molecules.

In [None]:
allsmiles = []
allmols = []

for i in current_nodes:
    smiles = i.state.smiles
    allsmiles += [smiles]

    mol = Chem.MolFromSmiles(smiles)
    allmols.append(mol)
    mol.SetProp("name", str(Chem.rdMolDescriptors.CalcMolFormula(mol)))

img = Chem.Draw.MolsToGridImage(
    allmols, 
    legends=[f"{mol.GetProp('name')}" for mol in allmols],
    subImgSize=(350,350))

img

## 4: Plot tree searches

We can visualize the search paths with the same visualization module used in Notebook 1. Options for plotting values are ```'visits'```, ```'reward'```, ```'e_at'```, and ```'sa_score'```.

In [None]:
for i, j in enumerate(current_nodes):
    tuples = make_tree_nodes(node=j, size=levels)
    node_info = make_node_info(node=j, size=levels)
    # display(node_info)
    mass_plotting(node_info=node_info, params=['e_at'], tuples=tuples, smiles=allsmiles[i])