## Inference Notebook 

We provide a minimal notebook to run the sampling and inference of a EQGAT-Diff model trained on the GEOM-Drugs dataset.
This notebook is in the `inference/` subdirectory.

We append the `../eqgat_diff` directory into the path to load all required modules.

The GEOM-Drugs* dataset is saved in `../data/geom/` and for size reason, we uploaded only the dataset statistics, like, the empirical distribution for molecule size, or atom- and edge-features.

The original training/validation and test sets are currently not provided, as this notebook only serves as inference showcase.

The model weights are provided upon request as of now.

In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import sys
sys.path.append("../eqgat_diff")

In [3]:
import argparse
import os
import pickle
import warnings

import numpy as np
import torch
from tqdm import tqdm

warnings.filterwarnings(
    "ignore", category=UserWarning, message="TypedStorage is deprecated"
)


In [4]:
import rdkit
from rdkit.Chem.Draw import IPythonConsole
from rdkit import Chem
import nglview
IPythonConsole.ipython_useSVG = True 
IPythonConsole.molSize = 400, 400
IPythonConsole.drawOptions.addAtomIndices = True
IPythonConsole.ipython_3d = True



In [5]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""

    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


In [6]:
from experiments.diffusion_discrete import Trainer
from experiments.data.geom.geom_dataset_adaptive import (
    GeomDataModule as DataModule,
)
from experiments.data.data_info import GeneralInfos as DataInfos

In [7]:
model_path = "../weights/geom/best_mol_stab.ckpt"

In [8]:
# load hyperparameter
ckpt = torch.load(model_path, map_location="cpu")
hparams = ckpt["hyper_parameters"]
hparams["select_train_subset"] = False
hparams["diffusion_pretraining"] = False
hparams["num_charge_classes"] = 6
hparams = dotdict(hparams)

hparams.load_ckpt_from_pretrained = None
hparams.load_ckpt = None
hparams.gpus = 1

print(f"Loading {hparams.dataset} Datamodule.")

Loading drugs Datamodule.


In [9]:
hparams.dataset_root = "../data/geom"

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
device

device(type='cuda')

In [11]:
datamodule = DataModule(hparams, only_stats=True)

In [12]:
dataset_info = DataInfos(datamodule, hparams)

In [13]:
train_smiles = (
    list(datamodule.train_dataset.smiles)
    if hparams.dataset != "pubchem"
    else datamodule.train_smiles
)
prop_norm, prop_dist = None, None

In [14]:
model = Trainer.load_from_checkpoint(model_path,
                                     dataset_info=dataset_info,
                                     smiles_list=train_smiles,
                                     prop_norm=prop_norm,
                                     prop_dist=prop_dist,
                                     load_ckpt_from_pretrained=None,
                                     load_ckpt=None,
                                     run_evaluation=True,
                                     strict=False,
                                ).to(device)
model = model.eval()

In [15]:
save_dir = "tmp"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
save_dir = "tmp/geom"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

### Generating

We generate in total 100 molecules where each batch consists of 50 samples.
On an H100, the running time should be about 57s per batch.  
The iterator goes over the reverse diffusion time steps $T=500$.  

The device is set to "cpu" in the model.run_evaluation since the evaluation statistics can be computed on cpu.

In [16]:
step=0
ngraphs=100
batch_size=50

In [17]:
with torch.no_grad():
    results_dict, generated_smiles, stable_molecules = model.run_evaluation(
        step=step,
        dataset_info=model.dataset_info,
        ngraphs=ngraphs,
        bs=batch_size,
        return_molecules=True,
        verbose=True,
        inner_verbose=True,
        run_test_eval=True,
        device="cpu",
        save_dir=save_dir,
    )

Creating 100 graphs in [50, 50] batches


100%|██████████| 500/500 [00:57<00:00,  8.72it/s]
100%|██████████| 500/500 [00:58<00:00,  8.49it/s]


