In [33]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../../src')
from rdkit import Chem
from dfs_transformer.utils import Smiles2Mol, Mol2Smiles, DFSCode2Graph, Graph2Mol, isValid, Smiles2DFSCode, DFSCode2Smiles, isValidMoleculeDFSCode
from dfs_transformer.utils import load_selfattn_wandb, load_selfattn_local, computeChemicalValidityAndNovelty, parseChempropAtomFeatures, parseChempropBondFeatures
from dfs_transformer.utils import FeaturizedDFSCodes2Nx, Mol2Nx, Nx2Mol
import os.path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import dfs_transformer as dfs
import numpy as np
from ml_collections import ConfigDict
import yaml
import functools
import tqdm
import traceback
from einops import rearrange


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
pretrained_model = "r2r-30-c4s-noprop"
pretrained_project = "pubchem_newencoding"
force_download = False
if os.path.isdir("../../wandb/artifacts/%s"%pretrained_model) and not force_download:
    bert, cfg = load_selfattn_local("../../wandb/artifacts/%s"%pretrained_model)
else:
    bert, cfg = load_selfattn_wandb(pretrained_model, wandb_dir="../../wandb", pretrained_project=pretrained_project)

In [3]:
bert = bert.eval()

In [4]:
with open("../../config/selfattn/data/pubchem10K.yaml") as file:
    d = ConfigDict(yaml.load(file, Loader=yaml.FullLoader))

In [5]:
d.no_features = cfg.data.no_features

In [6]:
if cfg.training.mode == "rnd2rnd":
    collate_fn = functools.partial(dfs.collate_BERT, 
                                   mode=cfg.training.mode, 
                                   fraction_missing = cfg.training.fraction_missing,
                                   use_loops=cfg.model.use_loops)
elif cfg.training.mode == "rnd2rnd_entry":
    collate_fn = functools.partial(dfs.collate_BERT_entries, 
                                   mode="rnd2rnd", 
                                   fraction_missing = cfg.training.fraction_missing,
                                   use_loops=cfg.model.use_loops)

In [7]:
validset = dfs.PubChem('../.'+d.valid_path, max_nodes=d.max_nodes, max_edges=d.max_edges, noFeatures=d.no_features,
                   molecular_properties=d.molecular_properties, useDists=d.useDists, useHs=d.useHs,
                   filter_unencoded=True)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 17.25it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9978/9978 [00:13<00:00, 736.18it/s]


In [8]:
validloader = DataLoader(validset, batch_size=50, shuffle=False, 
                         pin_memory=True, collate_fn=collate_fn, num_workers=4,
                         prefetch_factor=2)

