## Import packages

In [None]:
from dataclasses import dataclass, field
import math
from pathlib import Path
import time
from typing import Callable, Union, Iterable

from lightning import pytorch as pl
import numpy as np
import pandas as pd
from rdkit import Chem
import torch
import os
import chemprop

from chemprop import data, featurizers, models
from chemprop.models import MPNN

## Define helper function to make model predictions from SMILES

In [130]:
def make_prediction(
    models: list[MPNN],
    trainer: pl.Trainer,
    smiles: list[str],
) -> np.ndarray:
    """Makes predictions on a list of SMILES.

    Parameters
    ----------
    models : list
        A list of models to make predictions with.
    smiles : list
        A list of SMILES to make predictions on.

    Returns
    -------
    list[list[float]]
        A list of lists containing the predicted values.
    """

    test_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smiles]
    test_dset = data.MoleculeDataset(test_data)
    test_loader = data.build_dataloader(
        test_dset, batch_size=1, num_workers=0, shuffle=False
    )

    with torch.inference_mode():
        sum_preds = []
        for model in models:
            predss = trainer.predict(model, test_loader)
            preds = torch.cat(predss, 0)
            preds = preds.cpu().numpy()
            sum_preds.append(preds)

        # Ensemble predictions
        sum_preds = sum(sum_preds)
        avg_preds = sum_preds / len(models)

    return avg_preds

## Classes/functions relevant to Monte Carlo Tree Search

