In [None]:
import torch
import pickle
from tqdm import tqdm
import sys
sys.path.append('.') 
from jtnn import Vocab, JTNNVAE
import numpy as np

In [2]:

vocab = Vocab([x.strip() for x in open('qvae/data/merged/vocab.txt')])

# Load test data
with open('qvae/data/merged/train_shuffled/chunk_test.pkl', 'rb') as f:
    test_data = pickle.load(f)[:500]


In [6]:
print(f"Testing on {len(test_data)} molecules\n")

model_path = 'qvae/var_model/model.epoch-5'
print(f"\n=== Testing {model_path} ===")

model = JTNNVAE(vocab, 450, 56, 3, stereo=True)
try:
    model.load_state_dict(torch.load(model_path))
except:
    print(f"Could not load {model_path}")
model = model.cuda()
model.eval()


Testing on 500 molecules


=== Testing qvae/var_model/model.epoch-5 ===


JTNNVAE(
  (embedding): Embedding(2279, 450)
  (jtnn): JTNNEncoder(
    (embedding): Embedding(2279, 450)
    (W_z): Linear(in_features=900, out_features=450, bias=True)
    (W_r): Linear(in_features=450, out_features=450, bias=False)
    (U_r): Linear(in_features=450, out_features=450, bias=True)
    (W_h): Linear(in_features=900, out_features=450, bias=True)
    (W): Linear(in_features=900, out_features=450, bias=True)
  )
  (jtmpn): JTMPN(
    (W_i): Linear(in_features=40, out_features=450, bias=False)
    (W_h): Linear(in_features=450, out_features=450, bias=False)
    (W_o): Linear(in_features=485, out_features=450, bias=True)
  )
  (mpn): MPN(
    (W_i): Linear(in_features=50, out_features=450, bias=False)
    (W_h): Linear(in_features=450, out_features=450, bias=False)
    (W_o): Linear(in_features=489, out_features=450, bias=True)
  )
  (decoder): JTNNDecoder(
    (embedding): Embedding(2279, 450)
    (W_z): Linear(in_features=900, out_features=450, bias=True)
    (U_r): Linear

In [15]:

mol_tree = test_data[0] 
mol_batch = [mol_tree]
tree_mess, tree_vec, mol_vec = model.encode(mol_batch)
print("shape of tree_mess: " + str(tree_mess.keys()))
print("shape of tree_vec: " + str(tree_vec.shape))
print("shape of mol_vec: " + str(mol_vec.shape))

tree_mean = model.T_mean(tree_vec)
mol_mean = model.G_mean(mol_vec)

result = model.decode(tree_mean, mol_mean, prob_decode=False)


shape of tree_mess: dict_keys([(5, 14), (6, 14), (7, 14), (14, 4), (11, 15), (12, 15), (4, 3), (15, 10), (3, 2), (10, 9), (2, 1), (9, 8), (1, 13), (8, 13), (13, 0), (0, 13), (13, 1), (13, 8), (1, 2), (8, 9), (2, 3), (9, 10), (3, 4), (10, 15), (4, 14), (15, 11), (15, 12), (14, 5), (14, 6), (14, 7)])
shape of tree_vec: torch.Size([1, 450])
shape of mol_vec: torch.Size([1, 450])


[22:54:04] Explicit valence for atom # 1 C, 5, is greater than permitted
[22:54:04] Explicit valence for atom # 5 C, 6, is greater than permitted
[22:54:04] Explicit valence for atom # 5 C, 6, is greater than permitted
[22:54:04] Explicit valence for atom # 2 C, 5, is greater than permitted
[22:54:04] Explicit valence for atom # 3 C, 5, is greater than permitted
[22:54:04] Explicit valence for atom # 2 C, 5, is greater than permitted
[22:54:04] Explicit valence for atom # 5 C, 6, is greater than permitted
[22:54:04] Explicit valence for atom # 5 C, 6, is greater than permitted
[22:54:04] Explicit valence for atom # 2 C, 5, is greater than permitted
[22:54:04] Explicit valence for atom # 3 C, 5, is greater than permitted
[22:54:04] Explicit valence for atom # 2 C, 5, is greater than permitted
[22:54:04] Explicit valence for atom # 4 C, 5, is greater than permitted
[22:54:04] Explicit valence for atom # 4 C, 5, is greater than permitted
[22:54:04] Explicit valence for atom # 4 C, 5, is g

In [11]:
latent = np.load("qvae/latent_vectors_sample.npz")
latent.keys()
for k in latent.keys():
    print(latent[k])

[[-7.14630485e-02  2.00638986e+00 -5.70057929e-01 ...  1.86683893e-01
   7.27568686e-01 -8.57842207e-01]
 [ 2.66846791e-02 -4.52017927e+00 -2.49691558e+00 ... -1.67793840e-01
   1.95688471e-01 -8.15476477e-01]
 [ 1.35045469e+00  2.64833778e-01  5.01492798e-01 ... -2.83942297e-02
   1.07684452e-03 -6.08146250e-01]
 ...
 [ 1.47505879e-01 -6.54589951e-01 -7.79010236e-01 ...  4.22240049e-03
   1.09001204e-01 -7.28557825e-01]
 [ 3.42498541e+00 -1.25488186e+00 -1.04263616e+00 ... -9.53609124e-03
  -2.27499068e-01 -8.51416171e-01]
 [-6.77770853e-01 -6.39140487e-01 -8.37219834e-01 ...  1.15038723e-01
   1.82776988e-01 -6.80854201e-01]]
['Cn1cc(NC(=O)N2CCCCCC2c2ccncc2)cn1' 'Cc1nccc(NC(C)CCC[NH3+])n1'
 'O=S(=O)(c1ccc(Cl)cc1C(F)(F)F)N1CCCC(c2nc3ccccc3o2)C1' ...
 'CC(C)Oc1cccnc1C(=O)N1CCN(C(=O)C[NH+]2CCCC2)CC1'
 'Cc1c(Cl)cccc1NC(=O)C(C)S(=O)(=O)Cc1nnnn1C1CC1'
 'COc1ccccc1C(=O)CC1(O)C(=O)Nc2ccc([N+](=O)[O-])cc21']
