In [1]:
from frozen_conf_xtb2 import calculate_XTB_energy
import torch
import numpy as np
from edm_qm9_utils.analyze import check_stability_yy
from edm_qm9_utils.rdkit_functions import build_molecule
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import QED
from rdkit.Geometry import Point3D
import py3Dmol
from tqdm import tqdm
import pdb
from scipy.stats import wasserstein_distance
import matplotlib.pyplot as plt
from fcd_torch import FCD as FCDMetric
from e3fp.pipeline import fprints_from_mol, confs_from_smiles
from e3fp.fingerprint.metrics.fprint_metrics import tanimoto
import os
# from frozen_conf_xtb.utils import calculate_XTB_energy

import random
import utils_yy.utils as utils_yy

In [2]:
qm9_atom_list = ['H', 'C', 'O', 'N', 'F']
# log_dir_ld = './logs/job6_latent_ddpm_qm9/'
log_dir_ld='../AE_geom_uncond_weights_and_data/job17_latent_ddpm_qm9_spatial_graphs/'

### 2D valid and unique

In [3]:
smiless = torch.load(log_dir_ld + 'sample_smiles.pt')

res_list = []

for nrun in range(3):
    num_mol = 0
    smiles_sample = []
    for smi in tqdm(smiless[nrun*10000:(nrun+1)*10000]):
        try:
            mol = Chem.MolFromSmiles(smi)
            Chem.SanitizeMol(mol)
            atom_list = []
            for atom in mol.GetAtoms():
                atom_list.append(atom.GetSymbol())
            if len(set(atom_list).difference(qm9_atom_list)) > 0:
                continue

            mol = Chem.RemoveHs(mol)
            smi = Chem.CanonSmiles( Chem.MolToSmiles(mol) )
            smiles_sample.append(smi)
        except:
            None

        num_mol += 1

    valid = len(smiles_sample) / num_mol
    unique = len(set(smiles_sample)) / len(smiles_sample)
    valid_and_unique = valid * unique
#     print(valid, unique, valid_and_unique, num_mol)
    res_list.append([valid, unique, valid_and_unique])

res_list = np.array(res_list)
print(res_list.mean(axis=0), res_list.std(axis=0))

100%|██████████| 10000/10000 [00:05<00:00, 1863.03it/s]
0it [00:00, ?it/s]


ZeroDivisionError: division by zero

### 3D stable

In [6]:
mols = torch.load(log_dir_ld + 'sample_conformer.pt')

res_list = []

for nrun in range(3):
    valid_num = 0
    total_num = 0
    atom_valid = 0
    for mol in tqdm(mols[nrun*10000:(nrun+1)*10000]):    
        atom_list = []
        for atom in mol.GetAtoms():
            atom_list.append(atom.GetSymbol())
        if len(set(atom_list).difference(qm9_atom_list)) > 0:
            continue

        AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
        positions = mol.GetConformers()[0].GetPositions()

        out = check_stability_yy(positions, atom_list)

        if out[0]:
            valid_num += 1
        total_num += 1

        atom_valid += out[1] / out[2]

#     print(valid_num / total_num, atom_valid / total_num, valid_num, total_num)
    res_list.append([valid_num / total_num, atom_valid / total_num])

res_list = np.array(res_list)
print(res_list.mean(axis=0), res_list.std(axis=0))

100%|██████████| 10000/10000 [00:32<00:00, 307.41it/s]
0it [00:00, ?it/s]


ZeroDivisionError: division by zero

In [5]:
mols = torch.load(log_dir_ld + 'sample_conformer.pt')

res_list = []

for nrun in range(3):
    valid_num = 0
    total_num = 0
    atom_valid = 0
    for mol in tqdm(mols[nrun*10000:(nrun+1)*10000]):    
        atom_list = []
        for atom in mol.GetAtoms():
            atom_list.append(atom.GetSymbol())
        if len(set(atom_list).difference(qm9_atom_list)) > 0:
            continue

#         AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
        positions = mol.GetConformers()[0].GetPositions()

        out = check_stability_yy(positions, atom_list)

        if out[0]:
            valid_num += 1
        total_num += 1

        atom_valid += out[1] / out[2]

