In [1]:
import os
import torch
import numpy as np
import rdkit
from rdkit import Chem
import matplotlib.pyplot as plt
import pandas as pd
from rdkit.Chem import AllChem
from tqdm import tqdm 
from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')

In [2]:
from torch_geometric.data import Data, Batch 
import nglview
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw
from ase.visualize import view
from ase import Atoms
from xtb.ase.calculator import XTB
IPythonConsole.ipython_useSVG = True 
IPythonConsole.molSize = 300, 300
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False
%matplotlib notebook



In [3]:
import py3Dmol
from rdkit import Chem
from rdkit.Chem import AllChem
import ipywidgets
from ipywidgets import interact, interactive, fixed, IntSlider

In [4]:
import pickle

In [5]:
def MolTo3DView(mol, size=(300, 300), style="stick", surface=False, opacity=0.5):
    """Draw molecule in 3D
    
    Args:
    ----
        mol: rdMol, molecule to show
        size: tuple(int, int), canvas size
        style: str, type of drawing molecule
               style can be 'line', 'stick', 'sphere', 'carton'
        surface, bool, display SAS
        opacity, float, opacity of surface, range 0.0-1.0
    Return:
    ----
        viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.
    """
    assert style in ('line', 'stick', 'sphere', 'carton')
    mblock = Chem.MolToMolBlock(mol)
    viewer = py3Dmol.view(width=size[0], height=size[1])
    viewer.addModel(mblock, 'mol')
    viewer.setStyle({style:{}})
    if surface:
        viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})
    viewer.zoomTo()
    return viewer

In [6]:
def repeat_data(data: Data, num_repeat) -> Batch:
    datas = [deepcopy(data) for i in range(num_repeat)]
    return Batch.from_data_list(datas)

def repeat_batch(batch: Batch, num_repeat) -> Batch:
    datas = batch.to_data_list()
    new_data = []
    for i in range(num_repeat):
        new_data += deepcopy(datas)
    return Batch.from_data_list(new_data)

In [7]:
from e3moldiffusion import chem
from geom.data import GeomDataModule, MolFeaturization
from geom.train_coordsatomsbonds import Trainer
from geom.dataset_infos import get_dataset_info
from evaluation.diffusion_analyze import check_stability
from evaluation.rdkit_functions import BasicMolecularMetrics

In [8]:
def compute_distances(pos):
    d = torch.cdist(pos, pos)
    ids = torch.triu_indices(pos.size(0), pos.size(0), 1)
    d = torch.triu(d, diagonal=1)
    d = d[ids[0, :], ids[1, :]]
    return d

In [9]:
def compute_mmff_energy(mol):
    ff = AllChem.MMFFGetMoleculeForceField(mol, AllChem.MMFFGetMoleculeProperties(mol), confId=0)
    ff.Initialize()
    out = ff.CalcEnergy()
    return out

In [10]:
!nvidia-smi

Mon Apr 17 15:55:14 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 495.29.05    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  On   | 00000000:17:00.0 Off |                    0 |
| N/A   34C    P0    44W / 250W |      0MiB / 40536MiB |      0%   E. Process |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [11]:
!echo "checking if the compute node has access to the hpfs directory where data is stored."
!ls /hpfs/projects/mlcs

checking if the compute node has access to the hpfs directory where data is stored.
2022_deepfrag_collab  e3moldiffusion	   potency_prediction
cct		      mdse		   ProteinMassSpecPred
cct-working	      methods_development  README-PROJECTS.txt
CrystalMD	      mixmd		   sms_workspace
del		      pdm


In [12]:
device = torch.device("cuda:0")
# device = "cpu"

In [13]:
datapath = "/hpfs/projects/mlcs/e3moldiffusion/"
test_info_df = os.path.join(datapath, "qm9", "test_info.csv")
test_info_df = pd.read_csv(test_info_df)
smiles_to_mol_id = os.path.join(datapath, "qm9", "smiles_to_mol_id.csv")
smiles_to_mol_id = pd.read_csv(smiles_to_mol_id)
smiles_to_mol_id = {i:s for i, s in zip(smiles_to_mol_id.mol_id, smiles_to_mol_id.smiles)}
#test_info_df["smiles"] = test_info_df["mol_id"].map(lambda x: smiles_to_mol_id.get(x))

In [14]:
#dataset = "qm9"
dataset = "drugs"

In [15]:
dataset_info = get_dataset_info(dataset_name=dataset, remove_h=False)

In [16]:
db_path = "/hpfs/projects/mlcs/e3moldiffusion"
with open(os.path.join(db_path, "drugs_atom_types.pickle"), "rb") as f:
    atom_types_dict = pickle.load(f)