Mostly similar to the scripts from Chemprop v1 [interpret.py](https://github.com/chemprop/chemprop/blob/master/chemprop/interpret.py) with additional documentation

In [131]:
@dataclass
class MCTSNode:
    """Represents a node in a Monte Carlo Tree Search.

    Parameters
    ----------
    smiles : str
        The SMILES for the substructure at this node.
    atoms : list
        A list of atom indices in the substructure at this node.
    W : float
        The total action value, which indicates how likely the deletion will lead to a good rationale.
    N : int
        The visit count, which indicates how many times this node has been visited. It is used to balance exploration and exploitation.
    P : float
        The predicted property score of the new subgraphs' after the deletion, shown as R in the original paper.
    """

    smiles: str
    atoms: Iterable[int]
    W: float = 0
    N: int = 0
    P: float = 0
    children: list[...] = field(default_factory=list) # type: ignore

    def __post_init__(self):
        self.atoms = set(self.atoms)

    def Q(self) -> float:
        """
        Returns
        -------
        float
            The mean action value of the node.
        """
        return self.W / self.N if self.N > 0 else 0

    def U(self, n: int, c_puct: float = 10.0) -> float:
        """
        Parameters
        ----------
        n : int
            The sum of the visit count of this node's siblings.
        c_puct : float
            A constant that controls the level of exploration.
        
        Returns
        -------
        float
            The exploration value of the node.
        """
        return c_puct * self.P * math.sqrt(n) / (1 + self.N)

In [132]:
def find_clusters(mol: Chem.Mol) -> tuple[list[tuple[int, ...]], list[list[int]]]:
    """Finds clusters within the molecule. Jin et al. from [1]_ only allows deletion of one peripheral non-aromatic bond or one peripheral ring from each state,
    so the clusters here are defined as non-ring bonds and the smallest set of smallest rings.

    Parameters
    ----------
    mol : RDKit molecule
        The molecule to find clusters in.

    Returns
    -------
    tuple
        A tuple containing:
        - list of tuples: Each tuple contains atoms in a cluster.
        - list of int: Each atom's cluster index.
    
    References
    ----------
    .. [1] Jin, Wengong, Regina Barzilay, and Tommi Jaakkola. "Multi-objective molecule generation using interpretable substructures." International conference on machine learning. PMLR, 2020. https://arxiv.org/abs/2002.03244
    """

    n_atoms = mol.GetNumAtoms()
    if n_atoms == 1:  # special case
        return [(0,)], [[0]]

    clusters = []
    for bond in mol.GetBonds():
        a1 = bond.GetBeginAtom().GetIdx()
        a2 = bond.GetEndAtom().GetIdx()
        if not bond.IsInRing():
            clusters.append((a1, a2))

    ssr = [tuple(x) for x in Chem.GetSymmSSSR(mol)]
    clusters.extend(ssr)

    atom_cls = [[] for _ in range(n_atoms)]
    for i in range(len(clusters)):
        for atom in clusters[i]:
            atom_cls[atom].append(i)

    return clusters, atom_cls

In [133]:
def extract_subgraph_from_mol(mol: Chem.Mol, selected_atoms: set[int]) -> tuple[Chem.Mol, list[int]]:
    """Extracts a subgraph from an RDKit molecule given a set of atom indices.

    Parameters
    ----------
    mol : RDKit molecule
        The molecule from which to extract a subgraph.
    selected_atoms : list of int
        The indices of atoms which form the subgraph to be extracted.

    Returns
    -------
    tuple
        A tuple containing:
        - RDKit molecule: The subgraph.
        - list of int: Root atom indices from the selected indices.
    """

    selected_atoms = set(selected_atoms)
    roots = []
    for idx in selected_atoms:
        atom = mol.GetAtomWithIdx(idx)
        bad_neis = [y for y in atom.GetNeighbors() if y.GetIdx() not in selected_atoms]
        if len(bad_neis) > 0:
            roots.append(idx)

    new_mol = Chem.RWMol(mol)

    for atom_idx in roots:
        atom = new_mol.GetAtomWithIdx(atom_idx)
        atom.SetAtomMapNum(1)
        aroma_bonds = [
            bond for bond in atom.GetBonds() if bond.GetBondType() == Chem.rdchem.BondType.AROMATIC
        ]
        aroma_bonds = [
            bond
            for bond in aroma_bonds
            if bond.GetBeginAtom().GetIdx() in selected_atoms
            and bond.GetEndAtom().GetIdx() in selected_atoms
        ]
        if len(aroma_bonds) == 0:
            atom.SetIsAromatic(False)

    remove_atoms = [
        atom.GetIdx() for atom in new_mol.GetAtoms() if atom.GetIdx() not in selected_atoms
    ]
    remove_atoms = sorted(remove_atoms, reverse=True)
    for atom in remove_atoms:
        new_mol.RemoveAtom(atom)

    return new_mol.GetMol(), roots

In [134]:
def extract_subgraph(smiles: str, selected_atoms: set[int]) -> tuple[str, list[int]]:
    """Extracts a subgraph from a SMILES given a set of atom indices.

    Parameters
    ----------
    smiles : str
        The SMILES string from which to extract a subgraph.
    selected_atoms : list of int
        The indices of atoms which form the subgraph to be extracted.

    Returns
    -------
    tuple
        A tuple containing:
        - str: SMILES representing the subgraph.
        - list of int: Root atom indices from the selected indices.
    """
    # try with kekulization
    mol = Chem.MolFromSmiles(smiles)
    Chem.Kekulize(mol)
    subgraph, roots = extract_subgraph_from_mol(mol, selected_atoms)
    try:
        subgraph = Chem.MolToSmiles(subgraph, kekuleSmiles=True)
        subgraph = Chem.MolFromSmiles(subgraph)
    except Exception:
        subgraph = None

    mol = Chem.MolFromSmiles(smiles)  # de-kekulize
    if subgraph is not None and mol.HasSubstructMatch(subgraph):
        return Chem.MolToSmiles(subgraph), roots

    # If fails, try without kekulization
    subgraph, roots = extract_subgraph_from_mol(mol, selected_atoms)
    subgraph = Chem.MolToSmiles(subgraph)
    subgraph = Chem.MolFromSmiles(subgraph)

    if subgraph is not None:
        return Chem.MolToSmiles(subgraph), roots
    else:
        return None, None

In [135]:
def mcts_rollout(
    node: MCTSNode,
    state_map: dict[str, MCTSNode],
    orig_smiles: str,
    clusters: list[set[int]],
    atom_cls: list[set[int]],
    nei_cls: list[set[int]],
    scoring_function: Callable[[list[str]], list[float]],
    min_atoms: int = 15,
    c_puct: float = 10.0,
) -> float:
    """A Monte Carlo Tree Search rollout from a given MCTSNode.

    Parameters
    ----------
    node : MCTSNode
        The MCTSNode from which to begin the rollout.
    state_map : dict
        A mapping from SMILES to MCTSNode.
    orig_smiles : str
        The original SMILES of the molecule.
    clusters : list
        Clusters of atoms.
    atom_cls : list
        Atom indices in the clusters.
    nei_cls : list
        Neighboring cluster indices.
    scoring_function : function
        A function for scoring subgraph SMILES using a Chemprop model.
    min_atoms : int
        The minimum number of atoms in a subgraph.
    c_puct : float
        The constant controlling the level of exploration.

    Returns
    -------
    float
        The score of this MCTS rollout.
    """
    # Return if the number of atoms is less than the minimum
    cur_atoms = node.atoms
    if len(cur_atoms) <= min_atoms:
        return node.P

    # Expand if this node has never been visited
    if len(node.children) == 0:
        # Cluster indices whose all atoms are present in current subgraph
        cur_cls = set([i for i, x in enumerate(clusters) if x <= cur_atoms])

        for i in cur_cls:
            # Leaf atoms are atoms that are only involved in one cluster.
            leaf_atoms = [a for a in clusters[i] if len(atom_cls[a] & cur_cls) == 1]

            # This checks
            # 1. If there is only one neighbor cluster in the current subgraph (so that we don't produce unconnected graphs), or
            # 2. If the cluster has only two atoms and the current subgraph has only one leaf atom.
            # If either of the conditions is met, remove the leaf atoms in the current cluster.
            if len(nei_cls[i] & cur_cls) == 1 or len(clusters[i]) == 2 and len(leaf_atoms) == 1:
                new_atoms = cur_atoms - set(leaf_atoms)
                new_smiles, _ = extract_subgraph(orig_smiles, new_atoms)
                if new_smiles in state_map:
                    new_node = state_map[new_smiles]  # merge identical states
                else:
                    new_node = MCTSNode(new_smiles, new_atoms)
                if new_smiles:
                    node.children.append(new_node)

        state_map[node.smiles] = node
        if len(node.children) == 0:
            return node.P  # cannot find leaves

        scores = scoring_function([x.smiles for x in node.children])
        for child, score in zip(node.children, scores):
            child.P = score

    sum_count = sum(c.N for c in node.children)
    selected_node = max(node.children, key=lambda x: x.Q() + x.U(sum_count, c_puct=c_puct))
    v = mcts_rollout(
        selected_node,
        state_map,
        orig_smiles,
        clusters,
        atom_cls,
        nei_cls,
        scoring_function,
        min_atoms=min_atoms,
        c_puct=c_puct,
    )
    selected_node.W += v
    selected_node.N += 1

    return v

In [136]:
def mcts(
    smiles: str,
    scoring_function: Callable[[list[str]], list[float]],
    n_rollout: int,
    max_atoms: int,
    prop_delta: float,
    min_atoms: int = 15,
    c_puct: int = 10,
) -> list[MCTSNode]:
    """Runs the Monte Carlo Tree Search algorithm.

    Parameters
    ----------
    smiles : str
        The SMILES of the molecule to perform the search on.
    scoring_function : function
        A function for scoring subgraph SMILES using a Chemprop model.
    n_rollout : int
        The number of MCTS rollouts to perform.
    max_atoms : int
        The maximum number of atoms allowed in an extracted rationale.
    prop_delta : float
        The minimum required property value for a satisfactory rationale.
    min_atoms : int
        The minimum number of atoms in a subgraph.
    c_puct : float
        The constant controlling the level of exploration.

    Returns
    -------
    list
        A list of rationales each represented by a MCTSNode.
    """

    mol = Chem.MolFromSmiles(smiles)

    clusters, atom_cls = find_clusters(mol)
    nei_cls = [0] * len(clusters)
    for i, cls in enumerate(clusters):
        nei_cls[i] = [nei for atom in cls for nei in atom_cls[atom]]
        nei_cls[i] = set(nei_cls[i]) - {i}
        clusters[i] = set(list(cls))
    for a in range(len(atom_cls)):
        atom_cls[a] = set(atom_cls[a])

    root = MCTSNode(smiles, set(range(mol.GetNumAtoms())))
    state_map = {smiles: root}
    for _ in range(n_rollout):
        mcts_rollout(
            root,
            state_map,
            smiles,
            clusters,
            atom_cls,
            nei_cls,
            scoring_function,
            min_atoms=min_atoms,
            c_puct=c_puct,
        )

    rationales = [
        node
        for _, node in state_map.items()
        if len(node.atoms) <= max_atoms and node.P >= prop_delta
    ]

    return rationales

## Load model

In [None]:
chemprop_dir = Path.cwd().parent
model_path = (
"../../results/models/model_GCN_cleaned_finetuned.ckpt"
)  # path to model checkpoint (.ckpt) or model file (.pt)

In [138]:
mpnn = models.MPNN.load_from_file(model_path)
mpnn

  d = torch.load(model_path, map_location=map_location)


MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=1100, bias=False)
    (W_h): Linear(in_features=1100, out_features=1100, bias=False)
    (W_o): Linear(in_features=1172, out_features=1100, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): GraphTransform(
      (V_transform): Identity()
      (E_transform): Identity()
    )
  )
  (agg): NormAggregation()
  (bn): BatchNorm1d(1100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): MulticlassClassificationFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=1100, out_features=1100, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=1100, out_features=1100, bias=True)
      )
      (2): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_

