In [1]:
from dgllife.model import load_pretrained
from dgllife.data import JTVAEDataset, JTVAECollator
import torch
from rdkit import Chem
import pandas as pd
import rdkit
from tqdm import tqdm
print(rdkit.__version__)
print(torch.__version__)
assert torch.cuda.is_available()

Using backend: pytorch


2018.09.3
1.10.0


In [2]:
model = load_pretrained("JTVAE_ZINC_no_kl")

Downloading JTVAE_ZINC_no_kl_pre_trained.pth from https://data.dgl.ai/pre_trained/jtvae_ZINC_no_kl.pth...
Pretrained model loaded


In [3]:
model = model.to("cuda")

In [None]:
def canonicalize(smiles: str):
    return Chem.CanonSmiles(smiles, useChiral=1)

In [11]:
recanon_smiles = [
    Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
    for smiles in
    pd.read_csv("lincs_trapnell.smiles", header=None)[0].values]

In [13]:
x = pd.read_csv("lincs_trapnell.smiles", header=None)[0].apply(canonicalize)

In [16]:
x.to_csv("recanonicalized.smiles", index=False, header=None)

In [17]:
dataset = JTVAEDataset("recanonicalized.smiles", vocab=model.vocab, training=False)
collator =JTVAECollator(training=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,collate_fn=collator, drop_last=True)

In [22]:
get_data = lambda idx: collator([dataset[idx]])
errors = []
for i in tqdm(range(len(dataset))):
    try:
        _, batch_tree_graphs, batch_mol_graphs = get_data(i)
        batch_tree_graphs = batch_tree_graphs.to("cuda")
        batch_mol_graphs = batch_mol_graphs.to("cuda")
        _, tree_vec, mol_vec = model.encode(batch_tree_graphs, batch_mol_graphs)
        latent = torch.cat([model.T_mean(tree_vec), model.G_mean(mol_vec)], dim=1)
    except Exception as e:
        errors.append(e)

 22%|██▏       | 3956/17869 [10:14<36:02,  6.43it/s]  


KeyboardInterrupt: 

In [21]:
print(canonicalize("C1=CC=CCCC2CCCC(CCNCCOCCCCC=CCCCCCCC=C1)O2"))
print(canonicalize((canonicalize('C1=CC=CCCC2CCCC(CCNCCOCCCCC=CCCCCCCC=C1)O2'))))

C1=CC=CCCC2CCCC(CCNCCOCCCCC=CCCCCCCC=C1)O2
C1=CC=CCCC2CCCC(CCNCCOCCCCC=CCCCCCCC=C1)O2


In [14]:
s = set([e.args[0] for e in errors])

In [23]:
len(errors)

188

In [31]:
collator([dataset[10]])

(<dgllife.utils.jtvae.mol_tree.MolTree at 0x7f0a25b10198>,
 Graph(num_nodes=15, num_edges=28,
       ndata_schemes={'wid': Scheme(shape=(), dtype=torch.int64)}
       edata_schemes={}),
 Graph(num_nodes=17, num_edges=34,
       ndata_schemes={'x': Scheme(shape=(39,), dtype=torch.float32)}
       edata_schemes={'x': Scheme(shape=(50,), dtype=torch.float32)}))

In [7]:
tree_mean = model.T_mean(tree_vec)

In [8]:
mol_mean = model.G_mean(mol_vec)

In [9]:
tree_mean.shape

torch.Size([1, 28])

In [61]:
mol_mean.shape

torch.Size([1, 28])

In [63]:
model.decode(tree_mean, mol_mean, prob_decode=False)

'Cc1ccnc(NC=O)c1'

In [67]:
torch.cat([model.T_mean(tree_vec), model.G_mean(mol_vec)], dim=1).shape

torch.Size([1, 56])

In [70]:
latent.shape

torch.Size([1, 56])