# JTVAE embedding
This is a molecule embedding using the JunctionTree VAE, as implemented in DGLLifeSci.

It's pretrained on LINCS + Trapnell + half of ZINC (~220K molecules total).
LINCS contains a `Cl.[Li]` molecule which fails during encoding, so it just gets a dummy encoding.

In [1]:
import pickle

import pandas as pd
import rdkit
import torch
from dgllife.data import JTVAEDataset, JTVAECollator
from dgllife.model import load_pretrained
from tqdm import tqdm

print(rdkit.__version__)
print(torch.__version__)
assert torch.cuda.is_available()

[Using backend: pytorch
07:37:25] /opt/dgl/src/runtime/tensordispatch.cc:43: TensorDispatcher: dlopen failed: /home/icb/simon.boehm/miniconda3/envs/jtvae_dgl/lib/python3.7/site-packages/dgl/tensoradapter/pytorch/libtensoradapter_pytorch_1.10.1.so: cannot open shared object file: No such file or directory


2018.09.3
1.10.1


In [2]:
from dgllife.model import JTNNVAE

from_pretrained = False
if from_pretrained:
    model = load_pretrained("JTVAE_ZINC_no_kl")
else:
    trainfile = "data/train_077a9bedefe77f2a34187eb57be2d416.txt"
    modelfile = "data/model-vaetrain-final.pt"
    vocabfile = "data/vocab-final.pkl"

    with open(vocabfile, "rb") as f:
        vocab = pickle.load(f)

    model = JTNNVAE(vocab=vocab,
                    hidden_size=450,
                    latent_size=56,
                    depth=3)
    model.load_state_dict(torch.load(modelfile, map_location="cpu"))


In [3]:
model = model.to("cuda")

In [4]:
smiles = pd.read_csv("../lincs_trapnell.smiles")
# need to remove the header, before passing it to JTVAE
smiles.to_csv("jtvae_dataset.smiles", index=False, header=None)

In [5]:
dataset = JTVAEDataset("jtvae_dataset.smiles", vocab=model.vocab, training=False)
collator = JTVAECollator(training=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collator, drop_last=True)