In [53]:
device = torch.device('cuda:%d'%0 if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
to_cuda = functools.partial(dfs.utils_to_cuda, device=device)
to_cpu = functools.partial(dfs.utils_to_cuda, device='cpu')

In [62]:
def dict_collect_correct(pred, target, key):
    with torch.no_grad():
        errors = {key: torch.argmax(rearrange(prd, 'd0 d1 d2 -> (d0 d1) d2'), axis=1) for key, prd in pred.items()}
        tgt = rearrange(target[key], 'd0 d1 -> (d0 d1)')
        prd = rearrange(pred[key], 'd0 d1 d2 -> (d0 d1) d2')
        mask = tgt != -1
        errors = {key: val[mask][torch.argmax(prd[mask], dim=1) == tgt[mask]] for key, val in errors.items()}
    return errors

In [41]:
def dict_collect_errors(pred, target, key):
    with torch.no_grad():
        errors = {key: torch.argmax(rearrange(prd, 'd0 d1 d2 -> (d0 d1) d2'), axis=1) for key, prd in pred.items()}
        tgt = rearrange(target[key], 'd0 d1 -> (d0 d1)')
        prd = rearrange(pred[key], 'd0 d1 d2 -> (d0 d1) d2')
        mask = tgt != -1
        errors = {key: val[mask][torch.argmax(prd[mask], dim=1) != tgt[mask]] for key, val in errors.items()}
    return errors

In [None]:
def dict_collect_errors_that_copy_cannot_resolve(inputs, pred, target, key):
    raise NotImplementedError("not implemented...")
    with torch.no_grad():
        errors = {key: torch.argmax(rearrange(prd, 'd0 d1 d2 -> (d0 d1) d2'), axis=1) for key, prd in pred.items()}
        tgt = rearrange(target[key], 'd0 d1 -> (d0 d1)')
        prd = rearrange(pred[key], 'd0 d1 d2 -> (d0 d1) d2')
        mask = tgt != -1
        errors = {key: val[mask][torch.argmax(prd[mask], dim=1) != tgt[mask]] for key, val in errors.items()}
    return errors

In [51]:
bert = bert.to(device)

In [125]:
investigate_key = 'atomic_num_to'

In [126]:
errors_acc = {}
correct_acc = {}
for data in tqdm.tqdm(validloader):
    data = [to_cuda(d) for d in data]
    pred = to_cpu(bert(data[0]))
    target = to_cpu(data[1])
    errors = dict_collect_errors(pred, target, investigate_key)
    correct = dict_collect_correct(pred, target, investigate_key)
    if len(errors_acc) == 0:
        errors_acc.update(errors)
        correct_acc.update(correct)
    else:
        for key, val in errors.items():
            errors_acc[key] = torch.cat((errors_acc[key], val))
        for key, val in correct.items():
            correct_acc[key] = torch.cat((correct_acc[key], val))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 199/199 [00:13<00:00, 14.33it/s]


In [127]:
n_fwd = 0
n_bwd = 0
n_fwd_branch = 0
for dfs1, dfs2, atm1, atm2, bnd in zip(errors_acc['dfs_from'],
                                       errors_acc['dfs_to'],
                                       errors_acc['atomic_num_from'],
                                       errors_acc['atomic_num_to'],
                                       errors_acc['bond_type']):
    #print("dfs1 %3d, dfs2 %3d, atm1 %3d, atm2 %3d, bnd %3d"%(dfs1.item(), dfs2.item(), atm1.item(), atm2.item(), bnd.item()))
    if dfs1 == dfs2 - 1:
        n_fwd += 1
    elif dfs1 < dfs2:
        n_fwd_branch += 1
    else:
        n_bwd += 1
print('fwd', n_fwd, 'bwd', n_bwd, 'fwd branch', n_fwd_branch)

fwd 536 bwd 8 fwd branch 301


In [128]:
n_fwd2 = 0
n_bwd2 = 0
n_fwd_branch2 = 0
for dfs1, dfs2, atm1, atm2, bnd in zip(correct_acc['dfs_from'],
                                       correct_acc['dfs_to'],
                                       correct_acc['atomic_num_from'],
                                       correct_acc['atomic_num_to'],
                                       correct_acc['bond_type']):
    #print("dfs1 %3d, dfs2 %3d, atm1 %3d, atm2 %3d, bnd %3d"%(dfs1.item(), dfs2.item(), atm1.item(), atm2.item(), bnd.item()))
    if dfs1 == dfs2 - 1:
        n_fwd2 += 1
    elif dfs1 < dfs2:
        n_fwd_branch2 += 1
    else:
        n_bwd2 += 1
print('fwd', n_fwd2, 'bwd', n_bwd2, 'fwd branch', n_fwd_branch2)

fwd 23754 bwd 3647 fwd branch 6891


In [129]:
print('fwd', n_fwd/(n_fwd+n_fwd2), 'bwd', n_bwd/(n_bwd + n_bwd2), 'fwd branch', n_fwd_branch/(n_fwd_branch + n_fwd_branch2))

fwd 0.022066694112803622 bwd 0.002188782489740082 fwd branch 0.041852057842046715


In [130]:
(n_fwd+n_bwd+n_fwd_branch)/((n_fwd+n_bwd+n_fwd_branch) + (n_fwd2+n_bwd2+n_fwd_branch2))

0.024048723567749095