with open(os.path.join(db_path, "drugs_num_nodes.pickle"), "rb") as f:
    num_nodes_dict = pickle.load(f)

In [17]:
max_num_nodes = max(num_nodes_dict.keys())

In [18]:
empirical_distribution_num_nodes = {i: num_nodes_dict.get(i) for i in range(max_num_nodes)}
empirical_distribution_num_nodes_tensor = {}
for key, value in empirical_distribution_num_nodes.items():
    if value is None:
        value = 0
    empirical_distribution_num_nodes_tensor[key] = value
# print(empirical_distribution_num_nodes_tensor)

In [19]:
empirical_distribution_num_nodes_tensor = torch.tensor(list(empirical_distribution_num_nodes_tensor.values())).float()

In [20]:
integer_to_el = {val: key for key, val in dataset_info["atom_encoder"].items()}

In [21]:
integer_to_el

{0: 'H',
 1: 'B',
 2: 'C',
 3: 'N',
 4: 'O',
 5: 'F',
 6: 'Al',
 7: 'Si',
 8: 'P',
 9: 'S',
 10: 'Cl',
 11: 'As',
 12: 'Br',
 13: 'I',
 14: 'Hg',
 15: 'Bi'}

In [22]:
get_atomic_number = lambda x: Chem.GetPeriodicTable().GetAtomicNumber(integer_to_el.get(x))

In [24]:
run = 2

In [25]:
os.listdir(f"logs/{dataset}/run{str(run)}/")

['lightning_logs',
 'last-v1.ckpt',
 'epoch=1-step=101753-v1.ckpt',
 'epoch=1-step=101753.ckpt',
 'last.ckpt']

In [36]:
model = Trainer.load_from_checkpoint(f'logs/{dataset}/run{str(run)}/last.ckpt', strict=False).to(device)
model = model.eval()

In [37]:
sum(m.numel() for m in model.parameters() if m.requires_grad)

686432

In [38]:
print(model.hparams.fully_connected,
      model.hparams.local_global_model,
      model.hparams.cutoff_local,
      model.hparams.cutoff_global,
      model.hparams.num_layers)

False True 5.0 10.0 5


In [39]:
num_graphs = 100

In [40]:
100**2

10000

In [41]:
def generate_graphs(num_graphs: int = 100, verbose=False):
    steps = model._hparams["num_diffusion_timesteps"]
    with torch.no_grad():
        pos, atom_types, edge_types, batch_num_nodes, trajs = model.reverse_sampling(num_graphs=num_graphs,
                                                                                     device=device,
                                                                                     empirical_distribution_num_nodes=empirical_distribution_num_nodes_tensor,
                                                                                     verbose=verbose,
                                                                                     save_traj=True)
        
    pos_splits = pos.detach().cpu().split(batch_num_nodes.cpu().tolist(), dim=0)
    atom_types_split = atom_types.detach().cpu().split(batch_num_nodes.cpu().tolist(), dim=0)

    atom_types_integer = torch.argmax(atom_types, dim=-1)
    atom_types_integer_split = atom_types_integer.detach().cpu().split(batch_num_nodes.cpu().tolist(), dim=0)
    
    return pos_splits, atom_types_split, atom_types_integer_split, edge_types, trajs

In [42]:
pos_splits, atom_types_split, atom_types_integer_split, edge_types, trajs = generate_graphs(num_graphs=100, verbose=True)

range(0, 300)


100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [00:09<00:00, 32.61it/s]


In [43]:
trajs[-1]