## Reconstruction demo
Reconstruct a couple of molecules to check reconstruction performance (it's not good).

In [6]:
acc = 0.0
device="cuda"
for it, (tree, tree_graph, mol_graph) in enumerate(dataloader):
    if it > 10:
        break
    tot = it + 1
    smiles = tree.smiles
    tree_graph = tree_graph.to(device)
    mol_graph = mol_graph.to(device)
    dec_smiles = model.reconstruct(tree_graph, mol_graph)
    print(dec_smiles)
    print(smiles)
    print()
    if dec_smiles == smiles:
        acc += 1
print('Final acc: {:.4f}'.format(acc / tot))



C[C@H](C#N)NC(=O)c1oc(-c2ccccc2)cc1C(F)F
C[C@H](NC(=O)/C(C#N)=C/c1cccc(Br)n1)c1ccccc1

Cc1nn(C)c(C)c1Cn1nc(-c2ccc(Cl)c(F)c2)ccc1=S
Cc1cc(Nc2cc(CN3CCOCC3)c3nc(C)c(Cc4ccc(Cl)cc4F)n3n2)[nH]n1

Cc1cc(-n2nccc2C(=O)NCc2ccc3ccc(Cl)cc3n2)ccc1F
Cc1cc(N2CCOCC2)cc2[nH]c(-c3c(NCC(O)c4cccc(Cl)c4)cc[nH]c3=O)nc12

Cc1ccc(N2CCC(NC3CC3)CC2)cc1
Cl.Cl.c1ccc([C@@H]2C[C@H]2NC2CCNCC2)cc1

O=C(/C=C/c1ccc2cccnc2c1)N1CCNCC1=O
O=C(c1ccc(/C=C/c2n[nH]c3ccccc23)cc1)N1CCNCC1

C[C@H]1C[C@@H](N2N=C(C(=O)Cc3ccccc3)NC2(C)C)CC[C@H]1C
Cc1nnc(C(C)C)n1C1CC2CCC(C1)N2CCC(NC(=O)C1CCC(F)(F)CC1)c1ccccc1

C[C@H](N)CN1CCC[C@@H]2CCC[C@H]([C@H](O)CO)[C@@H]21
NC(=O)c1ncn([C@@H]2O[C@H](CO)[C@@H](O)[C@H]2O)c1N

N#CCNC(=O)c1ccc(Nc2cnc(N3CCCC3)nc2)cc1
N#CCNC(=O)c1ccc(-c2ccnc(Nc3ccc(N4CCOCC4)cc3)n2)cc1

C#Cc1ccc(NC(=O)NO)c(OC)c1
C#Cc1cccc(Nc2ncnc3cc(OC)c(OCCCCCCC(=O)NO)cc23)c1

O=C(CN1C(=O)C=C(c2ccccc2)C12CC=CC2)N1CCCCC1
O=C1CCC(N2C(=O)c3ccccc3C2=O)C(=O)N1

NC(=O)c1cc(-c2cccc(Cl)c2OCc2cccs2)[nH]c(=O)c1
NC(=O)C1CCCc2c1[nH]c1ccc(Cl)cc21

F

## Generate embeddings for all LINCS + Trapnell molecules

In [7]:
get_data = lambda idx: collator([dataset[idx]])
errors = []
smiles = []
latents = []
for i in tqdm(range(len(dataset))):
    try:
        _, batch_tree_graphs, batch_mol_graphs = get_data(i)
        batch_tree_graphs = batch_tree_graphs.to("cuda")
        batch_mol_graphs = batch_mol_graphs.to("cuda")
        with torch.no_grad():
            _, tree_vec, mol_vec = model.encode(batch_tree_graphs, batch_mol_graphs)
        latent = torch.cat([model.T_mean(tree_vec), model.G_mean(mol_vec)], dim=1)
        latents.append(latent)
        smiles.append(dataset.data[i])
    except Exception as e:
        errors.append((dataset.data[i], e))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17869/17869 [1:10:38<00:00,  4.22it/s]


In [8]:
# There should only be one error, a Cl.[Li] molecule.
errors

[('Cl.[Li]', KeyError('x'))]

In [9]:
# Add a dummy embedding for the Cl.[Li] molecule
dummy_emb = torch.mean(torch.concat(latents), dim=0).unsqueeze(dim=0)
assert dummy_emb.shape == latents[0].shape
smiles.append(errors[0][0])
latents.append(dummy_emb)
assert len(latents) == len(smiles)

In [10]:
np_latents = [latent.squeeze().cpu().detach().numpy() for latent in latents]
final_df = pd.DataFrame(np_latents, index=smiles,
                        columns=[f"latent_{i + 1}" for i in range(np_latents[0].shape[0])])
final_df.to_parquet("data/jtvae_dgl.parquet")

In [11]:
final_df

Unnamed: 0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,latent_10,...,latent_47,latent_48,latent_49,latent_50,latent_51,latent_52,latent_53,latent_54,latent_55,latent_56
C[C@H](NC(=O)/C(C#N)=C/c1cccc(Br)n1)c1ccccc1,-1.574519,-6.199383,0.361449,-2.896032,0.508433,-2.463944,-2.274414,0.915923,-1.768996,-5.671151,...,-0.381136,0.114776,0.461026,-0.232696,0.149239,0.008222,0.064641,1.459532,-0.240004,0.192129
Cc1cc(Nc2cc(CN3CCOCC3)c3nc(C)c(Cc4ccc(Cl)cc4F)n3n2)[nH]n1,2.289662,5.075178,4.964231,-7.993801,-1.601092,7.161238,2.391373,-3.190383,1.082587,-9.432123,...,-0.269714,0.135338,0.240984,-0.346155,0.155523,0.214027,-0.142053,1.222742,-0.241584,0.011364
Cc1cc(N2CCOCC2)cc2[nH]c(-c3c(NCC(O)c4cccc(Cl)c4)cc[nH]c3=O)nc12,4.548813,3.436257,6.246993,-2.812740,-4.044300,2.656566,2.754469,-3.622663,-4.065948,-6.316630,...,-0.285767,0.132375,0.272686,-0.329809,0.154618,0.184377,-0.112274,1.256856,-0.241356,0.037407
Cl.Cl.c1ccc([C@@H]2C[C@H]2NC2CCNCC2)cc1,5.660413,4.218781,2.332154,-2.789022,1.138375,0.966016,-2.565607,1.725014,-3.196076,1.320138,...,-0.243737,0.120570,0.182537,-0.232914,0.121171,0.276718,-0.018734,1.226653,-0.262646,0.031372
O=C(c1ccc(/C=C/c2n[nH]c3ccccc23)cc1)N1CCNCC1,-3.375466,2.009230,8.609015,-2.512949,-3.200943,-3.950001,-2.787516,2.496357,6.775537,2.045950,...,-0.299430,0.129854,0.299669,-0.315896,0.153847,0.159139,-0.086928,1.285893,-0.241162,0.059574
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Cc1cc(CS(=O)(=O)c2ccccc2)cc(OCc2ccc(CN3CCC[C@@H]3CO)cc2)c1,5.240765,2.263940,8.101819,0.707216,-2.464874,-3.738584,-2.760966,3.972898,0.451858,2.640987,...,-0.327352,0.124701,0.354811,-0.287463,0.152272,0.107565,-0.035130,1.345232,-0.240767,0.104873
CN(C)CCOc1ccc(/C(=C(\CCCl)c2ccccc2)c2ccccc2)cc1,-1.581464,5.444716,9.411486,-4.419342,-3.032458,-5.367414,2.688151,0.213778,1.863454,2.824382,...,-0.309993,0.127905,0.320529,-0.305140,0.153251,0.139629,-0.067333,1.308341,-0.241013,0.076711
CC1(C)C=Cc2c(ccc3c2[N+]([O-])=C2C3=C[C@@]34NC(=O)[C@]5(CCCN5C3=O)C[C@H]4C2(C)C)O1,0.949761,-2.565722,2.936663,0.848347,-4.364933,5.616033,-1.832742,2.019709,2.333825,-2.484026,...,-0.246827,0.139561,0.195786,-0.369461,0.156814,0.256302,-0.184509,1.174103,-0.241908,-0.025767
C[C@@H]1CC(=O)NN=C1c1ccc(N)c([N+](=O)[O-])c1,4.238235,3.362757,2.420582,-6.184925,-3.117164,6.042106,1.130967,3.782147,0.415811,1.656704,...,-0.306768,0.128500,0.314160,-0.308424,0.153433,0.145586,-0.073316,1.301487,-0.241058,0.071478


In [12]:
smiles = pd.read_csv("../lincs_trapnell.smiles")
smiles2 = final_df.index

In [13]:
set(list(smiles["smiles"])) == set(list(smiles2))

True