## Chemprop: Molecular Property Prediction

[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/chemprop)](https://badge.fury.io/py/chemprop)
[![PyPI version](https://badge.fury.io/py/chemprop.svg)](https://badge.fury.io/py/chemprop)
[![Build Status](https://github.com/chemprop/chemprop/workflows/tests/badge.svg)](https://github.com/chemprop/chemprop)

This repository contains message passing neural networks for molecular property prediction as described in the paper [Analyzing Learned Molecular Representations for Property Prediction](https://pubs.acs.org/doi/abs/10.1021/acs.jcim.9b00237) and as used in the paper [A Deep Learning Approach to Antibiotic Discovery](https://www.cell.com/cell/fulltext/S0092-8674(20)30102-1).

**Documentation:** Full documentation of Chemprop is available at https://chemprop.readthedocs.io/en/latest/.

**Website:** A web prediction interface with some trained Chemprop models is available at [chemprop.csail.mit.edu](http://chemprop.csail.mit.edu).

**Tutorial:** These [slides](https://docs.google.com/presentation/d/14pbd9LTXzfPSJHyXYkfLxnK8Q80LhVnjImg8a3WqCRM/edit?usp=sharing) provide a Chemprop tutorial and highlight recent additions as of April 28th, 2020.

In [None]:
!git clone https://github.com/chemprop/chemprop.git
%cd /content/chemprop

!pip install -e .
!pip install rdkit-pypi==2021.3.1.5

In [None]:
!tar -zxvf data.tar.gz

### Usage: Chemprop Train

In [13]:
import logging
from typing import Callable

from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from tqdm import tqdm

from chemprop.args import TrainArgs
from chemprop.data import MoleculeDataLoader, MoleculeDataset
from chemprop.models import MoleculeModel
from chemprop.nn_utils import compute_gnorm, compute_pnorm, NoamLR

In [14]:
def train(model: MoleculeModel,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param loss_func: Loss function.
    :param optimizer: An optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for recording output.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()
    loss_sum = iter_count = 0

    for batch in tqdm(data_loader, total=len(data_loader), leave=False):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, target_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch, data_weights_batch = \
            batch.batch_graph(), batch.features(), batch.targets(), batch.atom_descriptors(), \
            batch.atom_features(), batch.bond_features(), batch.data_weights()

        mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])

        if args.target_weights is not None:
            target_weights = torch.Tensor(args.target_weights)
        else:
            target_weights = torch.ones_like(targets)
        data_weights = torch.Tensor(data_weights_batch).unsqueeze(1)

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch)

        # Move tensors to correct device
        mask = mask.to(preds.device)
        targets = targets.to(preds.device)
        target_weights = target_weights.to(preds.device)
        data_weights = data_weights.to(preds.device)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * target_weights * data_weights * mask
        else:
            loss = loss_func(preds, targets) * target_weights * data_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += 1

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum = iter_count = 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}')

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter

In [17]:
!python train.py --data_path data/tox21.csv --dataset_type classification --save_dir tox21_checkpoints

Command line
python train.py --data_path data/tox21.csv --dataset_type classification --save_dir tox21_checkpoints
Args
{'activation': 'ReLU',
 'aggregation': 'mean',
 'aggregation_norm': 100,
 'atom_descriptor_scaling': True,
 'atom_descriptors': None,
 'atom_descriptors_path': None,
 'atom_descriptors_size': 0,
 'atom_features_size': 0,
 'atom_messages': False,
 'batch_size': 50,
 'bias': False,
 'bond_feature_scaling': True,
 'bond_features_path': None,
 'bond_features_size': 0,
 'cache_cutoff': 10000,
 'checkpoint_dir': None,
 'checkpoint_frzn': None,
 'checkpoint_path': None,
 'checkpoint_paths': None,
 'class_balance': False,
 'config_path': None,
 'crossval_index_dir': None,
 'crossval_index_file': None,
 'crossval_index_sets': None,
 'cuda': True,
 'data_path': 'data/tox21.csv',
 'data_weights_path': None,
 'dataset_type': 'classification',
 'depth': 3,
 'device': device(type='cuda'),
 'dropout': 0.0,
 'empty_cache': False,
 'ensemble_size': 1,
 'epochs': 30,
 'explicit_h': Fal

### Usage: Chemprop Predict

In [18]:
from typing import List

import torch
from tqdm import tqdm

from chemprop.data import MoleculeDataLoader, MoleculeDataset, StandardScaler
from chemprop.models import MoleculeModel

In [19]:
def predict(model: MoleculeModel,
            data_loader: MoleculeDataLoader,
            disable_progress_bar: bool = False,
            scaler: StandardScaler = None) -> List[List[float]]:
    """
    Makes predictions on a dataset using an ensemble of models.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param disable_progress_bar: Whether to disable the progress bar.
    :param scaler: A :class:`~chemprop.features.scaler.StandardScaler` object fit on the training targets.
    :return: A list of lists of predictions. The outer list is molecules while the inner list is tasks.
    """
    model.eval()

    preds = []

    for batch in tqdm(data_loader, disable=disable_progress_bar, leave=False):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch = \
            batch.batch_graph(), batch.features(), batch.atom_descriptors(), batch.atom_features(), batch.bond_features()

        # Make predictions
        with torch.no_grad():
            batch_preds = model(mol_batch, features_batch, atom_descriptors_batch,
                                atom_features_batch, bond_features_batch)

        batch_preds = batch_preds.data.cpu().numpy()

        # Inverse scale if regression
        if scaler is not None:
            batch_preds = scaler.inverse_transform(batch_preds)

        # Collect vectors
        batch_preds = batch_preds.tolist()
        preds.extend(batch_preds)

    return preds

In [21]:
!python predict.py --test_path data/tox21.csv --checkpoint_dir tox21_checkpoints --preds_path tox21_preds.csv

Loading training args
Loading data
7831it [00:00, 159388.15it/s]
100% 7831/7831 [00:00<00:00, 348246.81it/s]
Validating SMILES
Test size = 7,831
  cpuset_checked))
Predicting with an ensemble of 1 models
  0% 0/1 [00:00<?, ?it/s]Loading pretrained parameter "encoder.encoder.0.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.0.W_i.weight".
Loading pretrained parameter "encoder.encoder.0.W_h.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.bias".
Loading pretrained parameter "ffn.1.weight".
Loading pretrained parameter "ffn.1.bias".
Loading pretrained parameter "ffn.4.weight".
Loading pretrained parameter "ffn.4.bias".
Moving model to cuda

  0% 0/157 [00:00<?, ?it/s][A
  1% 1/157 [00:00<00:46,  3.37it/s][A
  2% 3/157 [00:00<00:21,  7.18it/s][A
  4% 7/157 [00:00<00:12, 11.87it/s][A
 10% 15/157 [00:00<00:05, 24.53it/s][A
 12% 19/157 [00:00<00:05, 27.23it/s][A
 15% 24/157 [00:01<00:04, 32.62it/

### Usage: Chemprop Fingerprint

In [22]:
import csv
from typing import List, Optional, Union

import torch
from tqdm import tqdm

from chemprop.args import PredictArgs, TrainArgs
from chemprop.data import get_data, get_data_from_smiles, MoleculeDataLoader, MoleculeDataset
from chemprop.utils import load_args, load_checkpoint, makedirs, timeit, load_scalers, update_prediction_args
from chemprop.data import MoleculeDataLoader, MoleculeDataset
from chemprop.features import set_reaction, set_explicit_h
from chemprop.models import MoleculeModel

In [23]:
@timeit()
def molecule_fingerprint(args: PredictArgs, smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
    """
    Loads data and a trained model and uses the model to encode fingerprint vectors for the data.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param smiles: List of list of SMILES to make predictions on.
    :return: A list of fingerprint vectors (list of floats)
    """

    print('Loading training args')
    train_args = load_args(args.checkpoint_paths[0])

    # Update args with training arguments
    update_prediction_args(predict_args=args, train_args=train_args, validate_feature_sources=False)
    args: Union[PredictArgs, TrainArgs]

    #set explicit H option and reaction option
    set_explicit_h(train_args.explicit_h)
    set_reaction(train_args.reaction, train_args.reaction_mode)

    print('Loading data')
    if smiles is not None:
        full_data = get_data_from_smiles(
            smiles=smiles,
            skip_invalid_smiles=False,
            features_generator=args.features_generator
        )
    else:
        full_data = get_data(path=args.test_path, smiles_columns=args.smiles_columns, target_columns=[], ignore_columns=[], skip_invalid_smiles=False,
                             args=args, store_row=True)

    print('Validating SMILES')
    full_to_valid_indices = {}
    valid_index = 0
    for full_index in range(len(full_data)):
        if all(mol is not None for mol in full_data[full_index].mol):
            full_to_valid_indices[full_index] = valid_index
            valid_index += 1

    test_data = MoleculeDataset([full_data[i] for i in sorted(full_to_valid_indices.keys())])

    # Edge case if empty list of smiles is provided
    if len(test_data) == 0:
        return [None] * len(full_data)

    print(f'Test size = {len(test_data):,}')

    # Create data loader
    test_data_loader = MoleculeDataLoader(
        dataset=test_data,
        batch_size=args.batch_size,
        num_workers=args.num_workers
    )

    # Load model
    print(f'Encoding smiles into a fingerprint vector from a single model')
    if len(args.checkpoint_paths) != 1:
        raise ValueError("Fingerprint generation only supports one model, cannot use an ensemble")

    model = load_checkpoint(args.checkpoint_paths[0], device=args.device)
    scaler, features_scaler, atom_descriptor_scaler, bond_feature_scaler = load_scalers(args.checkpoint_paths[0])

    # Normalize features
    if args.features_scaling or train_args.atom_descriptor_scaling or train_args.bond_feature_scaling:
        test_data.reset_features_and_targets()
        if args.features_scaling:
            test_data.normalize_features(features_scaler)
        if train_args.atom_descriptor_scaling and args.atom_descriptors is not None:
            test_data.normalize_features(atom_descriptor_scaler, scale_atom_descriptors=True)
        if train_args.bond_feature_scaling and args.bond_features_size > 0:
            test_data.normalize_features(bond_feature_scaler, scale_bond_features=True)

    # Make fingerprints
    model_preds = model_fingerprint(
        model=model,
        data_loader=test_data_loader
    )

    # Save predictions
    print(f'Saving predictions to {args.preds_path}')
    assert len(test_data) == len(model_preds)
    makedirs(args.preds_path, isfile=True)

    # Copy predictions over to full_data
    total_hidden_size = args.hidden_size * args.number_of_molecules
    for full_index, datapoint in enumerate(full_data):
        valid_index = full_to_valid_indices.get(full_index, None)
        preds = model_preds[valid_index] if valid_index is not None else ['Invalid SMILES'] * total_hidden_size

        fingerprint_columns=[f'fp_{i}' for i in range(total_hidden_size)]
        for i in range(len(fingerprint_columns)):
            datapoint.row[fingerprint_columns[i]] = preds[i]

    # Write predictions
    with open(args.preds_path, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=args.smiles_columns+fingerprint_columns,extrasaction='ignore')
        writer.writeheader()
        for datapoint in full_data:
            writer.writerow(datapoint.row)

    return model_preds

def model_fingerprint(model: MoleculeModel,
            data_loader: MoleculeDataLoader,
            disable_progress_bar: bool = False) -> List[List[float]]:
    """
    Encodes the provided molecules into the latent fingerprint vectors, according to the provided model.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param disable_progress_bar: Whether to disable the progress bar.
    :return: A list of fingerprint vector lists.
    """
    model.eval()

    fingerprints = []

    for batch in tqdm(data_loader, disable=disable_progress_bar, leave=False):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, atom_descriptors_batch = batch.batch_graph(), batch.features(), batch.atom_descriptors()

        # Make predictions
        with torch.no_grad():
            batch_fp = model.fingerprint(mol_batch, features_batch, atom_descriptors_batch)

        # Collect vectors
        batch_fp = batch_fp.data.cpu().tolist()

        fingerprints.extend(batch_fp)

    return fingerprints

def chemprop_fingerprint() -> None:
    """
    Parses Chemprop predicting arguments and returns the latent representation vectors for
    provided molecules, according to a previously trained model.
    """
    molecule_fingerprint(args=PredictArgs().parse_args())

In [24]:
!python fingerprint.py --test_path data/tox21.csv --checkpoint_dir tox21_checkpoints --preds_path tox21_fingerprint.csv

Loading training args
Loading data
7831it [00:00, 157366.03it/s]
100% 7831/7831 [00:00<00:00, 343997.77it/s]
Validating SMILES
Test size = 7,831
  cpuset_checked))
Encoding smiles into a fingerprint vector from a single model
Loading pretrained parameter "encoder.encoder.0.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.0.W_i.weight".
Loading pretrained parameter "encoder.encoder.0.W_h.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.bias".
Loading pretrained parameter "ffn.1.weight".
Loading pretrained parameter "ffn.1.bias".
Loading pretrained parameter "ffn.4.weight".
Loading pretrained parameter "ffn.4.bias".
Moving model to cuda
Saving predictions to tox21_fingerprint.csv
Elapsed time = 0:00:12


### Usage: Chemprop Interpret

In [25]:
import math
from typing import Callable, Dict, List, Set, Tuple

import numpy as np
from rdkit import Chem

from chemprop.args import InterpretArgs
from chemprop.data import get_data_from_smiles, get_header, get_smiles, MoleculeDataLoader, MoleculeDataset
from chemprop.train import predict
from chemprop.utils import load_args, load_checkpoint, load_scalers, timeit

MIN_ATOMS = 15
C_PUCT = 10

In [26]:
class ChempropModel:
    """A :class:`ChempropModel` is a wrapper around a :class:`~chemprop.models.model.MoleculeModel` for interpretation."""

    def __init__(self, args: InterpretArgs) -> None:
        """
        :param args: A :class:`~chemprop.args.InterpretArgs` object containing arguments for interpretation.
        """
        self.args = args
        self.train_args = load_args(args.checkpoint_paths[0])

        # If features were used during training, they must be used when predicting
        if ((self.train_args.features_path is not None or self.train_args.features_generator is not None)
                and args.features_generator is None):
            raise ValueError('Features were used during training so they must be specified again during prediction '
                             'using the same type of features as before (with --features_generator <generator> '
                             'and using --no_features_scaling if applicable).')

        if self.train_args.atom_descriptors_size > 0 or self.train_args.atom_features_size > 0 or self.train_args.bond_features_size > 0:
            raise NotImplementedError('The interpret function does not yet work with additional atom or bond features')

        self.scaler, self.features_scaler, self.atom_descriptor_scaler, self.bond_feature_scaler = load_scalers(args.checkpoint_paths[0])
        self.checkpoints = [load_checkpoint(checkpoint_path, device=args.device) for checkpoint_path in args.checkpoint_paths]

    def __call__(self, smiles: List[str], batch_size: int = 500) -> List[List[float]]:
        """
        Makes predictions on a list of SMILES.

        :param smiles: A list of SMILES to make predictions on.
        :param batch_size: The batch size.
        :return: A list of lists of floats containing the predicted values.
        """
        test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False, features_generator=self.args.features_generator)
        valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None]
        test_data = MoleculeDataset([test_data[i] for i in valid_indices])

        if self.train_args.features_scaling:
            test_data.normalize_features(self.features_scaler)
        if self.train_args.atom_descriptor_scaling and self.args.atom_descriptors is not None:
            test_data.normalize_features(self.atom_descriptor_scaler, scale_atom_descriptors=True)
        if self.train_args.bond_feature_scaling and self.args.bond_features_size > 0:
            test_data.normalize_features(self.bond_feature_scaler, scale_bond_features=True)

        test_data_loader = MoleculeDataLoader(dataset=test_data, batch_size=batch_size)

        sum_preds = []
        for model in self.checkpoints:
            model_preds = predict(
                model=model,
                data_loader=test_data_loader,
                scaler=self.scaler,
                disable_progress_bar=True
            )
            sum_preds.append(np.array(model_preds))

        # Ensemble predictions
        sum_preds = sum(sum_preds)
        avg_preds = sum_preds / len(self.checkpoints)

        return avg_preds

class MCTSNode:
    """A :class:`MCTSNode` represents a node in a Monte Carlo Tree Search."""

    def __init__(self, smiles: str, atoms: List[int], W: float = 0, N: int = 0, P: float = 0) -> None:
        """
        :param smiles: The SMILES for the substructure at this node.
        :param atoms: A list of atom indices represented by this node.
        :param W: The W value of this node.
        :param N: The N value of this node.
        :param P: The P value of this node.
        """
        self.smiles = smiles
        self.atoms = set(atoms)
        self.children = []
        self.W = W
        self.N = N
        self.P = P

    def Q(self) -> float:
        return self.W / self.N if self.N > 0 else 0

    def U(self, n: int) -> float:
        return C_PUCT * self.P * math.sqrt(n) / (1 + self.N)

def find_clusters(mol: Chem.Mol) -> Tuple[List[Tuple[int, ...]], List[List[int]]]:
    """
    Finds clusters within the molecule.

    :param mol: An RDKit molecule.
    :return: A tuple containing a list of atom tuples representing the clusters
             and a list of lists of atoms in each cluster.
    """
    n_atoms = mol.GetNumAtoms()
    if n_atoms == 1:  # special case
        return [(0,)], [[0]]

    clusters = []
    for bond in mol.GetBonds():
        a1 = bond.GetBeginAtom().GetIdx()
        a2 = bond.GetEndAtom().GetIdx()
        if not bond.IsInRing():
            clusters.append((a1, a2))

    ssr = [tuple(x) for x in Chem.GetSymmSSSR(mol)]
    clusters.extend(ssr)

    atom_cls = [[] for _ in range(n_atoms)]
    for i in range(len(clusters)):
        for atom in clusters[i]:
            atom_cls[atom].append(i)

    return clusters, atom_cls

def __extract_subgraph(mol: Chem.Mol, selected_atoms: Set[int]) -> Tuple[Chem.Mol, List[int]]:
    """
    Extracts a subgraph from an RDKit molecule given a set of atom indices.

    :param mol: An RDKit molecule from which to extract a subgraph.
    :param selected_atoms: The atoms which form the subgraph to be extracted.
    :return: A tuple containing an RDKit molecule representing the subgraph
             and a list of root atom indices from the selected indices.
    """
    selected_atoms = set(selected_atoms)
    roots = []
    for idx in selected_atoms:
        atom = mol.GetAtomWithIdx(idx)
        bad_neis = [y for y in atom.GetNeighbors() if y.GetIdx() not in selected_atoms]
        if len(bad_neis) > 0:
            roots.append(idx)

    new_mol = Chem.RWMol(mol)

    for atom_idx in roots:
        atom = new_mol.GetAtomWithIdx(atom_idx)
        atom.SetAtomMapNum(1)
        aroma_bonds = [bond for bond in atom.GetBonds() if bond.GetBondType() == Chem.rdchem.BondType.AROMATIC]
        aroma_bonds = [bond for bond in aroma_bonds if
                       bond.GetBeginAtom().GetIdx() in selected_atoms and bond.GetEndAtom().GetIdx() in selected_atoms]
        if len(aroma_bonds) == 0:
            atom.SetIsAromatic(False)

    remove_atoms = [atom.GetIdx() for atom in new_mol.GetAtoms() if atom.GetIdx() not in selected_atoms]
    remove_atoms = sorted(remove_atoms, reverse=True)
    for atom in remove_atoms:
        new_mol.RemoveAtom(atom)

    return new_mol.GetMol(), roots

def extract_subgraph(smiles: str, selected_atoms: Set[int]) -> Tuple[str, List[int]]:
    """
    Extracts a subgraph from a SMILES given a set of atom indices.

    :param smiles: A SMILES from which to extract a subgraph.
    :param selected_atoms: The atoms which form the subgraph to be extracted.
    :return: A tuple containing a SMILES representing the subgraph
             and a list of root atom indices from the selected indices.
    """
    # try with kekulization
    mol = Chem.MolFromSmiles(smiles)
    Chem.Kekulize(mol)
    subgraph, roots = __extract_subgraph(mol, selected_atoms)
    try:
        subgraph = Chem.MolToSmiles(subgraph, kekuleSmiles=True)
        subgraph = Chem.MolFromSmiles(subgraph)
    except Exception:
        subgraph = None

    mol = Chem.MolFromSmiles(smiles)  # de-kekulize
    if subgraph is not None and mol.HasSubstructMatch(subgraph):
        return Chem.MolToSmiles(subgraph), roots

    # If fails, try without kekulization
    subgraph, roots = __extract_subgraph(mol, selected_atoms)
    subgraph = Chem.MolToSmiles(subgraph)
    subgraph = Chem.MolFromSmiles(subgraph)

    if subgraph is not None:
        return Chem.MolToSmiles(subgraph), roots
    else:
        return None, None

def mcts_rollout(node: MCTSNode,
                 state_map: Dict[str, MCTSNode],
                 orig_smiles: str,
                 clusters: List[Set[int]],
                 atom_cls: List[Set[int]],
                 nei_cls: List[Set[int]],
                 scoring_function: Callable[[List[str]], List[float]]) -> float:
    """
    A Monte Carlo Tree Search rollout from a given :class:`MCTSNode`.

    :param node: The :class:`MCTSNode` from which to begin the rollout.
    :param state_map: A mapping from SMILES to :class:`MCTSNode`.
    :param orig_smiles: The original SMILES of the molecule.
    :param clusters: Clusters of atoms.
    :param atom_cls: Atom indices in the clusters.
    :param nei_cls: Neighboring clusters.
    :param scoring_function: A function for scoring subgraph SMILES using a Chemprop model.
    :return: The score of this MCTS rollout.
    """
    cur_atoms = node.atoms
    if len(cur_atoms) <= MIN_ATOMS:
        return node.P

    # Expand if this node has never been visited
    if len(node.children) == 0:
        cur_cls = set([i for i, x in enumerate(clusters) if x <= cur_atoms])
        for i in cur_cls:
            leaf_atoms = [a for a in clusters[i] if len(atom_cls[a] & cur_cls) == 1]
            if len(nei_cls[i] & cur_cls) == 1 or len(clusters[i]) == 2 and len(leaf_atoms) == 1:
                new_atoms = cur_atoms - set(leaf_atoms)
                new_smiles, _ = extract_subgraph(orig_smiles, new_atoms)
                if new_smiles in state_map:
                    new_node = state_map[new_smiles]  # merge identical states
                else:
                    new_node = MCTSNode(new_smiles, new_atoms)
                if new_smiles:
                    node.children.append(new_node)

        state_map[node.smiles] = node
        if len(node.children) == 0:
            return node.P  # cannot find leaves

        scores = scoring_function([[x.smiles] for x in node.children])
        for child, score in zip(node.children, scores):
            child.P = score

    sum_count = sum(c.N for c in node.children)
    selected_node = max(node.children, key=lambda x: x.Q() + x.U(sum_count))
    v = mcts_rollout(selected_node, state_map, orig_smiles, clusters, atom_cls, nei_cls, scoring_function)
    selected_node.W += v
    selected_node.N += 1

    return v

def mcts(smiles: str,
         scoring_function: Callable[[List[str]], List[float]],
         n_rollout: int,
         max_atoms: int,
         prop_delta: float) -> List[MCTSNode]:
    """
    Runs the Monte Carlo Tree Search algorithm.

    :param smiles: The SMILES of the molecule to perform the search on.
    :param scoring_function: A function for scoring subgraph SMILES using a Chemprop model.
    :param n_rollout: THe number of MCTS rollouts to perform.
    :param max_atoms: The maximum number of atoms allowed in an extracted rationale.
    :param prop_delta: The minimum required property value for a satisfactory rationale.
    :return: A list of rationales each represented by a :class:`MCTSNode`.
    """
            
    mol = Chem.MolFromSmiles(smiles)
    if mol.GetNumAtoms() > 50:
        n_rollout = 1

    clusters, atom_cls = find_clusters(mol)
    nei_cls = [0] * len(clusters)
    for i, cls in enumerate(clusters):
        nei_cls[i] = [nei for atom in cls for nei in atom_cls[atom]]
        nei_cls[i] = set(nei_cls[i]) - {i}
        clusters[i] = set(list(cls))
    for a in range(len(atom_cls)):
        atom_cls[a] = set(atom_cls[a])

    root = MCTSNode(smiles, set(range(mol.GetNumAtoms())))
    state_map = {smiles: root}
    for _ in range(n_rollout):
        mcts_rollout(root, state_map, smiles, clusters, atom_cls, nei_cls, scoring_function)

    rationales = [node for _, node in state_map.items() if len(node.atoms) <= max_atoms and node.P >= prop_delta]

    return rationales

@timeit()
def interpret(args: InterpretArgs) -> None:
    """
    Runs interpretation of a Chemprop model using the Monte Carlo Tree Search algorithm.

    :param args: A :class:`~chemprop.args.InterpretArgs` object containing arguments for interpretation.
    """

    if args.number_of_molecules != 1:
        raise ValueError("Interpreting is currently only available for single-molecule models.")
    
    global C_PUCT, MIN_ATOMS

    chemprop_model = ChempropModel(args)

    def scoring_function(smiles: List[str]) -> List[float]:
        return chemprop_model(smiles)[:, args.property_id - 1]

    C_PUCT = args.c_puct
    MIN_ATOMS = args.min_atoms

    all_smiles = get_smiles(path=args.data_path, smiles_columns=args.smiles_columns)
    header = get_header(path=args.data_path)

    property_name = header[args.property_id] if len(header) > args.property_id else 'score'
    print(f'smiles,{property_name},rationale,rationale_score')

    for smiles in all_smiles:
        score = scoring_function([smiles])[0]
        if score > args.prop_delta:
            rationales = mcts(
                smiles=smiles[0],
                scoring_function=scoring_function,
                n_rollout=args.rollout,
                max_atoms=args.max_atoms,
                prop_delta=args.prop_delta
            )
        else:
            rationales = []

        if len(rationales) == 0:
            print(f'{smiles},{score:.3f},,')
        else:
            min_size = min(len(x.atoms) for x in rationales)
            min_rationales = [x for x in rationales if len(x.atoms) == min_size]
            rats = sorted(min_rationales, key=lambda x: x.P, reverse=True)
            print(f'{smiles},{score:.3f},{rats[0].smiles},{rats[0].P:.3f}')

def chemprop_interpret() -> None:
    """Runs interpretation of a Chemprop model.

    This is the entry point for the command line command :code:`chemprop_interpret`.
    """
    interpret(args=InterpretArgs().parse_args())

In [27]:
!python interpret.py --data_path data/tox21.csv --checkpoint_dir tox21_checkpoints/fold_0/ --property_id 1

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[18:37:15] Aromatic bonds on non aromatic atom 8
[18:37:16] Aromatic bonds on non aromatic atom 7
[18:37:16] Aromatic bonds on non aromatic atom 7
[18:37:16] Aromatic bonds on non aromatic atom 7
[18:37:17] Aromatic bonds on non aromatic atom 4
[18:37:17] Aromatic bonds on non aromatic atom 5
[18:37:17] Aromatic bonds on non aromatic atom 7
[18:37:17] Aromatic bonds on non aromatic atom 8
[18:37:17] Aromatic bonds on non aromatic atom 8
[18:37:17] Aromatic bonds on non aromatic atom 7
[18:37:17] Aromatic bonds on non aromatic atom 8
[18:37:19] Aromatic bonds on non aromatic atom 7
[18:37:19] Aromatic bonds on non aromatic atom 8
[18:37:19] Aromatic bonds on non aromatic atom 8
[18:37:19] Aromatic bonds on non aromatic atom 7
[18:37:20] Aromatic bonds on non aromatic atom 4
[18:37:20] Aromatic bonds on non aromatic atom 6
[18:37:20] Aromatic bonds on non aromatic atom 8
[18:37:20] Aromatic bonds on non aromatic atom 7
[18: