# Score based diffusion model for molecule generation.

Firstly, load the trained pytorch lightning model.

In [1]:
import torch
from model import pl_module

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lightning_module = pl_module.load_from_checkpoint(
    checkpoint_path="tb_logs/painn/39929/checkpoints/scorenet-epoch=399-avg_val_loss=0.714.ckpt",
    map_location=torch.device("cpu"),
)
# print architecture of the lightning model
print(lightning_module)

  return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count


pl_module(
  (diffschedule): DiffSchedule()
  (score_model): PaiNN(
    (atom_emb): AtomEmbedding(
      (embeddings): Embedding(83, 256)
    )
    (t_emb): GaussianFourierProjection()
    (embedding): Linear(in_features=512, out_features=256, bias=True)
    (radial_basis): RadialBasis(
      (envelope): PolynomialEnvelope()
      (rbf): GaussianSmearing()
    )
    (message_layers): ModuleList(
      (0-5): 6 x PaiNNMessage()
    )
    (update_layers): ModuleList(
      (0-5): 6 x PaiNNUpdate(
        (vec_proj): Linear(in_features=256, out_features=512, bias=False)
        (xvec_proj): Sequential(
          (0): Linear(in_features=512, out_features=256, bias=True)
          (1): ScaledSiLU(
            (_activation): SiLU()
          )
          (2): Linear(in_features=256, out_features=768, bias=True)
        )
      )
    )
    (out_xh): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): ScaledSiLU(
        (_activation): SiLU()
      )
      (2)

Predefine the molecular format with the molecular formula C4H7NO, since graph neural networks are permutation-invariant, the order of the atoms does not matter.

In [2]:
# Atomic number 6 is carbon, 1 is hydrogen, 7 is nitrogen, 8 is oxygen
C4 = torch.tensor([6, 6, 6, 6], dtype=torch.long, device=device)
H7 = torch.tensor([1, 1, 1, 1, 1, 1, 1], dtype=torch.long, device=device)
N1 = torch.tensor([7], dtype=torch.long, device=device)
O1 = torch.tensor([8], dtype=torch.long, device=device)

atomic_numbers = torch.cat([C4, H7, N1, O1])
assert len(atomic_numbers) == 13

In [3]:
# Predifine 4 molecules of identical composition
atomic_numbers = atomic_numbers.repeat(4, 1).reshape(-1)
mask = torch.tensor([13, 13, 13, 13], dtype=torch.long, device=device)
assert len(atomic_numbers) == 52

Sample molecules from gaussian noise by reversed SDE, using Euler Maruyama sampler.

In [4]:
from datetime import datetime
import os
from model.io import write_batch_xyz

mols_pos, trajs = lightning_module.en_diffusion.sample(
    atomic_numbers,
    mask,
    num_steps=1000,
    t_mode='linear',
)

time_point = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_dir = f'demo/{time_point}'
os.makedirs(save_dir, exist_ok=True)
write_batch_xyz(save_dir, atomic_numbers, mols_pos, mask)
print(f"saved samples to {save_dir}")


saved samples to demo/2024-05-28_01-20-44


The SMILES formula of the generated molecules is automatically identified using openbabel.

In [5]:
import subprocess

smis = []
for i in range(4):
    smi = subprocess.run(
        'obabel ' +  f'./demo/{time_point}/mol_{i}.xyz -osmi',
        capture_output=True,
        text=True,
        shell=True,
    ).stdout
    smis.append(smi.split('\t')[0])

In [6]:
smis

['C(C#CCN)O', 'C[C](C(=O)C)[NH]', 'C1(CC1)NC=O', 'C(=O)[C@@H]1CCN1']

Visualizing molecules.

In [7]:
from model.visualize import draw_mol

for i in range(4):
    draw_mol(f'./demo/{time_point}/mol_{i}.xyz')

The results show that our model can generate diverse and stable structures.