In [15]:
from dgllife.model import load_pretrained
from dgllife.data import JTVAEDataset, JTVAECollator
import torch
from rdkit import Chem
import pandas as pd
import rdkit
from tqdm import tqdm
print(rdkit.__version__)

2018.09.3


In [10]:
model = load_pretrained("JTVAE_ZINC_no_kl")

Downloading JTVAE_ZINC_no_kl_pre_trained.pth from https://data.dgl.ai/pre_trained/jtvae_ZINC_no_kl.pth...
Pretrained model loaded


In [11]:
def canonicalize(smiles: str):
    return Chem.MolToSmiles(Chem.MolFromSmiles(smiles))

In [12]:
df_filename = "/home/simon/Documents/ETH/Masters_thesis/chemical_CPA/embeddings/grover/data/embeddings/grover_base.parquet"
smiles = [canonicalize(s) for s in pd.read_parquet(df_filename).index]
file = "all_smiles.txt"
with open(file, "w") as f:
    for smile in smiles:
        f.write(smile + "\n")

In [23]:
dataset = JTVAEDataset("all_smiles.txt", vocab=model.vocab, training=False)
collator =JTVAECollator(training=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,collate_fn=collator, drop_last=True)

In [40]:
get_data = lambda idx: collator([dataset[idx]])
errors = []
for i in tqdm(range(len(dataset))):
    try:
        _, batch_tree_graphs, batch_mol_graphs = get_data(i)
        _, tree_vec, mol_vec = model.encode(batch_tree_graphs, batch_mol_graphs)
        latent = torch.cat([model.T_mean(tree_vec), model.G_mean(mol_vec)], dim=1)
    except Exception as e:
        errors.append(e)

 81%|████████  | 14443/17868 [23:43<05:37, 10.15it/s] 


KeyboardInterrupt: 

In [55]:
s = set([e.args[0] for e in errors])

In [58]:
len(errors)

3667

In [56]:
len(s)

137

In [31]:
collator([dataset[10]])

(<dgllife.utils.jtvae.mol_tree.MolTree at 0x7f0a25b10198>,
 Graph(num_nodes=15, num_edges=28,
       ndata_schemes={'wid': Scheme(shape=(), dtype=torch.int64)}
       edata_schemes={}),
 Graph(num_nodes=17, num_edges=34,
       ndata_schemes={'x': Scheme(shape=(39,), dtype=torch.float32)}
       edata_schemes={'x': Scheme(shape=(50,), dtype=torch.float32)}))

In [7]:
tree_mean = model.T_mean(tree_vec)

In [8]:
mol_mean = model.G_mean(mol_vec)

In [9]:
tree_mean.shape

torch.Size([1, 28])

In [61]:
mol_mean.shape

torch.Size([1, 28])

In [63]:
model.decode(tree_mean, mol_mean, prob_decode=False)

'Cc1ccnc(NC=O)c1'

In [67]:
torch.cat([model.T_mean(tree_vec), model.G_mean(mol_vec)], dim=1).shape

torch.Size([1, 56])

In [70]:
latent.shape

torch.Size([1, 56])