## Inference Notebook 

We provide a minimal notebook to run the sampling and inference of a EQGAT-Diff model trained on the QM9 dataset.
This notebook is in the `inference/` subdirectory.

We append the `../eqgat_diff` directory into the path to load all required modules.

The QM9 dataset is saved in `../data/qm9/` and for size reason, we uploaded only the dataset statistics, like, the empirical distribution for molecule size, or atom- and edge-features.

The original training/validation and test sets are currently not provided, as this notebook only serves as inference showcase.

The model weights are provided upon request as of now.

In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import sys
sys.path.append("../eqgat_diff")

In [3]:
import argparse
import os
import pickle
import warnings

import numpy as np
import torch
from tqdm import tqdm

warnings.filterwarnings(
    "ignore", category=UserWarning, message="TypedStorage is deprecated"
)

In [4]:
import rdkit
from rdkit.Chem.Draw import IPythonConsole
from rdkit import Chem
import nglview
IPythonConsole.ipython_useSVG = True 
IPythonConsole.molSize = 400, 400
IPythonConsole.drawOptions.addAtomIndices = True
IPythonConsole.ipython_3d = True



In [5]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""

    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


In [6]:
from experiments.diffusion_discrete import Trainer
from experiments.data.qm9.qm9_dataset import (
    QM9DataModule as DataModule,
)
from experiments.data.data_info import GeneralInfos as DataInfos

In [7]:
model_path = "../weights/qm9/best_mol_stab.ckpt"

In [8]:
# load hyperparameter
ckpt = torch.load(model_path, map_location="cpu")
hparams = ckpt["hyper_parameters"]
hparams["select_train_subset"] = False
hparams["diffusion_pretraining"] = False
hparams["num_charge_classes"] = 6
hparams = dotdict(hparams)

hparams.load_ckpt_from_pretrained = None
hparams.load_ckpt = None
hparams.gpus = 1

print(f"Loading {hparams.dataset} Datamodule.")

Loading qm9 Datamodule.


In [9]:
hparams.dataset_root = "../data/qm9"

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
device

device(type='cuda')

In [11]:
datamodule = DataModule(hparams, only_stats=True)

In [12]:
dataset_info = DataInfos(datamodule, hparams)

In [13]:
train_smiles = (
    list(datamodule.train_dataset.smiles)
    if hparams.dataset != "pubchem"
    else datamodule.train_smiles
)
prop_norm, prop_dist = None, None

In [14]:
model = Trainer.load_from_checkpoint(model_path,
                                     dataset_info=dataset_info,
                                     smiles_list=train_smiles,
                                     prop_norm=prop_norm,
                                     prop_dist=prop_dist,
                                     load_ckpt_from_pretrained=None,
                                     load_ckpt=None,
                                     run_evaluation=True,
                                     strict=False,
                                ).to(device)
model = model.eval()

In [15]:
save_dir = "tmp"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
save_dir = "tmp/qm9"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

### Generating

We generate in total 100 molecules where each batch consists of 50 samples. 
On an H100, the running time should be about 15s per batch.   
The iterator goes over the reverse diffusion time steps $T=500$.  

The device is set to "cpu" in the model.run_evaluation since the evaluation statistics can be computed on cpu.

In [16]:
step=0
ngraphs=100
batch_size=50

In [17]:
with torch.no_grad():
    results_dict, generated_smiles, stable_molecules = model.run_evaluation(
        step=step,
        dataset_info=model.dataset_info,
        ngraphs=ngraphs,
        bs=batch_size,
        return_molecules=True,
        verbose=True,
        inner_verbose=True,
        run_test_eval=True,
        device="cpu",
        save_dir=save_dir,
    )

Creating 100 graphs in [50, 50] batches


100%|██████████| 500/500 [00:14<00:00, 33.38it/s]
100%|██████████| 500/500 [00:15<00:00, 32.22it/s]


Analyzing molecule stability
Error messages: AtomValence 1, Kekulize 0, other 0,  -- No error 99
Validity over 100 molecules: 99.00%
Number of connected components of 100 molecules: mean:1.00 max:1.00
Connected components of 100 molecules: 100.00
Sparsity level on local rank 0: 87 %
Run time=0:00:47.795955
{'mol_stable': 0.9800000190734863, 'atm_stable': 0.9982896447181702, 'validity': 0.9900000095367432, 'sanitize_validity': 0.99, 'novelty': 0.4848484992980957, 'uniqueness': 1.0, 'sampling/NumNodesW1': 0.49896612763404846, 'sampling/AtomTypesTV': 0.017466576769948006, 'sampling/EdgeTypesTV': 0.010771473869681358, 'sampling/ChargeW1': 0.004673334304243326, 'sampling/ValencyW1': 0.025954075157642365, 'sampling/BondLengthsW1': 0.010890845209360123, 'sampling/AnglesW1': 0.9403378963470459, 'connected_components': 100.0, 'bulk_similarity': 0.0760997809949149, 'bulk_diversity': 0.9184974415653269, 'kl_score': 0.9143664893385509, 'QED': 0.4479850513289507, 'SA': 0.5816161616161617, 'LogP': -

In [18]:
results_dict

Unnamed: 0,mol_stable,atm_stable,validity,sanitize_validity,novelty,uniqueness,sampling/NumNodesW1,sampling/AtomTypesTV,sampling/EdgeTypesTV,sampling/ChargeW1,...,kl_score,QED,SA,LogP,Lipinski,Diversity,step,epoch,run_time,ngraphs
0,0.98,0.99829,0.99,0.99,0.484848,1.0,0.498966,0.017467,0.010771,0.004673,...,0.914366,0.447985,0.581616,-0.249105,4.949495,0.880493,0,0,0:00:47.795955,100


In [19]:
nglview.show_rdkit(stable_molecules[0].rdkit_mol)

NGLWidget()

### Generating only valid molecules

In [20]:
with torch.no_grad():
    results_dict, generated_smiles, stable_molecules = model.generate_valid_samples(
        dataset_info=model.dataset_info,
        ngraphs=ngraphs,
        bs=batch_size,
        return_molecules=True,
        verbose=True,
        inner_verbose=True,
        device="cpu",
        save_dir=save_dir,
    )

Creating 100 graphs in [50, 50] batches


100%|██████████| 500/500 [00:15<00:00, 32.14it/s]
100%|██████████| 500/500 [00:15<00:00, 32.84it/s]


Analyzing molecule stability
Error messages: AtomValence 0, Kekulize 0, other 0,  -- No error 100
Validity over 100 molecules: 100.00%
Number of connected components of 100 molecules: mean:1.00 max:1.00
Connected components of 100 molecules: 100.00
Analyzing molecule stability
Error messages: AtomValence 0, Kekulize 0, other 0,  -- No error 100
Validity over 100 molecules: 100.00%
Number of connected components of 100 molecules: mean:1.00 max:1.00
Connected components of 100 molecules: 100.00
Sparsity level on local rank 0: 87 %
Run time=0:00:58.614853
{'mol_stable': 1.0, 'atm_stable': 1.0, 'validity': 1.0, 'sanitize_validity': 1.0, 'novelty': 0.6600000262260437, 'uniqueness': 1.0, 'sampling/NumNodesW1': 0.5176378488540649, 'sampling/AtomTypesTV': 0.02442755550146103, 'sampling/EdgeTypesTV': 0.007449071388691664, 'sampling/ChargeW1': 0.009620936587452888, 'sampling/ValencyW1': 0.03428169712424278, 'sampling/BondLengthsW1': 0.0013829406816512346, 'sampling/AnglesW1': 0.7992836236953735,