# Score based diffusion model for molecule generation.

# 1. Itroduction of Score-Based Diffusion Modeling through Stochastic Differential Equations.

Diffusion model is a kind of generative model, which has gained extensive attention in generating high-quality samples in recent years.  They use the principle of diffusion process to generate complex data distribution from simple noise distribution by way of gradual de-noising. Diffusion generative models can be primarily divided into three categories: Score Matching with Langevin Dynamics (SMLD), Denoising Diffusion Probabilistic Models (DDPM), and Score-based Generative Model (SGM). Here, we primarily use SGM, precisely VESDE.

### *Variance Exploding Stochastic Differential Equation (VESDE) Diffusion*

Variance Exploding Stochastic Differential Equations (VESDE) diffusion models are a class of score-based generative models. These models leverage the principles of stochastic processes, particularly stochastic differential equations (SDEs), to model the evolution of data points over continuous time. (Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." arXiv preprint arXiv:2011.13456 (2020).)

### *Core Concept*

At the heart of VESDE diffusion is the idea of modeling the data distribution as a process that evolves over time according to a stochastic differential equation. This evolution is characterized by the variance of the data points increasing (or "exploding") as time progresses, which is a key feature distinguishing VESDE from other types of SDE-based models like Variance Preserving SDE (VPSDE).

### *Mathematical Formulation*

The VESDE can be defined by the following SDE:

$$dx_{t} = f(x_t, t)dt + g(t)dW_t$$

where:

- $x_t$ represents the data point at time $t$
- $f(x_t, t)$ is the drift term that dictates the deterministic part of the evolution.
- $g(t)$ is the diffusion term that controls the magnitude of the stochastic component.
- $W_t$ is a standard Wiener process (also known as Brownian motion).

### *Training and Sampling*

Training a VESDE model involves learning to reverse this diffusion process. During training, the model is trained to denoise corrupted data samples, effectively learning the reverse dynamics of the SDE. This reverse process can be described by the reverse-time SDE, which can be solved using techniques like the Euler-Maruyama method.

Once trained, sampling from a VESDE model involves simulating the reverse-time SDE starting from a noise distribution (typically Gaussian noise) and evolving it backward in time to obtain a sample from the learned data distribution.

### *Molecule Generation*

The coordinate space of the molecule is a continuous space, so we want to use the diffusion model to generate the three-dimensional conformation of the molecule.

Here is the main code for the socre-based model:

```python
class VESDE(torch.nn.Module):
    def __init__(
        self,
        score_model: torch.nn.Module,
        schedule: DiffSchedule,
) -> None:
        super().__init__()
        self.score_model = score_model
        self.schedule = schedule
        
    def forward(self, pos, atomic_numbers, mask):
        t = self.schedule(mask)
        pos_dim = pos.shape[-1]
        nodes_mask = get_batch_mask_for_nodes(mask)
        edge_index = get_full_edges_index(nodes_mask, remove_self_edge=False)

        # sample zero CoM noise
        noise = sample_center_gravity_zero_gaussian_batch(
            (pos.shape[0], pos_dim), nodes_mask
        )
        nodes_t = t[nodes_mask]
        std = self.schedule.marginal_prob_std(nodes_t)
        perturbed_pos = pos + noise * std[:, None]
        
        # compute score
        score = self.score_model(atomic_numbers, nodes_t, perturbed_pos, edge_index)
        score = score / std[:, None] # normalize score
        if torch.any(torch.isnan(score)):
            print('nan in score, resetting to randn')
            score = torch.randn_like(score, requires_grad=True)
        score = remove_mean_batch(score, nodes_mask)
        
        l2loss = torch.mean(torch.sum((score * std[:, None] + noise)**2, dim=-1))
        return l2loss

    @torch.no_grad()
    def sample(
        self,
        atomic_numbers: torch.Tensor,
        mask: torch.Tensor,
        num_steps: int=500,
        t_mode: str='linear',
    )-> Tuple[torch.Tensor, List[torch.Tensor]]:
        '''
        Sample a mols and return the trajectory using Euler Maruyama sampler.
        '''
        device = atomic_numbers.device
        pos_shape = [atomic_numbers.size(0), 3]
        t = torch.ones(len(atomic_numbers), device=atomic_numbers.device)
        nodes_mask = get_batch_mask_for_nodes(mask)
        
        # sample zero CoM noise as initial position
        init_com = sample_center_gravity_zero_gaussian_batch(
            (pos_shape[0], pos_shape[1]), nodes_mask
        )
        init_pos = init_com * self.schedule.marginal_prob_std(t)[:, None]
        edge_index = get_full_edges_index(nodes_mask, remove_self_edge=False)
        num_steps = torch.tensor(num_steps, device=device)
        time_steps = self.schedule.sample_t(num_steps, mode=t_mode)
        step_sizes = torch.cat((-torch.diff(time_steps), time_steps[-1].unsqueeze(0)))
        
        # sample batch of mols
        pos = init_pos
        trajs = []
        for time_step, step_size in zip(time_steps, step_sizes):
            batch_time_step = torch.ones(pos.size(0), device=device) * time_step
            g = self.schedule.diffusion_coeff(batch_time_step)
            score = self.score_model(atomic_numbers, batch_time_step, pos, edge_index)
            
            # normalize score
            if torch.any(torch.isnan(score)):
                print('nan in score, resetting to randn')
                score = torch.randn_like(score)
            score = score / self.schedule.marginal_prob_std(batch_time_step)[:, None]
            score = remove_mean_batch(score, nodes_mask)
            mean_pos = pos + (g**2)[:, None] * score * step_size
            noise = sample_center_gravity_zero_gaussian_batch(
                (pos_shape[0], pos_shape[1]), nodes_mask
            )
            pos = mean_pos + torch.sqrt(step_size) * g[:, None] * noise
            trajs.append(mean_pos)
        return mean_pos, trajs
```

# 2. Equivariant Graph Neural Networks.

Equivariant Graph Neural Networks (EGNNs) are a class of specially designed Graph Neural Networks that are equivariant under certain transformations of graphs (such as rotation, translation, reflection, etc.). This means that if these transformations are made to the input graph, the output of the model will undergo the same transformation accordingly. This property is especially useful when dealing with tasks with spatial symmetries (such as molecular modeling, physical simulations, etc.), as it guarantees that the model is invariant or isotropic to these transformations and is able to better capture and exploit the symmetries in the data.

PaiNN (Polarizable atom interaction neural network) is a equivariant graph neural network specifically designed for molecular modeling and prediction. PaiNN's design philosophy aims to incorporate interaction models from physics to better predict the physical and chemical properties of molecules. (Schütt, Kristof, Oliver Unke, and Michael Gastegger. "Equivariant message passing for the prediction of tensorial properties and molecular spectra." International Conference on Machine Learning. PMLR, 2021.)

We need a neural network to fit the noise in diffusion models, and since molecules can naturally be represented as graphs, we will use the equivariant PaiNN as our graph neural network for training. 

Let's verify the equivariance of PaiNN:

In [10]:
import torch
from model.painn import PaiNN

model = PaiNN()
atomic_numbers = torch.tensor([1, 6, 6, 1, 1, 1, 1, 1, 1], dtype=torch.long) # H, C, C, H, H, H, H, H, H
t = torch.tensor([0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4], dtype=torch.float) # random time
from torch_geometric.nn import radius_graph
pos = torch.tensor([[ 0.0072, -0.5687,  0.0000],
    [-1.2854,  0.2499,  0.0000],
    [ 1.1304,  0.3147,  0.0000],
    [ 0.0392, -1.1972,  0.8900],
    [ 0.0392, -1.1972, -0.8900],
    [-1.3175,  0.8784,  0.8900],
    [-1.3175,  0.8784, -0.8900],
    [-2.1422, -0.4239,  0.0000],
    [ 1.9857, -0.1365,  0.0000]], dtype = torch.float)

edge_index = radius_graph(pos, r=1.70, batch=None, loop=False)
from e3nn import o3
rot = o3.rand_matrix()
pos_rot = pos @ rot
out_rot = model(atomic_numbers, t, pos, edge_index) @ rot
out = model(atomic_numbers, t, pos_rot, edge_index)
print('The molecular input to the model rotates, and the output should rotate simultaneously. \nSo this number should be zero:', (out - out_rot).max())

The molecular input to the model rotates, and the output should rotate simultaneously. 
So this number should be zero: tensor(1.0687e-11, grad_fn=<MaxBackward1>)


In [2]:
# print the architecture of PaiNN
print(model)

PaiNN(
  (atom_emb): AtomEmbedding(
    (embeddings): Embedding(83, 512)
  )
  (t_emb): GaussianFourierProjection()
  (embedding): Linear(in_features=1024, out_features=512, bias=True)
  (radial_basis): RadialBasis(
    (envelope): PolynomialEnvelope()
    (rbf): GaussianSmearing()
  )
  (message_layers): ModuleList(
    (0-5): 6 x PaiNNMessage()
  )
  (update_layers): ModuleList(
    (0-5): 6 x PaiNNUpdate(
      (vec_proj): Linear(in_features=512, out_features=1024, bias=False)
      (xvec_proj): Sequential(
        (0): Linear(in_features=1024, out_features=512, bias=True)
        (1): ScaledSiLU(
          (_activation): SiLU()
        )
        (2): Linear(in_features=512, out_features=1536, bias=True)
      )
    )
  )
  (out_xh): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ScaledSiLU(
      (_activation): SiLU()
    )
    (2): Linear(in_features=256, out_features=8, bias=True)
  )
  (out_dpos): PaiNNOutput(
    (output_network): ModuleList(

# 3. Results

Load the trained pytorch lightning model.

In [3]:
from model.pl import pl_module

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lightning_module = pl_module.load_from_checkpoint(
    checkpoint_path="checkpoint.ckpt",
    map_location=torch.device(device),
)

Predefine the molecular format with the molecular formula C4H7NO, since graph neural networks are permutation-invariant, the order of the atoms does not matter.

In [4]:
# Atomic number 6 is carbon, 1 is hydrogen, 7 is nitrogen, 8 is oxygen
C4 = torch.tensor([6, 6, 6, 6], dtype=torch.long, device=device)
H7 = torch.tensor([1, 1, 1, 1, 1, 1, 1], dtype=torch.long, device=device)
N1 = torch.tensor([7], dtype=torch.long, device=device)
O1 = torch.tensor([8], dtype=torch.long, device=device)

atomic_numbers = torch.cat([C4, H7, N1, O1])
assert len(atomic_numbers) == 13

In [5]:
# Predifine 4 molecules of identical composition
atomic_numbers = atomic_numbers.repeat(4, 1).reshape(-1)
mask = torch.tensor([13, 13, 13, 13], dtype=torch.long, device=device)
assert len(atomic_numbers) == 52

Sample molecules from gaussian noise by reversed SDE, using Euler Maruyama sampler.

In [6]:
from datetime import datetime
import os
from model.io import write_batch_xyz

mols_pos, trajs = lightning_module.en_diffusion.sample(
    atomic_numbers,
    mask,
    num_steps=500,
    t_mode='cosine',
)

time_point = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_dir = f'demo/{time_point}'
os.makedirs(save_dir, exist_ok=True)
write_batch_xyz(save_dir, atomic_numbers, mols_pos, mask)
print(f"saved samples to {save_dir}")


saved samples to demo/2024-05-31_21-10-08


The SMILES formula of the generated molecules is automatically identified using openbabel.

In [7]:
import subprocess

smis = []
for i in range(4):
    smi = subprocess.run(
        'obabel ' +  f'./demo/{time_point}/mol_{i}.xyz -osmi',
        capture_output=True,
        text=True,
        shell=True,
    ).stdout
    smis.append(smi.split('\t')[0])

In [8]:
smis

['[C]12CN[C@H]1[CH]2.O', 'C1C=C(CO1)N', '[C](C(=O)[NH])(C)C', 'C1CNCC1=O']

Visualizing molecules.

In [9]:
from model.visualize import draw_mol

for i in range(4):
    draw_mol(f'./demo/{time_point}/mol_{i}.xyz')

The results show that our model can generate diverse and stable structures. All molecules satisfy Lewis' rule and conform to chemical intuition.