## Load data to run interpretation for

In [139]:
chemprop_dir = Path.cwd().parent
test_path = "../../data/processed/data_cmnpd_after2000.csv"
smiles_column = "SMILES"


In [140]:
df_test = pd.read_csv(test_path)
df_test


Unnamed: 0,SMILES,labels
0,C(CCCCC(C)CCCC(C)C)CCC(=O)O,0
1,C(=C(/C)\C=O)/CC\C=C/C\C=C/CCCCCCCCCC,0
2,[C@@H](O)(\C=C\[C@@H](CCCCC)O)[C@]1([H])C[C@@]...,0
3,OCC(O)COCCCCCCCCCCCCCCCCCCCCCCCCCCC,0
4,O[C@H]1[C@H](O)[C@H](O[C@@H](OC[C@@H](O[C@@H](...,0
...,...,...
18397,C1(=O)CC(=C(COC(=O)C)[C@]([H])(C[C@]2(C)[C@@](...,2
18398,C1(=O)CC(=C(COC(=O)CC(C)C)[C@]([H])(C[C@]2(C)[...,2
18399,c1c(C)c(C(=O)Oc(cc(O)cc2C)c2O3)c3c(Cc4ccc(O)cc...,2
18400,[C@@]12([H])[C@@]3(C(=O)N[C@H]1CC(C)C)[C@@H](\...,2


In [141]:
#Animalia
#smiles_to_process = df_test.iloc[0:11402]['SMILES'].tolist()
#Bacteria
#smiles_to_process = df_test.iloc[11402:13934]['SMILES'].tolist()
#Fungi
smiles_to_process = df_test.iloc[13934:18403]['SMILES'].tolist()
print(smiles_to_process[0])
print(smiles_to_process[len(smiles_to_process)-1])

[C@H]1(O[C@H](CCC)OC[C@H]1O)CO
[C@@H]1(O)\C=C\C=C/[C@H](O)\C=C\C(=O)O[C@@H](C)CCC1


In [142]:
from rdkit import Chem

valid_smiles_to_process = []
invalid_smiles = []  
too_large_smiles = []  

for smiles in smiles_to_process:
    mol = Chem.MolFromSmiles(smiles)
    if mol:  
        if mol.GetNumAtoms() <= 100:  
            valid_smiles_to_process.append(smiles)
        else:
            too_large_smiles.append(smiles)
    else:  
        invalid_smiles.append(smiles)

if invalid_smiles:
    print(f"Found {len(invalid_smiles)} invalid SMILES, which were skipped.")
if too_large_smiles:
    print(f"Found {len(too_large_smiles)} SMILES with more than 100 atoms, which were skipped.")

smiles_to_process = valid_smiles_to_process

Found 14 SMILES with more than 100 atoms, which were skipped.


## Set up trainer

In [None]:
trainer = pl.Trainer(logger=None, enable_progress_bar=False, accelerator="gpu", devices=1)

# Running interpretation

In [145]:
# MCTS options
rollout = 1  # number of MCTS rollouts to perform. If mol.GetNumAtoms() > 50, consider setting n_rollout = 1 to avoid long computation time

c_puct = 10.0  # constant that controls the level of exploration

max_atoms = 40  # maximum number of atoms allowed in an extracted rationale

min_atoms = 8  # minimum number of atoms in an extracted rationale

prop_delta = 0.8  # Minimum score to count as positive.
# In this algorithm, if the predicted property from the substructure if larger than prop_delta, the substructure is considered satisfactory.
# This value depends on the property you want to interpret. 0.5 is a dummy value for demonstration purposes

num_rationales_to_keep = 1  # number of rationales to keep for each molecule

In [146]:
# Define the scoring function. "Score" for a substructure is the predicted property value of the substructure.

models = [mpnn]

property_for_interpretation = "class_2"               #class_0;class_1


def scoring_function(smiles: list[str]) -> list[float]:
    try:
        predictions = make_prediction(
            models=models,
            trainer=trainer,
            smiles=smiles,
        )
    except Exception as e:
        print(f"Error during prediction: {e}")
        return [0.0] * len(smiles)  

    try:
        class_0_probs = predictions[:, 0, 2]          #class_0;class_1
    except IndexError as e:
        print(f"Error in indexing predictions: {e}")
        return [0.0] * len(smiles)

    if class_0_probs.ndim != 1:
        raise ValueError("Predictions should be a 1D array of class 0 probabilities.")

    return class_0_probs.tolist() 

In [None]:
import csv
import pandas as pd

results_df = {
    "smiles": [],
    property_for_interpretation: []
}

for i in range(num_rationales_to_keep):
    results_df[f"rationale_{i}"] = []
    results_df[f"rationale_{i}_score"] = []

output_file = "../../results/predictions/fungi_big_rationales.csv"      #animalia;bacteria

header = list(results_df.keys())
with open(output_file, 'w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=header)
    writer.writeheader()

for count, smiles in enumerate(smiles_to_process):
    score = scoring_function([smiles])[0]
    if score > prop_delta:
        rationales = mcts(
            smiles=smiles,
            scoring_function=scoring_function,
            n_rollout=rollout,
            max_atoms=max_atoms,
            prop_delta=prop_delta,
            min_atoms=min_atoms,
            c_puct=c_puct,
        )
    else:
        rationales = []

    results_row = {"smiles": smiles, property_for_interpretation: score}

    if len(rationales) == 0:
        for i in range(num_rationales_to_keep):
            results_row[f"rationale_{i}"] = None
            results_row[f"rationale_{i}_score"] = None
    else:
        min_size = min(len(x.atoms) for x in rationales)
        min_rationales = [x for x in rationales if len(x.atoms) == min_size]
        rats = sorted(min_rationales, key=lambda x: x.P, reverse=True)

        for i in range(num_rationales_to_keep):
            if i < len(rats):
                results_row[f"rationale_{i}"] = rats[i].smiles
                results_row[f"rationale_{i}_score"] = rats[i].P
            else:
                results_row[f"rationale_{i}"] = None
                results_row[f"rationale_{i}_score"] = None

    with open(output_file, 'a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=header)
        writer.writerow(results_row)
    
    print(f"Processed SMILES {count + 1}/{len(smiles_to_process)}. Results saved to {output_file}.")

print("All results have been processed and saved.")