#     print(valid_num / total_num, atom_valid / total_num, valid_num, total_num)
    res_list.append([valid_num / total_num, atom_valid / total_num])

res_list = np.array(res_list)
print(res_list.mean(axis=0), res_list.std(axis=0))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:16<00:00, 623.36it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:16<00:00, 620.27it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:16<00:00, 610.79it/s]

[0.00229521 0.3261557 ] [0.00017812 0.00240093]





In [3]:
mols = torch.load(log_dir_ld + 'sample_conformer.pt')

res_list = []

for nrun in range(3):
    valid_num = 0
    total_num = 0
    atom_valid = 0
    for mol in tqdm(mols[nrun*10000:(nrun+1)*10000]):    
        atom_list = []
        for atom in mol.GetAtoms():
            atom_list.append(atom.GetSymbol())
        if len(set(atom_list).difference(qm9_atom_list)) > 0:
            continue

        try:
            AllChem.EmbedMolecule(mol)
            AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
            positions = mol.GetConformers()[0].GetPositions()
        except:
            total_num += 1
            continue

        out = check_stability_yy(positions, atom_list)

        if out[0]:
            valid_num += 1
        total_num += 1

        atom_valid += out[1] / out[2]

#     print(valid_num / total_num, atom_valid / total_num, valid_num, total_num)
    res_list.append([valid_num / total_num, atom_valid / total_num])