Analyzing molecule stability
Error messages: AtomValence 8, Kekulize 10, other 0,  -- No error 78
Validity over 100 molecules: 78.00%
Number of connected components of 100 molecules: mean:1.04 max:2.00
Connected components of 100 molecules: 96.00
Sparsity level on local rank 0: 95 %
Run time=0:03:34.810729
{'mol_stable': 0.9200000166893005, 'atm_stable': 0.9973741769790649, 'validity': 0.7799999713897705, 'sanitize_validity': 0.82, 'novelty': 1.0, 'uniqueness': 1.0, 'sampling/NumNodesW1': 1.3309286832809448, 'sampling/AtomTypesTV': 0.041406456381082535, 'sampling/EdgeTypesTV': 0.024203144013881683, 'sampling/ChargeW1': 0.0018256631447002292, 'sampling/ValencyW1': 0.013406765647232533, 'sampling/BondLengthsW1': 0.0007810961687937379, 'sampling/AnglesW1': 0.7763100862503052, 'connected_components': 96.0, 'bulk_similarity': 0.11161660142854338, 'bulk_diversity': 0.8921453170515329, 'kl_score': 0.7774316506076151, 'QED': 0.6076937689257759, 'SA': 0.7434615384615384, 'LogP': 2.3834048717948

In [18]:
results_dict

Unnamed: 0,mol_stable,atm_stable,validity,sanitize_validity,novelty,uniqueness,sampling/NumNodesW1,sampling/AtomTypesTV,sampling/EdgeTypesTV,sampling/ChargeW1,...,kl_score,QED,SA,LogP,Lipinski,Diversity,step,epoch,run_time,ngraphs
0,0.92,0.997374,0.78,0.82,1.0,1.0,1.330929,0.041406,0.024203,0.001826,...,0.777432,0.607694,0.743462,2.383405,4.923077,0.699524,0,0,0:03:34.810729,100


In [19]:
nglview.show_rdkit(stable_molecules[0].rdkit_mol)

NGLWidget()

### Generating only valid molecules

In [20]:
with torch.no_grad():
    results_dict, generated_smiles, stable_molecules = model.generate_valid_samples(
        dataset_info=model.dataset_info,
        ngraphs=100,
        bs=batch_size,
        return_molecules=True,
        verbose=True,
        inner_verbose=True,
        device="cpu",
        save_dir=save_dir,
    )

Creating 100 graphs in [50, 50] batches


100%|██████████| 500/500 [00:57<00:00,  8.64it/s]
100%|██████████| 500/500 [00:51<00:00,  9.64it/s]


Analyzing molecule stability
Error messages: AtomValence 5, Kekulize 9, other 0,  -- No error 80
Validity over 100 molecules: 80.00%
Number of connected components of 100 molecules: mean:1.06 max:2.00
Connected components of 100 molecules: 94.00
Creating 40 graphs in [40] batches


100%|██████████| 500/500 [00:46<00:00, 10.84it/s]


Analyzing molecule stability
Error messages: AtomValence 2, Kekulize 2, other 0,  -- No error 34
Validity over 40 molecules: 85.00%
Number of connected components of 40 molecules: mean:1.05 max:2.00
Connected components of 40 molecules: 95.00
Analyzing molecule stability
Error messages: AtomValence 0, Kekulize 0, other 0,  -- No error 114
Validity over 114 molecules: 100.00%
Number of connected components of 114 molecules: mean:1.00 max:1.00
Connected components of 114 molecules: 100.00
Sparsity level on local rank 0: 95 %
Run time=0:07:14.022570
{'mol_stable': 1.0, 'atm_stable': 1.0, 'validity': 1.0, 'sanitize_validity': 1.0, 'novelty': 1.0, 'uniqueness': 1.0, 'sampling/NumNodesW1': 1.8128700256347656, 'sampling/AtomTypesTV': 0.03898381441831589, 'sampling/EdgeTypesTV': 0.027824774384498596, 'sampling/ChargeW1': 0.002566906390711665, 'sampling/ValencyW1': 0.012142395600676537, 'sampling/BondLengthsW1': 0.0007226535235531628, 'sampling/AnglesW1': 0.5905836820602417, 'connected_componen

In [21]:
nglview.show_rdkit(stable_molecules[0].rdkit_mol)

NGLWidget()

In [22]:
nglview.show_rdkit(stable_molecules[1].rdkit_mol)

NGLWidget()

In [23]:
nglview.show_rdkit(stable_molecules[-1].rdkit_mol)

NGLWidget()

In [24]:
nglview.show_rdkit(stable_molecules[-2].rdkit_mol)

NGLWidget()