In [14]:
import os
import os.path as osp
import torch
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm import tqdm
import pickle
from edm_qm9_utils.property_prediction import main_qm9_prop
import utils_yy.utils as utils_yy

In [15]:
import faulthandler
import sys
faulthandler.enable(file=sys.stderr, all_threads=False)
try:
    import torch_sparse
except Exeption as e:
    faulthandler.dump_traceback_later(1)
    quit()

In [18]:
def evaluate(condition='alpha'):
    qm9_atom_list = ['H', 'C', 'O', 'N', 'F']
    atom_encoder = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
    log_dir = '../AE_geom_uncond_weights_and_data/job21_latent_ddpm_qm9_spatial_graph_condition_' + condition

    def get_classifier(dir_path='', device='cpu'):
        with open(osp.join(dir_path, 'args.pickle'), 'rb') as f:
            args_classifier = pickle.load(f)
        args_classifier.device = device
        args_classifier.model_name = 'egnn'
        classifier = main_qm9_prop.get_model(args_classifier)
        classifier_state_dict = torch.load(osp.join(dir_path, 'best_checkpoint.npy'), map_location=torch.device('cpu'))
        classifier.load_state_dict(classifier_state_dict)
        return classifier

    model = get_classifier('../e3_diffusion_for_molecules/qm9/property_prediction/checkpoints/QM9/Property_Classifiers/exp_class_' + condition)

    data = np.load( os.path.join('../e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d_spatial_graphs/', 'valid.npz') )
    cond_mean = data[condition].mean()
    cond_mad = np.abs(data[condition] - cond_mean).mean() * 10

    data = np.load( os.path.join('../e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d_spatial_graphs/', 'train.npz') )
    np.random.seed(42)
    num_data = data['emb_2d'].shape[0]
    idx_perm = np.random.permutation(num_data)
    idx_train = idx_perm[num_data//2:]
    idx_holdout = idx_perm[:num_data//2]
    condition_train = torch.tensor((data[condition] - cond_mean) / cond_mad)[idx_train]
    cond_max, cond_min = condition_train.max().item(), condition_train.min().item()
    print(cond_max, cond_min, cond_mean, cond_mad)

#     condition_list = torch.tensor( np.concatenate([np.linspace(cond_min, cond_max, 100) for _ in range(100)]
#                    + [np.linspace(cond_min * 1.5, cond_min, 100) for _ in range(100)]
#                    + [np.linspace(cond_max, cond_max * 1.5, 100) for _ in range(100)] ), dtype=torch.float32 )
#     condition_list = condition_list * cond_mad + cond_mean

    pred_max = (data[condition][idx_train].max().item() - cond_mean) / cond_mad * 10
    pred_min = (data[condition][idx_train].min().item() - cond_mean) / cond_mad * 10
    print(pred_max, pred_min)
    v_max, v_min = data[condition][idx_train].max().item(), data[condition][idx_train].min().item()
    scale = v_max - v_min
    print(v_max, v_min)
    print('random baseline', scale / 3)

    mol_list = torch.load( osp.join(log_dir, 'sample_conformer.pt') )
    condition_list = torch.load( osp.join(log_dir, 'condition.pt') )

    print("test1")
    pred_list = []
    label_list = []

    print(condition_list.size())
    # for mol, cond in zip(tqdm(mol_list[:3000]), condition_list):
    #     print("test2")
    # for mol, cond in zip(tqdm(mol_list), condition_list):
    for mol, cond in zip(tqdm(mol_list[:100000]), condition_list):
    # for mol, cond in zip(tqdm(mol_list[:3000]), condition_list):
        # print("test3")
        # filter out molecules with atoms not in qm9
        atom_list = []
        for atom in mol.GetAtoms():
            atom_list.append(atom.GetSymbol())
        if len(atom_list) > 29:
            continue
        if len(set(atom_list).difference(qm9_atom_list)) > 0:
            continue

        # featurization
        # num_nodes_max = 29
        # nodes = torch.cat([ torch.nn.functional.one_hot(atom_encoder[atom_list[idx]], num_classes=5).unsqueeze(dim=0) if idx < len(atom_list)
        #                     else torch.zeros((1, 5)) for idx in range(num_nodes_max) ], dim=0)
        nodes = torch.cat([ torch.nn.functional.one_hot(torch.tensor(atom_encoder[atom],
            dtype=torch.int64), num_classes=5).unsqueeze(dim=0) for atom in atom_list ], dim=0).float()

    #     atom_positions = torch.zeros((29, 3))
    #     atom_positions[:len(atom_list)] = torch.tensor(mol.GetConformer().GetPositions())
        atom_positions = torch.tensor(mol.GetConformer().GetPositions(), dtype=torch.float32)

        _, edge_index = utils_yy.construct_complete_graph(len(atom_list), return_index=True, add_self_loop=False)
        edges = [edge_index[0], edge_index[1]]

        atom_mask = torch.ones((len(atom_list), 1))
        edge_mask = torch.ones((edge_index.shape[1], 1))

        n_nodes = len(atom_list)

        with torch.no_grad():
            pred = model(h0=nodes, x=atom_positions, edges=edges, edge_attr=None, node_mask=atom_mask, edge_mask=edge_mask, n_nodes=n_nodes)
        pred[pred>pred_max] = pred_max
        pred[pred<pred_min] = pred_min

    #     print(len(atom_list), pred.item() * cond_mad / 10 + cond_mean, cond.item())
        pred_list.append(pred.item() * cond_mad / 10 + cond_mean)
        label_list.append(cond.item())

    print(len(pred_list))
    print(np.abs((np.array(pred_list) - np.array(label_list))).mean())

In [17]:
evaluate('alpha')

0.9102288484573364 -1.0507861375808716 75.37342 62.727723121643066
9.102288057650055 -10.507861181150194
132.47000122070312 9.460000038146973
random baseline 41.003333727518715
test1
torch.Size([3000])


100%|██████████| 3000/3000 [00:13<00:00, 229.51it/s] 


1600
34.47611467440916


In [19]:
evaluate('gap')

0.947828471660614 -0.5386840105056763 0.25221726 0.390242263674736
9.478284826116898 -5.386839999595246
0.6220999956130981 0.041999999433755875
random baseline 0.1933666653931141
test1
torch.Size([3000])


100%|██████████| 3000/3000 [00:16<00:00, 187.42it/s]

2141
0.24559635314623593





In [8]:
evaluate('alpha')

0.9102288484573364 -1.0507861375808716 75.37342 62.727723121643066
9.102288057650055 -10.507861181150194
132.47000122070312 9.460000038146973
random baseline 41.003333727518715
test1
torch.Size([3000])


100%|██████████| 3000/3000 [00:09<00:00, 326.41it/s] 

1600
34.47611467440916





In [9]:
evaluate('gap')

0.947828471660614 -0.5386840105056763 0.25221726 0.390242263674736
9.478284826116898 -5.386839999595246
0.6220999956130981 0.041999999433755875
random baseline 0.1933666653931141
test1
torch.Size([3000])


100%|██████████| 3000/3000 [00:16<00:00, 183.57it/s]

2141
0.24559635314623593





In [10]:
evaluate('homo')

0.7526819705963135 -1.1657202243804932 -0.24028876 0.1615406945347786
7.526819599545699 -11.657201893308272
-0.11869999766349792 -0.428600013256073
random baseline 0.10330000519752502
test1
torch.Size([3000])


100%|██████████| 3000/3000 [00:12<00:00, 244.63it/s]

1698
0.10721347864300645





In [11]:
evaluate('lumo')

0.47794172167778015 -0.4841439723968506 0.011928127 0.37990380078554153
4.77941704545956 -4.841439664678198
0.19349999725818634 -0.1720000058412552
random baseline 0.12183333436648051
test1
torch.Size([3000])


100%|██████████| 3000/3000 [00:15<00:00, 198.98it/s]

2002
0.12984439249141988





In [12]:
evaluate('mu')

1.9160064458847046 -0.22752515971660614 2.6750875 11.757326126098633
19.160064322765635 -2.2752515523038177
25.202199935913086 0.0
random baseline 8.40073331197103
test1
torch.Size([3000])


100%|██████████| 3000/3000 [00:10<00:00, 275.60it/s] 

1456
7.654095210307646





In [13]:
evaluate('Cv')

0.4776511788368225 -0.7890406250953674 31.620028 32.11751937866211
4.776511788526875 -7.890406281196496
46.96099853515625 6.2779998779296875
random baseline 13.560999552408854
test1
torch.Size([3000])


100%|██████████| 3000/3000 [00:09<00:00, 304.45it/s] 


1533
25.521800447719922


In [2]:
def evaluate_ood(condition='alpha'):
    qm9_atom_list = ['H', 'C', 'O', 'N', 'F']
    atom_encoder = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
    log_dir = 'logs/job8_latent_ddpm_qm9_condition_' + condition

    def get_classifier(dir_path='', device='cpu'):
        with open(osp.join(dir_path, 'args.pickle'), 'rb') as f:
            args_classifier = pickle.load(f)
        args_classifier.device = device
        args_classifier.model_name = 'egnn'
        classifier = main_qm9_prop.get_model(args_classifier)
        classifier_state_dict = torch.load(osp.join(dir_path, 'best_checkpoint.npy'), map_location=torch.device('cpu'))
        classifier.load_state_dict(classifier_state_dict)
        return classifier

    model = get_classifier('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules_official/qm9/property_prediction/checkpoints/QM9/Property_Classifiers/exp_class_' + condition)

    data = np.load( os.path.join('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d_4layer_new/', 'valid.npz') )
    cond_mean = data[condition].mean()
    cond_mad = np.abs(data[condition] - cond_mean).mean() * 10

    data = np.load( os.path.join('/scratch/user/yuning.you/project/graph_latent_diffusion/e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d_4layer_new/', 'train.npz') )
    np.random.seed(42)
    num_data = data['emb_2d'].shape[0]
    idx_perm = np.random.permutation(num_data)
    idx_train = idx_perm[num_data//2:]
    idx_holdout = idx_perm[:num_data//2]
    condition_train = torch.tensor((data[condition] - cond_mean) / cond_mad)[idx_train]
    cond_max, cond_min = condition_train.max().item(), condition_train.min().item()
    print(cond_max, cond_min, cond_mean, cond_mad)

#     condition_list = torch.tensor( np.concatenate([np.linspace(cond_min, cond_max, 100) for _ in range(100)]
#                    + [np.linspace(cond_min * 1.5, cond_min, 100) for _ in range(100)]
#                    + [np.linspace(cond_max, cond_max * 1.5, 100) for _ in range(100)] ), dtype=torch.float32 )
#     condition_list = condition_list * cond_mad + cond_mean

    pred_max = (data[condition][idx_train].max().item() - cond_mean) / cond_mad * 10 * 1.5
    pred_min = (data[condition][idx_train].min().item() - cond_mean) / cond_mad * 10
    print(pred_max, pred_min)
    v_max, v_min = data[condition][idx_train].max().item(), data[condition][idx_train].min().item()
    scale = (v_max - v_min) * 1.5
    baseline_factor = (0.25 ** 3 / 3 - 0.25 ** 2 / 2 + 0.25 / 2) / 0.25
    print(v_max, v_min)
    print('random baseline', scale * baseline_factor)

    mol_list = torch.load( osp.join(log_dir, 'sample_conformer.pt') )
    condition_list = torch.load( osp.join(log_dir, 'condition.pt') )
    # for mol, cond in zip(tqdm(mol_list), condition_list):

    pred_list = []
    label_list = []

#     for mol, cond in zip(tqdm(mol_list[-100000:]), condition_list[-100000:]):
    for mol, cond in zip(tqdm(mol_list[-100000:-90000]), condition_list[-100000:-90000]):
        # filter out molecules with atoms not in qm9
        atom_list = []
        for atom in mol.GetAtoms():
            atom_list.append(atom.GetSymbol())
#         if len(atom_list) > 29:
#             continue
        if len(set(atom_list).difference(qm9_atom_list)) > 0:
            continue

        # featurization
        # num_nodes_max = 29
        # nodes = torch.cat([ torch.nn.functional.one_hot(atom_encoder[atom_list[idx]], num_classes=5).unsqueeze(dim=0) if idx < len(atom_list)
        #                     else torch.zeros((1, 5)) for idx in range(num_nodes_max) ], dim=0)
        nodes = torch.cat([ torch.nn.functional.one_hot(torch.tensor(atom_encoder[atom],
            dtype=torch.int64), num_classes=5).unsqueeze(dim=0) for atom in atom_list ], dim=0).float()

        AllChem.MMFFOptimizeMolecule(mol)
    #     atom_positions = torch.zeros((29, 3))
    #     atom_positions[:len(atom_list)] = torch.tensor(mol.GetConformer().GetPositions())
        atom_positions = torch.tensor(mol.GetConformer().GetPositions(), dtype=torch.float32)

        _, edge_index = utils_yy.construct_complete_graph(len(atom_list), return_index=True, add_self_loop=False)
        edges = [edge_index[0], edge_index[1]]

        atom_mask = torch.ones((len(atom_list), 1))
        edge_mask = torch.ones((edge_index.shape[1], 1))

        n_nodes = len(atom_list)

        with torch.no_grad():
            pred = model(h0=nodes, x=atom_positions, edges=edges, edge_attr=None, node_mask=atom_mask, edge_mask=edge_mask, n_nodes=n_nodes)
        pred[pred>pred_max] = pred_max
        pred[pred<pred_min] = pred_min

    #     print(len(atom_list), pred.item() * cond_mad / 10 + cond_mean, cond.item())
        pred_list.append(pred.item() * cond_mad / 10 + cond_mean)
        label_list.append(cond.item())

    print(len(pred_list))
    print(np.abs((np.array(pred_list) - np.array(label_list))).mean())

In [11]:
evaluate_ood('alpha')

0.9102288484573364 -1.0507861375808716 75.37342 62.727723121643066
13.653432086475082 -10.507861181150194
132.47000122070312 9.460000038146973
random baseline 73.03718820214272


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [05:53<00:00, 28.28it/s]

9985
32.064227763569484





In [12]:
evaluate_ood('gap')

0.947828471660614 -0.5386840105056763 0.25221726 0.390242263674736
14.217427239175347 -5.386839999595246
0.6220999956130981 0.041999999433755875
random baseline 0.3444343727314845


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [01:25<00:00, 117.06it/s]

9381
0.36313442086942527





In [13]:
evaluate_ood('homo')

0.7526819705963135 -1.1657202243804932 -0.24028876 0.1615406945347786
11.290229399318548 -11.657201893308272
-0.11869999766349792 -0.428600013256073
random baseline 0.18400313425809145


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [01:22<00:00, 120.86it/s]

5936
0.10930117979748938





In [3]:
evaluate_ood('lumo')

0.47794172167778015 -0.4841439723968506 0.011928127 0.37990380078554153
7.16912556818934 -4.841439664678198
0.19349999725818634 -0.1720000058412552
random baseline 0.2170156268402934


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [09:01<00:00, 18.46it/s]

9472
0.17869909136619652





In [4]:
evaluate_ood('mu')

1.9160064458847046 -0.22752515971660614 2.6750875 11.757326126098633
28.740096484148452 -2.2752515523038177
25.202199935913086 0.0
random baseline 14.963806211948395


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:37<00:00, 264.70it/s]

1579
22.18587002372699





In [5]:
evaluate_ood('Cv')

0.4776511788368225 -0.7890406250953674 31.620028 32.11751937866211
7.164767682790313 -7.890406281196496
46.96099853515625 6.2779998779296875
random baseline 24.15553045272827


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [04:17<00:00, 38.78it/s]

9991
31.126622606334703



