<p style="color:purple; font-weight:bold">
[OA-ReactDiff] is the first diffusion-based generative model for generating  <U>3D chemical reactions</U>, which not only accelerates the search for 3D structures in chemical reactions by a factor of 1000, but also generates and explores <U>new and unknown</U> chemical reactions.
</p>

After completing this tutorial, you will be able to:

* Learn some of the applications and potential difficulties of diffusion model in chemistry and materials

* Understand the principles of running [OA-ReactDiff] and apply them to your projects

**It will take [at most] 20 minutes to read and run through the tutorial, so let's get started!**

Link to article 👉 [OA-ReactDiff](https://arxiv.org/abs/2304.06174)

<p style="color:blue; font-weight:bold">
If you have any ideas about applying OA-ReactDiff to your own work, but are not sure if it's entirely appropriate or you don't know how to get started, please feel free to contact the author, Chenru Duan, via email: duanchenru@gmail.com
</p>

<div>
    <img src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/20204/4af440e9b48747388fbd9ec9533b482a/mGNnjUDyymMo9ML3xOiF1Q.png" alt="cover_art" width="500" title="cover_art">
</div>

# Table of Content

* [Introduction to the Denoising Diffusion Probabilistic Models](#ddpm)
* [Generating a molecule](#single-molecule)
* [How about generating a chemical reaction?](#reaction)
* [OA-ReactDiff usage scenarios](#application)
    * [1. No preprocessing required, generation is so simple](#no-preprocessing)
    * [2. The gift of stochastic processes to compensate for chemical intuition.](#stochasticity)
    * [3. Give me some atoms and I'll generate all the reactions.](#exploration)
* [The future lies in generative AI](#outlook)


# A brief introduction to Denoising Diffusion Probabilistic Models (DDPM) <a id ='ddpm'></a>

DDPM is a diffusion model based on probabilistic statistical generation and random noise generation of data. It transforms noisy data into clean data samples by modeling a diffusion process.

These models focus on modeling the distribution of the data and simulate a gradual evolution from a simple distribution to a more complex distribution. The basic concept of a diffusion model is to transform a simple and easily sampled distribution, usually a normal distribution, into a more complex data distribution. This transformation is achieved through a series of reversible operations. Once the model has learned the transformation process, it can generate new samples by starting from a point in the simple distribution and gradually "denoising" to the desired complex data distribution.

The training of the DDPM model requires knowledge of the parameters of the diffusion process to effectively capture the relationship between clean and noisy data at each transformation step. In the generation process, DDPM starts with noisy data (e.g., noisy images) and iteratively applies the learned transformations in reverse to obtain denoised and realistic data samples.

<div>
    <img src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/20204/bcae8391a60845f990e6387bb4ac3977/ox_I-p5BPAeFzFg3Ojh8eQ.png" alt="image_diffusion" width="850" title="image_diffusion">
    <p style='font-size:0.8rem; font-weight:bold'>Figure 1｜Diffusion generation model in generating images</p>
</div>

Supplemental Reading 🔽

[An Introduction to Diffusion Models for Machine Learning](https://encord.com/blog/diffusion-models/#:~:text=Diffusion%20models%20are%20a%20class%20of%20generative%20models%20that%20simulate,a%20sequence%20of%20invertible%20operations.)

[Diffusion Models: A Comprehensive Survey of Methods and Applications](https://arxiv.org/pdf/2209.00796.pdf)

For those who want to play with DDPM image generation and read Chinese, please refer to this notebook tutorial: [Diffusion Models: You still don't understand the basics of diffusion modeling?](https://nb.bohrium.dp.tech/detail/2412844875)

# Tired with images, how about generating molecules?<a id ='single-molecule'></a>

Recently, the framework of DDPM has been progressively applied to tasks related to chemical molecules. Examples include generating [3D structures of organic small molecules](https://arxiv.org/abs/2203.17003), [protein-small molecule docking](https://arxiv.org/abs/2210.01776), and [protein structure-based drug design](https://arxiv.org/abs/2210.13695).

The core difference between the generation of molecules and the image generation above is their <span style="color: orange;">**representation and symmetry**</span>:

* <span style="color: orange;">*Representation*</span>: for a color image of NxN pixels, we generally represent it as a (N, N, 3) tensor, where the last 3 denotes the three primary RGB colors. And for a molecule, we need to use (atomic categories, coordinate information) to represent each atom.

* <span style="color: orange;">*Symmetry*</span>: In images, we generally do not require strict symmetry. It is common to use augmentation to make the model "remember" that the object is not distorted with respect to the 2D rotation. But in molecules, because we learn 3D objects directly and require high accuracy, augmentation is often too expensive and not accurate enough. Thus, we need to embed symmetry into the model.

<div>
    <img src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/20204/bcae8391a60845f990e6387bb4ac3977/1FzbkAcENo9bPWbm1xJHOQ.png", alt="mol_diffusion" width="850" title="mol_diffusion">
    <p style='font-size:0.8rem; font-weight:bold'>Figure 2｜Diffusion Generation Modeling of a Single Molecule</p>
</div>

## SE(3) symmetry is indispensable for generating molecular time models

While we are generating 3D molecules, the symmetry of the physics on the molecule is needed to be maintained. One way to do this is to use graph network models with SE(3) symmetry. A simple way to understand this is that for a molecule, we get the same result by rotating the molecule (D(g)) and then using the model prediction (f) as we do by using the model prediction and then rotating it, i.e., $D(g) \circ f = f \circ D'(g)$
 
A previous notebook written by Siyuan has a [detailed introduction to SE(3) symmetry](https://nb.bohrium.dp.tech/detail/9619364424). Those who would like to dig deeper can supplement their studies 📚. We will only give one example here for visualization experience.

In [None]:
# --- Importing and defining some functions ----
import torch
import py3Dmol
import numpy as np

from typing import Optional
from torch import tensor
from e3nn import o3
from torch_scatter import scatter_mean

from oa_reactdiff.model import LEFTNet

default_float = torch.float64
torch.set_default_dtype(default_float)  # Use double precision for more accurate testing


def remove_mean_batch(
    x: tensor, 
    indices: Optional[tensor] = None
) -> tensor:
    """Remove the mean from each batch in x

    Args:
        x (tensor): input tensor.
        indices (Optional[tensor], optional): batch indices. Defaults to None.

    Returns:
        tensor: output tensor with batch mean as 0.
    """
    if indices == None:
         return x - torch.mean(x, dim=0)
    mean = scatter_mean(x, indices, dim=0)
    x = x - mean[indices]
    return x


def draw_in_3dmol(mol: str, fmt: str = "xyz") -> py3Dmol.view:
    """Draw the molecule

    Args:
        mol (str): str content of molecule.
        fmt (str, optional): format. Defaults to "xyz".

    Returns:
        py3Dmol.view: output viewer
    """
    viewer = py3Dmol.view(1024, 576)
    viewer.addModel(mol, fmt)
    viewer.setStyle({'stick': {}, "sphere": {"radius": 0.36}})
    viewer.zoomTo()
    return viewer


def assemble_xyz(z: list, pos: tensor) -> str:
    """Assembling atomic numbers and positions into xyz format

    Args:
        z (list): chemical elements
        pos (tensor): 3D coordinates

    Returns:
        str: xyz string
    """
    natoms =len(z)
    xyz = f"{natoms}\n\n"
    for _z, _pos in zip(z, pos.numpy()):
        xyz += f"{_z}\t" + "\t".join([str(x) for x in _pos]) + "\n"
    return xyz

### Building a LEFTNet model

A simple test is performed to verify SE(3) symmetry. The model here is for testing, so we only need to build a very small model.

Note: [LEFTNet](https://arxiv.org/abs/2304.04757) is a new SOTA-level SE(3) graph neural network. Although we use LEFTNet here, the properties it exhibits are model-independent (other SE(3) models, such as [EGNN](https://arxiv.org/pdf/2102.09844.pdf), will give the same results)

In [None]:
num_layers = 2
hidden_channels = 8
in_hidden_channels = 4
num_radial = 4

model =  LEFTNet(
    num_layers=num_layers,
    hidden_channels=hidden_channels,
    in_hidden_channels=in_hidden_channels,
    num_radial=num_radial,
    object_aware=False,
)

sum(p.numel() for p in model.parameters() if p.requires_grad)


### Creating a water molecule as a test system

In [None]:
h = torch.rand(3, in_hidden_channels)
z = ["O", "H", "H"]
pos = tensor([
    [0, 0, 0],
    [1, 0, 0],
    [0, 1, 0],
]).double()  # For convenience, we'll set the H-O-H angle here to 90 degrees
edge_index = tensor([
    [0, 0, 1, 1, 2, 2],
    [1, 2, 0, 2, 0, 1]
]).long()  # Using the fully-connected approach, here the edges are undirected

In [None]:
xyz = assemble_xyz(z, pos)
view = draw_in_3dmol(xyz, "xyz")

view  # Display Molecules

### Forward prediction of a scalar and vector using LEFTNet

In [None]:
_h, _pos, __ = model.forward(
    h=h,
    pos=remove_mean_batch(pos),
    edge_index=edge_index,
)  # _h is the output feature, _pos is the output position, the former is a scalar, the latter is a vector

### Random rotation of our water molecule and run it through LEFTNet again

In [None]:
rot = o3.rand_matrix()
pos_rot = torch.matmul(pos, rot).double()

_h_rot, _pos_rot, __ = model.forward(
    h=h,
    pos=remove_mean_batch(pos_rot),
    edge_index=edge_index,
)

### It doesn't matter if it's a scalar or a vector, the result is always equivariant.

In [None]:
torch.max(
    torch.abs(
        _h - _h_rot
    )
)  # The h should remain the same after rotation

In [None]:
torch.max(
    torch.abs(
        torch.matmul(_pos, rot).double() - _pos_rot
    )
)  # The rotated pos should rotate

# The core of chemistry are reactions from the age of alchemy. How about direct generation of chemical reactions?<a id ='reaction'></a>

Since the ancient days of alchemy, chemistry has been a discipline of <span style="color: orange;">**studying and controlling how substances react with each other**</span>. Now that we have succeeded in generating a single molecule, why don't we extend the current diffusion generation model to <span style="color: orange;">**generating a group of molecules**</span>, such as the "reactants, transition states, generators" that is of interest to everyone in chemistry?

<div>
    <img src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/20204/cc00542f53e6459e90e866f6c009e25e/2voIGX7BDbS6pZAOIqRohw.png", alt="reaction_diffusion" width="350" title="reaction_diffusion">
    <p style='font-size:0.8rem; font-weight:bold'>Figure 3: Illustration of a chemical reaction. The black line is the energy curve, "reactants, products" corresponds to the poles, and "transition state" corresponds to the saddle point.</p>
</div>


With this in mind, we generalize diffusion generation to multi-system generation. For training, we simultaneously add noise to "reactants (blue), transition states (yellow), and products (red)", and use a graph neural network model with special symmetry (described in detail later) to predict the three added noises. Two models were developed for generation.

1. <span style="color: orange;">*Random generation*</span>. In this case, we generate a new chemical reaction directly from the three Gaussian distributions, i.e., the **new "unknown"** "reactant, transition state, product".
2. <span style="color: orange;">*Conditional Generation*</span>. In chemical reactions, we usually know part of the information. Here we can use inpainting for conditional generation, e.g. given reactants and products, we can directly generate transition states. In this process, we add fixed noise to "reactants, products" at each step and combine them with the transition states predicted by the graph neural network model. Step by step, we iterate until we generate "reactants ("known"), transition states ("unknown"), generators ("known")".
<div>
    <img src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/20204/bcae8391a60845f990e6387bb4ac3977/OnS4d6hjIcgEaIZcsyKxbQ.png", alt="reaction_diffusion" width="850" title="reaction_diffusion">>
    <p style='font-size:0.8rem; font-weight:bold'>Figure 4｜Detailed Process of Chemical Reaction Generated by Diffusion Generation Model</p>
</div>


## When we rotate any molecule of a chemical reaction, does that reaction change?
When generating a single molecule or a large multi-molecular system, we generally only need to satisfy the corresponding shift in the output of the model after a change in the SE(3) of the whole system (e.g., rotation of the whole) (second row of the figure below).

In chemical reactions, however, this symmetry is **not strong enough**. In reactions, we also require **SE(3) symmetry with respect to each single object** (third row of the figure below). A simple example is that when we rotate a reactant, the output of the model should only rotate the reactant and the rest of the reaction should remain unchanged, rather than treating it as a "brand new" reaction.

<div>
    <img src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/20204/bcae8391a60845f990e6387bb4ac3977/7R-rKlaPFQvpW5rVdQbvIw.png", alt="reaction_symmetry" width="350" title="reaction_symmetry">
    <p style='font-size:0.8rem; font-weight:bold'>Figure 5｜Behavior of chemical reactions under different symmetries (top) after different transformations (left)</p>
</div>

## Failure in system-wide SE(3) for maintaining symmetry of individual objects in reactions


### Consider a hypothetical "electrolysis of water" reaction, H2O -> H2 + O.
Note: This reaction process is *purely fictitious* and is only a simplified model to help us understand the symmetry required in the reaction. Also, we have omitted the transition state structure and only considered the symmetry relationship of 'reactants, products'. The inclusion of transition states does not influence the conclusion.

In [None]:
ns = [3, ] + [2, 1]  # reactant 3 atoms (H2O), product 2 atoms (H2), 1 atom (O radical)
ntot = np.sum(ns)
mask = tensor([0, 0, 0, 1, 1, 1])  # To distinguish between reactants and products
z = ["O", "H", "H"] +  ["H", "H", "O"]
pos_react = tensor([
    [0, 0, 0],
    [1, 0, 0],
    [0, 1, 0],
]).double()  # For convenience, we'll set the H-O-H angle here to 90 degrees
pos_prod = tensor([
    [0, 3, -0.4],
    [0, 3, 0.4],
    [0, -3, 0],
])  # Separating the H2 and O radicals
pos = torch.cat(
    [pos_react, pos_prod],
    dim=0,
)  # 拼接
h  = torch.rand(ntot, in_hidden_channels)

In [None]:
xyz = assemble_xyz(z, pos)
view = draw_in_3dmol(xyz, "xyz")

view


When using the overall SE(3) for a chemical reaction, all of our atoms need to be connected together to ensure that the reactants and products can 'interact' with each other via the graph neural network. Otherwise, even if we change the reactants, the products won't change, and the problem become worse😉

In [None]:
from oa_reactdiff.tests.model.utils import (
    generate_full_eij,
    get_cut_graph_mask,
)

edge_index = generate_full_eij(ntot)


### forward projection

In [None]:
_h, _pos, __ = model.forward(
    h=h,
    pos=remove_mean_batch(pos, mask),
    edge_index=edge_index,
)


### Let's rotate the reactants.

In [None]:
rot = o3.rand_matrix()
pos_react_rot = torch.matmul(pos_react, rot).double()

pos_rot = torch.cat(
    [pos_react_rot, pos_prod],
    dim=0,
)  # Combine rotated H2O and unrotated H2 and O radicals

xyz = assemble_xyz(z, pos_rot)
view = draw_in_3dmol(xyz, "xyz")

view


### Predict this reaction again. Note that the chemical substance of these two reactions is the same

In [None]:
_h_rot, _pos_rot, __ = model.forward(
    h=h,
    pos=remove_mean_batch(pos_rot, mask),
    edge_index=edge_index,
)

### But we're losing "equivariance" on both scalar and vector❌

In [None]:
torch.max(
    torch.abs(
        _h - _h_rot
    )
)  # The h should remain the same after rotation

In [None]:
_pos_rot_prime = torch.cat(
    [
        torch.matmul(_pos[:3], rot),
        _pos[3:]
    ]
)
torch.max(
    torch.abs(
        _pos_rot_prime  - _pos_rot
    )
)  # The rotated pos should rotate


## The need for an <span style="color: orange;">**"object-aware"**</span> SE(3) model in chemical reactions

To address this challenge and extend the framework of generative modeling to chemical reactions, we have developed a **generalized approach**, [OA-ReactDiff](https://arxiv.org/abs/2304.06174), that extends the SE(3) model to an <span style="color: orange;">**"object-aware"**</span> SE(3) model. This approach can be applied to variety of graph neural networks and Transformers.

This part contains more technical details. If you are interested, you can click on the article link above🔗. We directly demonstrate the improved results in this notebook.



### Create an "Object-Aware" LEFTNet

In [None]:
model_oa =  LEFTNet(
    num_layers=num_layers,
    hidden_channels=hidden_channels,
    in_hidden_channels=in_hidden_channels,
    num_radial=num_radial,
    object_aware=True,  # Using the object-aware model
)


We assign different atoms to their respective objects, which are "reactants and products" in this case.

In [None]:
subgraph_mask = get_cut_graph_mask(edge_index, 3)  # 0-2 is the atomic number of the reactant


### Repeat the rotation test again


In [None]:
_h, _pos, __ = model_oa.forward(
    h=h,
    pos=remove_mean_batch(pos, mask),
    edge_index=edge_index,
    subgraph_mask=subgraph_mask,
)

In [None]:
rot = o3.rand_matrix()
pos_react_rot = torch.matmul(pos_react, rot).double()

pos_rot = torch.cat(
    [pos_react_rot, pos_prod],
    dim=0,
)

_h_rot, _pos_rot, __ = model_oa.forward(
    h=h,
    pos=remove_mean_batch(pos_rot, mask),
    edge_index=edge_index,
    subgraph_mask=subgraph_mask,
)

### We verified that the "object-aware" LEFTNet will maintain the desired symmetry ✅

In [None]:
torch.max(
    torch.abs(
        _h - _h_rot
    )
)  # The h should remain the same after rotation

In [None]:
_pos_rot_prime = torch.cat(
    [
        torch.matmul(_pos[:3], rot),
        _pos[3:]
    ]
)

torch.max(
    torch.abs(
        _pos_rot_prime  - _pos_rot
    )
)  # The rotated pos should rotate

# From days to seconds, OA-ReactDiff accelerates the generation of transition state (TS) structures

In experiments, the time scale of the existence of reactants and products is relatively long, making them relatively easy to characterize (e.g. NMR, mass spectrometry, etc.). However, <span style="color: red;">**experimentally characterizing transition state structures is very difficult**</span>. This is because transition states exist for a very short period of time and have higher energy compared to reactants or products, making them difficult to separate and study directly. In terms of calculations, although methods like nudged elastic band have been developed, they are time-consuming and often require running on a machine that contains 24 CPU for a day, with <span style="color: red;">**a very low**</span> success rate (around 30%) 😭.

However, with [the OA-ReactDiff architecture developed in our article](https://arxiv.org/abs/2304.06174), can reduce the search time for transition states **from several days to 6 seconds**, greatly accelerating the exploration of chemical reaction mechanisms, improving success rates, and enabling the study of more complex chemical reaction networks. We have prepared three examples to demonstrate the application scenarios of OA-ReactDiff 🎬.

## Correct symmetry removes the need for all cumbersome pre-processing<a id ='no-preprocessing'></a>
In computation, traditional machine learning methods usually require a one-to-one correspondence between reactants and products, and the geometric positions between each molecule in the reactants also need to be carefully adjusted if there are multiple reactants. These pre-processing steps, not only take time, are practically unfeasible in many unknown reactions.

However, OA-ReactDiff ensures the symmetry required in all chemical reactions, So we **don't need to do any pre-processing for the reactions**.

In [None]:
# --- Importing necessary function ---
from torch.utils.data import DataLoader

from oa_reactdiff.trainer.pl_trainer import DDPMModule


from oa_reactdiff.dataset import ProcessedTS1x
from oa_reactdiff.diffusion._schedule import DiffSchedule, PredefinedNoiseSchedule

from oa_reactdiff.diffusion._normalizer import FEATURE_MAPPING
from oa_reactdiff.analyze.rmsd import batch_rmsd

from oa_reactdiff.utils.sampling_tools import (
    assemble_sample_inputs,
    write_tmp_xyz,
)


### Import the pre-trained model and redefine the schedule.

In [None]:
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

ddpm_trainer = DDPMModule.load_from_checkpoint(
    checkpoint_path="./pretrained-ts1x-diff.ckpt",
    map_location=device,
)
ddpm_trainer = ddpm_trainer.to(device)

In [None]:
noise_schedule: str = "polynomial_2"
timesteps: int = 150
precision: float = 1e-5

gamma_module = PredefinedNoiseSchedule(
            noise_schedule=noise_schedule,
            timesteps=timesteps,
            precision=precision,
        )
schedule = DiffSchedule(
    gamma_module=gamma_module,
    norm_values=ddpm_trainer.ddpm.norm_values
)
ddpm_trainer.ddpm.schedule = schedule
ddpm_trainer.ddpm.T = timesteps
ddpm_trainer = ddpm_trainer.to(device)


### Prepare dataset and data loader and select a reaction involving multiple molecules

In [None]:
dataset = ProcessedTS1x(
    npz_path="./oa_reactdiff/data/transition1x/train.pkl",
    center=True,
    pad_fragments=0,
    device=device,
    zero_charge=False,
    remove_h=False,
    single_frag_only=False,
    swapping_react_prod=False,
    use_by_ind=True,
)
loader = DataLoader(
    dataset, 
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=dataset.collate_fn
)
itl = iter(loader)
idx = -1

len(dataset)

In [None]:
for _ in range(4):  # The 4th sample happens to be a multimolecular reaction
    representations, res = next(itl)
idx += 1
n_samples = representations[0]["size"].size(0)
fragments_nodes = [
    repre["size"] for repre in representations
]
conditions = torch.tensor([[0] for _ in range(n_samples)], device=device)



The order of atoms in the reactants is intentionally scrambled relative to the products, and two molecules in the products are pulled infinitely far apart (10 A).

In [None]:
new_order_react = torch.randperm(representations[0]["size"].item())
for k in ["pos", "one_hot", "charge"]:
    representations[0][k] = representations[0][k][new_order_react]
    
xh_fixed = [
    torch.cat(
        [repre[feature_type] for feature_type in FEATURE_MAPPING],
        dim=1,
    )
    for repre in representations
]

### Generating TS structures for this set of "reactants, products" without any atomic ordering and fragment arrangement.

In [None]:
out_samples, out_masks = ddpm_trainer.ddpm.inpaint(
    n_samples=n_samples,
    fragments_nodes=fragments_nodes,
    conditions=conditions,
    return_frames=1,
    resamplings=5,
    jump_length=5,
    timesteps=None,
    xh_fixed=xh_fixed,
    frag_fixed=[0, 2],
)

### Evaluating the quality of generated TS structures



The difference between two molecules with RMSD below 0.1A is almost indistinguishable to the naked eye.

<div>
    <img src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/20204/bcae8391a60845f990e6387bb4ac3977/4vUVwvt2hGtpcZwTw296Xg.png", alt="rmsd_visualy" width="850" title="rmsd_visual">
    <p style='font-size:0.8rem; font-weight:bold'>Figure 6 | Structural differences between two molecules at different RMSD values (The C atoms of the two molecules are indicated by blue and yellowish colors, respectively)</p>
</div>

Storing the generated transition state and checking its RMSD against the true transition state for this reaction, we have an RMSD of only 0.02A.



In [None]:
rmsds = batch_rmsd(
    fragments_nodes, 
    out_samples[0],
    xh_fixed,
    idx=1,
)
write_tmp_xyz(
    fragments_nodes, 
    out_samples[0], 
    idx=[0, 1, 2], 
    localpath="./demo/inpainting"
)

rmsds = [min(1, _x) for _x in rmsds]
[(ii, round(rmsd, 2)) for ii, rmsd in enumerate(rmsds)], np.mean(rmsds), np.median(rmsds)

It can be seen that the atomic numbers of the transition states we generate and the atomic numbers of the reactants do not need to correspond at all, showing that OA-ReactDiff does not require the pre-processing of atomic arrangements

In [None]:
! cat ./demo/inpainting/gen_0_react.xyz

In [None]:
! cat ./demo/inpainting/gen_0_ts.xyz


At the same time, although the product consists of multiple molecules, there is no need for a specific geometrical arrangement of them (we've already pulled two molecules to "infinity" before the creation)

In [None]:
draw_in_3dmol(
    open("./demo/inpainting/gen_0_prod.xyz", "r").read(),
    "xyz"
)


## Surprises for unintented reactions - Gifts from stochastic processes<a id ='stochasticity'></a>

There is a certain amount of randomness in the diffusion generation model due to the intervention of noise, which leads to the fact that given the 3D structure of the 'reactants, products' in time, the transition state structure that comes out of the generation is not the same every time.

So a natural question arises: do these randomly generated structures, apart from the one that is identical to the true transition state, <span style="color: orange;">**have any chemical significance and practical value**</span>❓
    
We will delve into a case study to find the answer to this question.

In [None]:
from glob import glob
import plotly.express as px

from oa_reactdiff.analyze.rmsd import xyz2pmg, pymatgen_rmsd

from pymatgen.core import Molecule
from collections import OrderedDict


def draw_reaction(react_path: str, idx: int = 0, prefix: str = "gen") -> py3Dmol.view:
    """Draw the {reactants, transition states, products} of the reaction.

    Args:
        react_path (str): path to the reaction.
        idx (int, optional): index for the generated reaction. Defaults to 0.
        prefix (str, optional): prefix for distinguishing true sample and generated structure.
            Defaults to "gen".

    Returns:
        py3Dmol.view: _description_
    """
    with open(f"{react_path}/{prefix}_{idx}_react.xyz", "r") as fo:
        natoms = int(fo.readline()) * 3
    mol = f"{natoms}\n\n"
    for ii, t in enumerate(["react", "ts", "prod"]):
        pmatg_mol = xyz2pmg(f"{react_path}/{prefix}_{idx}_{t}.xyz")
        pmatg_mol_prime = Molecule(
            species=pmatg_mol.atomic_numbers,
            coords=pmatg_mol.cart_coords + 8 * ii,
        )
        mol += "\n".join(pmatg_mol_prime.to(fmt="xyz").split("\n")[2:]) + "\n"
    viewer = py3Dmol.view(1024, 576)
    viewer.addModel(mol, "xyz")
    viewer.setStyle({'stick': {}, "sphere": {"radius": 0.3}})
    viewer.zoomTo()
    return viewer

### Draw our example reaction

The bottom left to top right corresponds to "reactants, transition states, and products".

In [None]:
draw_reaction("./demo/example-3/ground_truth", prefix="sample")



Take a closer look at the transition state structure.

In [None]:
draw_in_3dmol(
    open("./demo/example-3/ground_truth/sample_0_ts.xyz", "r").read(),
    "xyz"
)



### How many different TSs do we generate for the same reaction?

To save everyone's time, we have gone ahead and used OA-ReactDiff to generate multiple transition state structures for this 'reactant, product', this reaction corresponds to [the example on the right side of Figure 3 in our article](https://arxiv.org/abs/2304.06174).


Sort the generated transition states into categories.

In [None]:
opt_ts_path = "./demo/example-3/opt_ts/"
opt_ts_xyzs = glob(f"{opt_ts_path}/*ts.opt.xyz")

order_dict = {}
for xyz in opt_ts_xyzs:

    order_dict.update(
        {int(xyz.split("/")[-1].split(".")[0]): xyz}
    )
order_dict = OrderedDict(sorted(order_dict.items()))

opt_ts_xyzs = []
ind_dict = {}
for ii, v in enumerate(order_dict.values()):
    opt_ts_xyzs.append(v)
    ind_dict.update(
        {ii: v}
    )



Calculate the paired RMSD matrix between these transition states. As seen in our previous comparison plot, configurations of molecules with an RMSD lower than "0.1" can be considered the same.

In [None]:
n_ts = len(opt_ts_xyzs)
rmsd_mat = np.ones((n_ts, n_ts)) * -2.5
for ii in range(n_ts):
    for jj in range(ii+1, n_ts):
        try:
            rmsd_mat[ii, jj] = np.log10(
                pymatgen_rmsd(
                    opt_ts_xyzs[ii],
                    opt_ts_xyzs[jj],
                    ignore_chirality=True,
                )
            )
        except:
            print(ii, jj)
            pass
        rmsd_mat[jj, ii] = rmsd_mat[ii, jj]



Classify these generated transition states based on their structures using K-Means.

In [None]:
from sklearn.cluster import KMeans

def reorder_matrix(matrix, n_clusters):
    # Apply K-means clustering to rows and columns
    row_clusters = KMeans(n_clusters=n_clusters).fit_predict(matrix)
    
    # Create a permutation to reorder rows and columns
    row_permutation = np.argsort(row_clusters)
    col_permutation = np.argsort(row_clusters)

    # Apply the permutation to the matrix
    reordered_matrix = matrix[row_permutation][:, col_permutation]

    return reordered_matrix, row_permutation, row_clusters


n = n_ts  # Number of overall transition states
n_clusters = 6  # The number of clusters in our K-Means

reordered_matrix, row_permutation, row_clusters = reorder_matrix(rmsd_mat, n_clusters)



Display the classification results. It can be clearly seen that many transition state structures are quite similar (the red squares on the diagonal represent one cluster).

Note: We have applied a log10 transformation to the RMSD values. "-2" corresponds to "0.01A", and "-1" corresponds to "0.1A".

In [None]:
from IPython.display import HTML

fig = px.imshow(
    reordered_matrix, 
    color_continuous_scale="Oryel_r",
    range_color=[-2, -0.3],
)
fig.layout.font.update({"size": 18, "family": "Arial"})

fig.layout.update({"width": 650, "height": 500})
# HTML(fig.to_html())
fig.show()


Intuitively summarize which cluster each different transition state belongs to (id: '0' to '5').

Note: The cluster where `./demo/example-3/opt_ts/26.ts.opt.xyz` is located corresponds to the true transition state of the given {reactant, product}.

In [None]:
import json

cluster_dict = {}
for ii, cluster in enumerate(row_clusters):
    cluster = str(cluster)
    if cluster not in cluster_dict:
        cluster_dict[cluster] = [ind_dict[ii]]
    else:
        cluster_dict[cluster] += [ind_dict[ii]]

cluster_dict = OrderedDict(sorted(cluster_dict.items()))
cluster_dict


It can be seen that the geometries of the transition states in the different clusters are completely different.

In [None]:
draw_in_3dmol(
    open("./demo/example-3/generated/gen_36_ts.xyz", "r").read(),
    "xyz"
)

### Are the rest of the TS worthless?

Of course not. Any given transition state structure will correspond to a unique chemical reaction. However, in calculations, due to personal experience or computational methods, <span style="color: red;">**the chemical reactions that each person can explore are also limited**</span>.

This limitation is particularly fatal especially when studying unknown complex reactions. It can cause us to overlook some potentially possible reactions, leading to a misjudgment of the reaction mechanism, which in turn affects the idea of catalytic material design. [Dr. Qiyuan Zhao's article](https://www.nature.com/articles/s43588-021-00101-3) explores these phenomena in depth, if you are interested in you can dive into it.

In our case, OA-ReactDiff not only found the transition state of the reaction we wanted, but also explored five related "unintended" chemical reactions due to the stochastic process feature. This feature can <span style="color: orange;">**compensate for the existing chemistry-based intuitive reaction exploration framework**</span> 😄.


For example, the reaction we generated below is **an unintended new reaction**.

In [None]:
draw_reaction("./demo/example-3/irc", prefix="cluster", idx=1)


## Give me some atoms, and I can generate all the reactions about them<a id ='exploration'></a>

In addition to searching for transition states of known reactions, OA-ReactDiff can directly and unconditionally <span style="color: orange;">**generate new chemical reactions and be used to explore reaction networks**</span>

Here, we hypothesize a beaker with only four atoms, C, N, O, H, and use OA-ReactDiff to explore the possible reactions of all four atoms.


### Randomly Generated Reactions

In [None]:
xyz_path = "./demo/CNOH/"
n_samples = 128  # Total number of reactions generated
natm = 4  # Number of atoms in the reactant
fragments_nodes = [
    torch.tensor([natm] * n_samples, device=device),
    torch.tensor([natm] * n_samples, device=device),
    torch.tensor([natm] * n_samples, device=device),
]

conditions = torch.tensor([[0]] * n_samples, device=device)
h0 = assemble_sample_inputs(
    atoms=["C"] * 1 + ["O"] * 1 + ["N"] * 1 + ["H"] * 1,  # The atomic species of the reactants, in this case one each of CNOH
    device=device,
    n_samples=n_samples,
    frag_type=False,
)

It is highly recommended that you use a V100 GPU machine to perform this step (it takes 30 seconds to run). Otherwise, it may take around 10 minutes. you can make yourself a cup of 🍵 or do some push-ups.

In [None]:
out_samples, out_masks = ddpm_trainer.ddpm.sample(
    n_samples=n_samples,
    fragments_nodes=fragments_nodes,
    conditions=conditions,
    return_frames=1,
    timesteps=None,
    h0=h0,
)

Save the generated result as xyz files.

In [None]:
write_tmp_xyz(
    fragments_nodes, 
    out_samples[0], 
    idx=[0, 1, 2], 
    ex_ind=0,
    localpath=xyz_path,
)



Randomly draw a generated reaction.

In [None]:
idx = 2
assert idx < n_samples
views = draw_reaction(xyz_path, idx)
views



### Analyzing reaction results



#### Finding unique non-repeating stable Molecules

In [None]:
from glob import glob

from pymatgen.io.xyz import XYZ
from openbabel import pybel

from oa_reactdiff.analyze.rmsd import pymatgen_rmsd


def xyz_to_smiles(fname: str) -> str:
    """Convert molecules in xyz format to smiles format

    Args:
        fname (str): path to the xyz file.

    Returns:
        str: SMILES string.
    """
    mol = next(pybel.readfile("xyz", fname))
    smi = mol.write(format="can")
    return smi.split()[0].strip()


In [None]:
xyzfiles = glob(f"{xyz_path}/gen*_react.xyz") + glob(f"{xyz_path}/gen*_prod.xyz")
xyz_converter = XYZ(mol=None)
mol = xyz_converter.from_file(xyzfiles[0]).molecule
unique_mols = {xyzfiles[0]: mol}
for _xyzfile in xyzfiles:
    _mol = xyz_converter.from_file(_xyzfile).molecule
    min_rmsd = 100
    for _, mol in unique_mols.items():
        rmsd = pymatgen_rmsd(mol, _mol, ignore_chirality=True, threshold=0.5)
        min_rmsd = min(min_rmsd, rmsd)
    if min_rmsd > 0.1:  # If the rmsd with an existing molecule is all greater than 0.1, then it is considered a new molecule
        unique_mols.update({_xyzfile: _mol})
        
len(unique_mols)


Only consider stable (i.e. non-radical, etc.) molecules here.

In [None]:
unique_idx = []
unique_smiles = []
idx = 0
for file in unique_mols:
    smi = xyz_to_smiles(file)
    if smi not in unique_smiles and not "." in smi:
        unique_smiles.append(smi)
        unique_idx.append(idx)
    idx += 1
unique_idx, unique_smiles  # Unique molecular counterparts to reactive indexes and smiles



Everyone can list it by themselves and see if the molecules produced here cover all the CONH tetraatom stabilized molecules 😁

Now let's draw one of them.

In [None]:
idx = np.random.choice(unique_idx, 1)[0]

draw_in_3dmol(open(list(unique_mols.keys())[idx], "r").read(), "xyz")



#### Unique Reaction Pathway

We only analyze reactions between stable molecules

In [None]:
unique_paths = {}
path_index = {}
for ii in range(n_samples):
    r_xyz = f"{xyz_path}/gen_{ii}_react.xyz"
    p_xyz = f"{xyz_path}/gen_{ii}_prod.xyz"
    path = set([xyz_to_smiles(r_xyz), xyz_to_smiles(p_xyz)])
    use = True
    for smi in path:
        if smi not in unique_smiles:
            use = False
    if not path in unique_paths.values() and len(path) > 1 and use:
        unique_paths[ii] = path
    if not (len(path) > 1 and use):
        continue
    sorted_smi = " & ".join(list(sorted(path)))
    if sorted_smi not in path_index:
        path_index[sorted_smi] = [ii]
    else:
        path_index[sorted_smi] += [ii]
mols_in_paths = []
for k, v in unique_paths.items():
    for _v in v:
        if not _v in mols_in_paths:
            mols_in_paths.append(_v)
            
mols_in_paths, len(mols_in_paths)


All CNOH tetraatomic stabilization reactions are as follows:

In [None]:
unique_paths, len(unique_paths)


select one and see if this reaction is chemically reasonable 😄

In [None]:
idx = np.random.choice(list(unique_paths.keys()), 1)[0]
views = draw_reaction(xyz_path, idx)
views



## Outlook: Let's "Generate" the Future <a id ='outlook'></a>
<div class="alert alert-success">
    <b>🎖️Congratulations! You have finished reading this notebook.</b>
</div>
    
A major challenge in the discovery of chemical materials lies in its enormous ($10^{60}$) potential design space. This space is so large that even if we were to buy all NVIDIA graphics cards and exhaust the global power supply, it would still be impossible to explore. Generative AI <span style="color: orange;">**bypasses the difficulty of screening the material space**</span> and directly generates potentially valuable molecules and materials, bringing new possibilities for material discovery.

Expansion of diffusion generation modeling for chemical materials is still in its early stages of infancy. If you have ideas about applying OA-ReactDiff to your own problems but are unsure if it is completely suitable or don't know where to start, please feel free to contact the author at duanchenru@gmail.com👏🏻