[tensor([[-0.0517,  0.0115,  2.2417, -0.1671,  1.2384],
         [ 2.1482,  0.3406,  1.4604,  0.1913,  0.3228],
         [-0.6322,  0.2937, -1.6957, -0.4924, -0.4617],
         ...,
         [ 1.3650,  0.7329,  1.2621,  2.9053, -2.2012],
         [ 3.3432,  1.3959, -0.4638,  0.0307, -2.3568],
         [ 0.3112,  0.0834,  1.2971,  0.0709, -0.9850]], device='cuda:0'),
 tensor([[-1.0223, -1.7970,  3.4896, -0.6731,  2.3610],
         [ 2.8121,  0.2732,  1.9619,  1.0704,  1.3491],
         [-1.0121,  0.4078, -2.4062, -0.8746, -0.2162],
         ...,
         [ 0.1325,  0.4784,  3.3406,  4.0479, -2.0342],
         [ 3.9129,  0.9970,  0.3734,  0.0469, -2.2721],
         [-0.8055, -0.7029,  1.7638,  0.9144, -0.8742]], device='cuda:0'),
 tensor([[-1.7686, -2.6804,  4.6947, -1.0301,  2.9108],
         [ 3.0793,  1.0419,  2.1482,  1.9835,  1.6791],
         [-2.4836,  0.7435, -2.8849, -1.2514, -0.4917],
         ...,
         [-0.4413,  1.3631,  4.7512,  5.9297, -2.7957],
         [ 4.2594,  1.15

In [44]:
edge_types

tensor([[-6.6809e+01, -8.7443e+01,  3.0257e+02, -7.3789e-02,  1.4163e+02],
        [ 1.4677e+02,  6.5421e+01,  8.3243e+01,  1.0104e+02,  5.7417e+01],
        [-1.9047e+02, -2.5807e+01, -1.6747e+02, -8.4243e+01, -6.7821e+01],
        ...,
        [-6.6667e+01,  5.3999e+01,  2.3168e+02,  2.4587e+02, -1.7320e+02],
        [ 1.8684e+02,  5.6573e+01, -6.4324e+01,  1.9058e+01, -2.2659e+02],
        [-1.2624e+02, -5.2638e+01,  1.6326e+02,  2.7233e+01, -4.9321e+01]],
       device='cuda:0')

In [45]:
i = 0

pos_np = pos_splits[i].detach().cpu().numpy()
numbers = atom_types_integer_split[i].detach().cpu().numpy()
numbers = np.array([get_atomic_number(a) for a in numbers])

atoms = Atoms(positions=pos_np, numbers=numbers)
view(atoms, viewer='x3d')

In [None]:
pos_splits_list = []
xohes_splits_list = []
for i in tqdm(range(100), total=100):
    pos_splits, xohes_integer_split = generate_graphs(num_graphs=100, verbose=False)
    pos_splits_list.append(pos_splits)
    xohes_splits_list.append(xohes_integer_split)

In [None]:
pos_splits_list = [item for sublist in pos_splits_list for item in sublist]
xohes_splits_list = [item for sublist in xohes_splits_list for item in sublist]

In [None]:
print(len(pos_splits_list), len(xohes_splits_list))

In [None]:
processed_list = [(a, b) for a, b in zip(pos_splits_list, xohes_splits_list)]

In [None]:
processed_list[0][1]

In [None]:
use_rdkit=True
molecule_stable = 0
nr_stable_bonds = 0
n_atoms = 0
n_molecules = 0

for mol in tqdm(processed_list, total=len(processed_list)):
    pos, atom_type = mol
    validity_results = check_stability(pos, atom_type, dataset_info)

    molecule_stable += int(validity_results[0])
    nr_stable_bonds += int(validity_results[1])
    n_atoms += int(validity_results[2])
    n_molecules += 1

# Validity
fraction_mol_stable = molecule_stable / float(n_molecules)
fraction_atm_stable = nr_stable_bonds / float(n_atoms)
validity_dict = {
    "mol_stable": fraction_mol_stable,
    "atm_stable": fraction_atm_stable,
}

In [None]:
if use_rdkit:
    metrics = BasicMolecularMetrics(dataset_info)
    rdkit_metrics = metrics.evaluate(processed_list)
    # print("Unique molecules:", rdkit_metrics[1])

In [None]:
validity_dict

## Results

Epoch End 0:   
Validity over 10000 molecules: 71.42%  
Uniqueness over 7142 valid molecules: 99.99%  
{'mol_stable': 0.0, 'atm_stable': 0.7456064817881325}  


Epoch End 1:  
Validity over 10000 molecules: 79.19%  
Uniqueness over 7919 valid molecules: 100.00%  
{'mol_stable': 0.0009, 'atm_stable': 0.8086492992130017}  


Epoch End 2:   
Validity over 10000 molecules: 67.77%  
Uniqueness over 6777 valid molecules: 100.00%  
{'mol_stable': 0.0036, 'atm_stable': 0.8304651868569646}  


Epoch End 3:  
Validity over 10000 molecules: 82.24%  
Uniqueness over 8224 valid molecules: 99.99%  
{'mol_stable': 0.0012, 'atm_stable': 0.8334680141897178}

Epoch End 4:  
Validity over 10000 molecules: 86.94%  
Uniqueness over 8694 valid molecules: 99.99%  
{'mol_stable': 0.0019, 'atm_stable': 0.8298923506487436}  

Epoch End 9:  
Validity over 10000 molecules: 85.84%  
Uniqueness over 8584 valid molecules: 100.00%  
{'mol_stable': 0.0064, 'atm_stable': 0.8536572844069668}  