res_list = np.array(res_list)
print(res_list.mean(axis=0), res_list.std(axis=0))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [08:40<00:00, 19.22it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [07:41<00:00, 21.65it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [08:57<00:00, 18.61it/s]

[0.82011158 0.89047233] [0.00367722 0.00348876]





### 2D distribution

In [3]:
data = np.load('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d/train.npz')

stats = np.zeros(10, dtype=int)
for ncharge in tqdm(data['charges']):
    idx = len([n for n in ncharge if n > 1])
    stats[idx] += 1
print(stats)

100%|████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [00:00<00:00, 295250.22it/s]

[    0     2     5     7    25    91   475  2404 13625 83366]





In [3]:
with open('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules/data/smiles_qm9.txt', 'r') as f:
    smiles_train = [smi for smi in tqdm(f.read().split('\n')[:100000])]
    
with open('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules/data/smiles_qm9.txt', 'r') as f:
    smiles_test = [smi for smi in tqdm(f.read().split('\n')[-13084:-1])]

100%|███████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [00:00<00:00, 4933372.54it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 13083/13083 [00:00<00:00, 5293659.97it/s]


In [8]:
def measure_2d_distribution(smi_list1, smi_list2, _3fold=True):
    qed_list1 = []
    for smi in tqdm(smi_list1):
        mol = Chem.MolFromSmiles(smi)
        prop = QED.properties(mol)
        qed = QED.qed(mol)
        qed_list1.append([prop[0], prop[1], prop[4], qed])
    random.shuffle(qed_list1)
    qed_list1 = np.array(qed_list1)
    
    qed_list2 = []
    for smi in tqdm(smi_list2):
        mol = Chem.MolFromSmiles(smi)
        prop = QED.properties(mol)
        qed = QED.qed(mol)
        qed_list2.append([prop[0], prop[1], prop[4], qed])
    qed_list2 = np.array(qed_list2)
    
    # print results
    prop_name = ['MW', 'ALogP', 'PSA', 'QED']
    if _3fold:
        for nprop, pname in enumerate(prop_name):
            num_run = qed_list1.shape[0] // 3
            res_list = [[v for v in utils_yy.statistical_metric(qed_list1[:num_run, nprop],
                                                                qed_list2[:, nprop]).values()],
                        [v for v in utils_yy.statistical_metric(qed_list1[num_run:num_run*2, nprop],
                                                                qed_list2[:, nprop]).values()],
                        [v for v in utils_yy.statistical_metric(qed_list1[num_run*2:, nprop],
                                                                qed_list2[:, nprop]).values()]]

            key_list = [k for k in utils_yy.statistical_metric(qed_list1[:num_run, nprop],
                                                               qed_list2[:, nprop]).keys()]

            print(pname)
            print({k: [np.mean([res_list[0][n], res_list[1][n], res_list[2][n]]),
                       np.std([res_list[0][n], res_list[1][n], res_list[2][n]])]
                   for n, k in enumerate(key_list)})
        
        kwargs_fcd = {'n_jobs': 8, 'device': 'cpu', 'batch_size': 32}
        fcd = [ FCDMetric(**kwargs_fcd)(smi_list1[:num_run], smi_list2),
                FCDMetric(**kwargs_fcd)(smi_list1[num_run:num_run*2], smi_list2),
                FCDMetric(**kwargs_fcd)(smi_list1[num_run*2:], smi_list2) ]
        print('FCD', np.mean(fcd), np.std(fcd))
        
    else:
        for nprop, pname in enumerate(prop_name):
            print(pname)
            print(utils_yy.statistical_metric(qed_list1[:, nprop], qed_list2[:, nprop]))
        
        kwargs_fcd = {'n_jobs': 8, 'device': 'cpu', 'batch_size': 32}
        print('FCD', FCDMetric(**kwargs_fcd)(smi_list1, smi_list2))

In [8]:
measure_2d_distribution(smiles_train, smiles_test, False)

100%|███████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [02:20<00:00, 713.06it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 13083/13083 [00:18<00:00, 700.05it/s]


MW
{'tvd': 0.006152899946495427, 'hd': 0.027250938652787673, 'wd': 0.11524130402583446}
ALogP
{'tvd': 0.002624401131239011, 'hd': 0.030700064941371456, 'wd': 0.007492805439065966}
PSA
{'tvd': 0.0036061522586562714, 'hd': 0.029461469455488783, 'wd': 0.14610340198731134}
QED
{'tvd': 0.0030609500879003257, 'hd': 0.03105649267647254, 'wd': 0.0005733954490561486}
FCD 0.0298101927620813


In [6]:
log_dir_edm = '/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules_official/qm9/'
with open(log_dir_edm + 'generate_smiles.txt', 'r') as f:
    smiless = f.read().split('\n')[:-1]

measure_2d_distribution(smiless, smiles_test)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 27501/27501 [00:38<00:00, 705.72it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 13083/13083 [00:17<00:00, 732.51it/s]


MW
{'tvd': [0.02935428198405649, 0.004752470157714284], 'hd': [0.17141575439215453, 0.0008324221259389271], 'wd': [1.7714501981004018, 0.044561127593093615]}
ALogP
{'tvd': [0.008299265308101352, 0.0005354909008047986], 'hd': [0.08299838748179272, 0.0006707623465969748], 'wd': [0.17530174062687146, 0.005110986789145081]}
PSA
{'tvd': [0.023748301545992025, 0.00201212106210784], 'hd': [0.11327890115482468, 0.0024406406175092672], 'wd': [2.611887956084219, 0.09649659085273403]}
QED
{'tvd': [0.007445369333508467, 0.0002900820798815914], 'hd': [0.06541844392446715, 0.0021318058924504326], 'wd': [0.005765394189978828, 0.0005135423771765456]}
FCD 0.580426556853895 0.003947453378448896


In [None]:
smiless = torch.load(log_dir_ld + 'sample_smiles.pt')
smiless1 = []

stats_tar = stats / stats.sum() * 30000
stats_gen = np.zeros(stats.shape)

for smi in tqdm(smiless):
    try:
        mol = Chem.MolFromSmiles(smi)
        mol = Chem.RemoveHs(mol)
        atom_list = []
        for atom in mol.GetAtoms():
            atom_list.append(atom.GetSymbol())
        if len(atom_list) > 9:
            continue
        if len(set(atom_list).difference(qm9_atom_list)) > 0:
            continue
        
        num_atom = len(atom_list)
        if stats_gen[num_atom] >= stats_tar[num_atom]:
            continue
        
        stats_gen[num_atom] += 1
        smiless1.append(smi)
    except:
        None

measure_2d_distribution(smiless1, smiles_test)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [00:25<00:00, 3995.41it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 30004/30004 [00:40<00:00, 733.17it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 13083/13083 [00:17<00:00, 727.86it/s]


MW
{'tvd': [0.022450151808060984, 0.0012644474612296625], 'hd': [0.08116829252875644, 0.00167361944333405], 'wd': [0.5999122989745161, 0.027839864635935553]}
ALogP
{'tvd': [0.009378349441438382, 0.0015421847509588763], 'hd': [0.06997382708192325, 0.0011263225847387278], 'wd': [0.10555170970283016, 0.001756792043358978]}
PSA
{'tvd': [0.012172530110945755, 0.0001771692903032922], 'hd': [0.09651223941333957, 0.0030863856017845164], 'wd': [1.925151165889883, 0.044115261258926834]}
QED
{'tvd': [0.010089493592951167, 6.627141043087967e-05], 'hd': [0.11516015716413623, 0.0020413791495492743], 'wd': [0.018175715857238813, 0.0006957501238402865]}


In [16]:
smiless = torch.load(log_dir_ld + 'sample_smiles.pt')
smiless1 = []

stats_tar = stats / stats.sum() * 30000
stats_gen = np.zeros(stats.shape)

for smi in tqdm(smiless):
    try:
        mol = Chem.MolFromSmiles(smi)
        mol = Chem.RemoveHs(mol)
        atom_list = []
        for atom in mol.GetAtoms():
            atom_list.append(atom.GetSymbol())
        if len(atom_list) > 9:
            continue
        if len(set(atom_list).difference(qm9_atom_list)) > 0:
            continue
        
#         num_atom = len(atom_list)
#         if stats_gen[num_atom] >= stats_tar[num_atom]:
#             continue
        
        stats_gen[num_atom] += 1
        smiless1.append(smi)
        
        if len(smiless1) > 30000:
            break
    except:
        None

measure_2d_distribution(smiless1, smiles_test)

  1%|▋                                                                                                | 689/100000 [00:00<00:42, 2334.51it/s]

ERROR! Session/line number was not unique in database. History logging moved to new session 119


 44%|█████████████████████████████████████████▌                                                     | 43713/100000 [00:11<00:14, 3824.82it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 30001/30001 [00:43<00:00, 690.00it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 13083/13083 [00:19<00:00, 672.26it/s]


MW
{'tvd': [0.06503261255239606, 0.0025482457597286695], 'hd': [0.2591391819198361, 0.00309431895076475], 'wd': [5.4630820434476, 0.12172579652712512]}
ALogP
{'tvd': [0.008884374195768529, 0.0011423928813224448], 'hd': [0.06434184958920419, 0.0034629449615702183], 'wd': [0.07050763718947335, 0.006751313857472962]}
PSA
{'tvd': [0.01054676109881432, 0.00024042704314579822], 'hd': [0.09394433453907099, 0.00395845558034042], 'wd': [1.3967515074453392, 0.11232934920778169]}
QED
{'tvd': [0.011007745975892871, 0.0011655775891214397], 'hd': [0.07328189839646003, 0.0014192370764311085], 'wd': [0.007266553271805258, 0.0005465462436851693]}
FCD 0.4428336272238518 0.009556058317030347


### 3D distribution

In [4]:
data = np.load('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d/test.npz')
print([k for k in data.keys()])

dataset_info = {'name': 'qm9', 'atom_decoder': {1:'H', 6:'C', 8:'O', 7:'N', 9:'F'}}
energy_list = []
for idx, natoms in enumerate(tqdm(data['num_atoms'])):
    x = data['charges'][idx][:natoms]
    pos = data['positions'][idx][:natoms]
    
    geometry = []
    for i, p in zip(x, pos):
        geometry.append([dataset_info['atom_decoder'][i]] + p.tolist())
    
    energy = calculate_XTB_energy(geometry)
    energy_list.append(energy)

['num_atoms', 'charges', 'positions', 'index', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'omega1', 'emb_2d', 'emb_3d', 'zpve_thermo', 'U0_thermo', 'U_thermo', 'H_thermo', 'G_thermo', 'Cv_thermo']


100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 13083/13083 [06:49<00:00, 31.98it/s]


In [5]:
# data = np.load('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d/train.npz')
# print([k for k in data.keys()])

# dataset_info = {'name': 'qm9', 'atom_decoder': {1:'H', 6:'C', 8:'O', 7:'N', 9:'F'}}
# energy_train_list = []
# for idx, natoms in enumerate(tqdm(data['num_atoms'])):
#     x = data['charges'][idx][:natoms]
#     pos = data['positions'][idx][:natoms]
    
#     geometry = []
#     for i, p in zip(x, pos):
#         geometry.append([dataset_info['atom_decoder'][i]] + p.tolist())
    
#     energy = calculate_XTB_energy(geometry)
#     energy_train_list.append(energy)

# torch.save(energy_train_list, './qm9_training_data_energy.pt')


energy_train_list = torch.load('./qm9_training_data_energy.pt')

print(utils_yy.statistical_metric(np.array(energy_train_list), np.array(energy_list)))

{'tvd': 0.003885490330963845, 'hd': 0.029642486119174343, 'wd': 0.031156386598832922}


In [5]:
mol_dir = '/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules_official/outputs/edm_qm9/eval/analyzed_molecules/'
atom2idx = {'H':1, 'C':6, 'O':8, 'N':7, 'F':9}
energy_edm_list = []
for fn in tqdm(os.listdir(mol_dir)):
    x, pos = [], []
    try:
        with open(mol_dir + fn, 'r') as f:
            data = f.read().split('\n')
    except:
        continue
    
    natoms = int(data[0])
    for d in data[2:2+natoms]:
        d = d.split()
        x.append(d[0])
        pos.append([float(d[1]), float(d[2]), float(d[3])])
    
    geometry = []
    for i, p in zip(x, pos):
        geometry.append([i] + p)
    
    energy = calculate_XTB_energy(geometry)
    energy_edm_list.append(energy)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30000/30000 [28:52<00:00, 17.31it/s]


In [7]:
random.shuffle(energy_edm_list)
num_run = len(energy_edm_list) // 3
res_list = [[v for v in utils_yy.statistical_metric(np.array(energy_edm_list[:num_run]),
                                                    np.array(energy_list)).values()],
            [v for v in utils_yy.statistical_metric(np.array(energy_edm_list[num_run:num_run*2]),
                                                    np.array(energy_list)).values()],
            [v for v in utils_yy.statistical_metric(np.array(energy_edm_list[num_run*2:]),
                                                    np.array(energy_list)).values()]]

key_list = [k for k in utils_yy.statistical_metric(np.array(energy_edm_list[:num_run]),
                                                   np.array(energy_list)).keys()]

print({k: [np.mean([res_list[0][n], res_list[1][n], res_list[2][n]]),
           np.std([res_list[0][n], res_list[1][n], res_list[2][n]])]
       for n, k in enumerate(key_list)})

{'tvd': [0.02810979388009885, 0.002992548617162737], 'hd': [0.13542332908722657, 0.0035131040141895818], 'wd': [0.28568282655895605, 0.0045918625204167465]}


In [7]:
mols = torch.load(log_dir_ld + 'sample_conformer.pt')

stats_tar = stats / stats.sum() * 30000
stats_gen = np.zeros(stats.shape)

energy_ld_list = []
for mol in tqdm(mols):    
    atom_list = []
    for atom in mol.GetAtoms():
        atom_list.append(atom.GetSymbol())
#     if len(atom_list) > 29:
#         continue
    if len(set(atom_list).difference(qm9_atom_list)) > 0:
        continue
    
    AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
    positions = mol.GetConformers()[0].GetPositions()

    geometry = []
    for idx, a in enumerate(atom_list):
        geometry.append([a] + positions[idx].tolist())

    mol = Chem.RemoveHs(mol)
    num_atom = len(mol.GetAtoms())
    if num_atom > 9 or stats_gen[num_atom] >= stats_tar[num_atom]:
        continue
    stats_gen[num_atom] += 1

    energy = calculate_XTB_energy(geometry)
    energy_ld_list.append(energy)


random.shuffle(energy_ld_list)
num_run = len(energy_ld_list) // 3
res_list = [[v for v in utils_yy.statistical_metric(np.array(energy_ld_list[:num_run]),
                                                    np.array(energy_list)).values()],
            [v for v in utils_yy.statistical_metric(np.array(energy_ld_list[num_run:num_run*2]),
                                                    np.array(energy_list)).values()],
            [v for v in utils_yy.statistical_metric(np.array(energy_ld_list[num_run*2:]),
                                                    np.array(energy_list)).values()]]

key_list = [k for k in utils_yy.statistical_metric(np.array(energy_ld_list[:num_run]),
                                                   np.array(energy_list)).keys()]

print({k: [np.mean([res_list[0][n], res_list[1][n], res_list[2][n]]),
           np.std([res_list[0][n], res_list[1][n], res_list[2][n]])]
       for n, k in enumerate(key_list)})

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [16:48<00:00, 99.11it/s]

{'tvd': [0.018782709851518756, 0.0018861222341144625], 'hd': [0.08252747344333826, 0.0024023886386782692], 'wd': [0.25082174310225563, 0.0026850363893648897]}





In [9]:
mols = torch.load(log_dir_ld + 'sample_conformer.pt')

stats_tar = stats / stats.sum() * 30000
stats_gen = np.zeros(stats.shape)

energy_ld_list = []
for mol in tqdm(mols):    
    atom_list = []
    for atom in mol.GetAtoms():
        atom_list.append(atom.GetSymbol())
#     if len(atom_list) > 29:
#         continue
    if len(set(atom_list).difference(qm9_atom_list)) > 0:
        continue
    
#     AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
    positions = mol.GetConformers()[0].GetPositions()

    geometry = []
    for idx, a in enumerate(atom_list):
        geometry.append([a] + positions[idx].tolist())

    mol = Chem.RemoveHs(mol)
    num_atom = len(mol.GetAtoms())
    if num_atom > 9 or stats_gen[num_atom] >= stats_tar[num_atom]:
        continue
    stats_gen[num_atom] += 1

    try:
        energy = calculate_XTB_energy(geometry)
    except:
        energy = None
    energy_ld_list.append(energy)

energy_mean = np.mean([e for e in energy_ld_list if e is not None])
energy_ld_list = [e if e is not None else energy_mean for e in energy_ld_list]

random.shuffle(energy_ld_list)
num_run = len(energy_ld_list) // 3
res_list = [[v for v in utils_yy.statistical_metric(np.array(energy_ld_list[:num_run]),
                                                    np.array(energy_list)).values()],
            [v for v in utils_yy.statistical_metric(np.array(energy_ld_list[num_run:num_run*2]),
                                                    np.array(energy_list)).values()],
            [v for v in utils_yy.statistical_metric(np.array(energy_ld_list[num_run*2:]),
                                                    np.array(energy_list)).values()]]

key_list = [k for k in utils_yy.statistical_metric(np.array(energy_ld_list[:num_run]),
                                                   np.array(energy_list)).keys()]

print({k: [np.mean([res_list[0][n], res_list[1][n], res_list[2][n]]),
           np.std([res_list[0][n], res_list[1][n], res_list[2][n]])]
       for n, k in enumerate(key_list)})

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [07:14<00:00, 230.22it/s]


{'tvd': [0.7739614360899444, 0.04337661685806737], 'hd': [0.7251906559774723, 0.03224364822916765], 'wd': [1015.7221504497578, 10.111491648653107]}


In [5]:
mols = torch.load(log_dir_ld + 'sample_conformer.pt')

stats_tar = stats / stats.sum() * 30000
stats_gen = np.zeros(stats.shape)

energy_rand_list = []
for mol in tqdm(mols):    
    atom_list = []
    for atom in mol.GetAtoms():
        atom_list.append(atom.GetSymbol())
#     if len(atom_list) > 29:
#         continue
    if len(set(atom_list).difference(qm9_atom_list)) > 0:
        continue
    
    try:
        mol = Chem.RemoveHs(mol)
        num_atom = len(mol.GetAtoms())
        if num_atom > 9 or stats_gen[num_atom] >= stats_tar[num_atom]:
            continue
        stats_gen[num_atom] += 1
        mol = Chem.AddHs(mol)
        
        AllChem.EmbedMolecule(mol)
        AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
        positions = mol.GetConformers()[0].GetPositions()

        geometry = []
        for idx, a in enumerate(atom_list):
            geometry.append([a] + positions[idx].tolist())

        energy = calculate_XTB_energy(geometry)
    
    except:
        energy = None

    energy_rand_list.append(energy)
    if len(energy_rand_list) > 30000:
        break

energy_mean = np.mean([e for e in energy_rand_list if e is not None])
energy_rand_list = [e if e is not None else energy_mean for e in energy_rand_list]

random.shuffle(energy_rand_list)
num_run = len(energy_rand_list) // 3
res_list = [[v for v in utils_yy.statistical_metric(np.array(energy_rand_list[:num_run]),
                                                    np.array(energy_list)).values()],
            [v for v in utils_yy.statistical_metric(np.array(energy_rand_list[num_run:num_run*2]),
                                                    np.array(energy_list)).values()],
            [v for v in utils_yy.statistical_metric(np.array(energy_rand_list[num_run*2:]),
                                                    np.array(energy_list)).values()]]

key_list = [k for k in utils_yy.statistical_metric(np.array(energy_rand_list[:num_run]),
                                                   np.array(energy_list)).keys()]

print({k: [np.mean([res_list[0][n], res_list[1][n], res_list[2][n]]),
           np.std([res_list[0][n], res_list[1][n], res_list[2][n]])]
       for n, k in enumerate(key_list)})

 73%|██████████████████████████████████████████████████████████████████████▊                          | 72980/100000 [15:12<05:37, 79.98it/s]


{'tvd': [0.04053214469936226, 0.0061896470352884635], 'hd': [0.0979752837054297, 0.0018371192676422707], 'wd': [0.28276628173420465, 0.014545023241275275]}


In [6]:
mols = torch.load(log_dir_ld + 'sample_conformer.pt')

stats_tar = stats / stats.sum() * 30000
stats_gen = np.zeros(stats.shape)

energy_ld_list = []
for mol in tqdm(mols):    
    atom_list = []
    for atom in mol.GetAtoms():
        atom_list.append(atom.GetSymbol())
#     if len(atom_list) > 29:
#         continue
    if len(set(atom_list).difference(qm9_atom_list)) > 0:
        continue
    
    AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
    positions = mol.GetConformers()[0].GetPositions()

    geometry = []
    for idx, a in enumerate(atom_list):
        geometry.append([a] + positions[idx].tolist())

    mol = Chem.RemoveHs(mol)
    num_atom = len(mol.GetAtoms())
    if num_atom > 9:
        continue
#     if num_atom > 9 or stats_gen[num_atom] >= stats_tar[num_atom]:
#         continue
#     stats_gen[num_atom] += 1

    energy = calculate_XTB_energy(geometry)
    energy_ld_list.append(energy)
    
    if len(energy_ld_list) > 30000:
        break


random.shuffle(energy_ld_list)
num_run = len(energy_ld_list) // 3
res_list = [[v for v in utils_yy.statistical_metric(np.array(energy_ld_list[:num_run]),
                                                    np.array(energy_list)).values()],
            [v for v in utils_yy.statistical_metric(np.array(energy_ld_list[num_run:num_run*2]),
                                                    np.array(energy_list)).values()],
            [v for v in utils_yy.statistical_metric(np.array(energy_ld_list[num_run*2:]),
                                                    np.array(energy_list)).values()]]

key_list = [k for k in utils_yy.statistical_metric(np.array(energy_ld_list[:num_run]),
                                                   np.array(energy_list)).keys()]

print({k: [np.mean([res_list[0][n], res_list[1][n], res_list[2][n]]),
           np.std([res_list[0][n], res_list[1][n], res_list[2][n]])]
       for n, k in enumerate(key_list)})

 44%|██████████████████████████████████████████▍                                                      | 43713/100000 [09:55<12:46, 73.40it/s]


{'tvd': [0.03497220961029582, 0.003946201454430411], 'hd': [0.19081330710655867, 0.003650139484239148], 'wd': [1.060405469194736, 0.01896135130480013]}
