# JTVAE embedding
This is a molecule embedding using the JunctionTree VAE, as implemented in DGLLifeSci.

It's pretrained on LINCS + Trapnell + half of ZINC (~220K molecules total).
LINCS contains a `Cl.[Li]` molecule which fails during encoding, so it just gets a dummy encoding.

In [16]:
import pickle
from pathlib import Path

import pandas as pd
import rdkit
import torch
from dgllife.data import JTVAEDataset, JTVAECollator
from dgllife.model import load_pretrained
from tqdm import tqdm

print(rdkit.__version__)
print(torch.__version__)
assert torch.cuda.is_available()

2018.09.3
1.10.1


In [17]:
from dgllife.utils import JTVAEVocab
from dgllife.model import JTNNVAE

from_pretrained = False
if from_pretrained:
    model = load_pretrained("JTVAE_ZINC_no_kl")
else:
    trainfile = "data/train_1f1775f24668d31640df46ce45fe3577.txt"
    modelfile = "pre_model_all/model.epoch-0"
    vocabfile = "pre_model_all/vocab_1f1775f24668d31640df46ce45fe3577.pkl"

    with open(vocabfile, "rb") as f:
        vocab = pickle.load(f)

    model = JTNNVAE(vocab=vocab,
                    hidden_size=450,
                    latent_size=56,
                    depth=3)
    model.load_state_dict(torch.load(modelfile, map_location="cpu"))


In [18]:
model = model.to("cuda")

In [19]:
smiles = pd.read_csv("../lincs_trapnell.smiles")
# need to remove the header, before passing it to JTVAE
smiles.to_csv("jtvae_dataset.smiles", index=False, header=None)

In [20]:
dataset = JTVAEDataset("jtvae_dataset.smiles", vocab=model.vocab, training=False)
collator = JTVAECollator(training=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collator, drop_last=True)

In [21]:
get_data = lambda idx: collator([dataset[idx]])
errors = []
smiles = []
latents = []
for i in tqdm(range(len(dataset))):
    try:
        _, batch_tree_graphs, batch_mol_graphs = get_data(i)
        batch_tree_graphs = batch_tree_graphs.to("cuda")
        batch_mol_graphs = batch_mol_graphs.to("cuda")
        with torch.no_grad():
            _, tree_vec, mol_vec = model.encode(batch_tree_graphs, batch_mol_graphs)
        latent = torch.cat([model.T_mean(tree_vec), model.G_mean(mol_vec)], dim=1)
        latents.append(latent)
        smiles.append(dataset.data[i])
    except Exception as e:
        errors.append((dataset.data[i], e))

100%|███████████████████████████████████████████████████████████████| 17869/17869 [50:15<00:00,  5.93it/s]


In [22]:
# There should only be one error, a Cl.[Li] molecule.
errors

[('Cl.[Li]', KeyError('x'))]

In [26]:
# Add a dummy embedding for the Cl.[Li] molecule
dummy_emb = torch.mean(torch.concat(latents), dim=0).unsqueeze(dim=0)
assert dummy_emb.shape == latents[0].shape
smiles.append(errors[0][0])
latents.append(dummy_emb)
assert len(latents) == len(smiles)

In [27]:
np_latents = [latent.squeeze().cpu().detach().numpy() for latent in latents]
final_df = pd.DataFrame(np_latents, index=smiles,
                        columns=[f"latent_{i + 1}" for i in range(np_latents[0].shape[0])])
final_df.to_parquet("data/jtvae_dgl.parquet")

In [25]:
final_df

Unnamed: 0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,latent_10,...,latent_47,latent_48,latent_49,latent_50,latent_51,latent_52,latent_53,latent_54,latent_55,latent_56
C[C@H](NC(=O)/C(C#N)=C/c1cccc(Br)n1)c1ccccc1,1.429828,-12.035091,0.489625,0.986142,-4.443451,-3.747700,-1.329681,1.072835,2.113698,2.438861,...,1.166213,0.537067,0.384840,-0.623976,-0.897625,0.084227,-0.209080,1.511904,-0.557535,1.185793
Cc1cc(Nc2cc(CN3CCOCC3)c3nc(C)c(Cc4ccc(Cl)cc4F)n3n2)[nH]n1,-2.008734,3.468817,5.856979,-7.614322,-0.542796,3.988372,2.222589,0.689923,-0.130659,-0.419161,...,-0.374772,0.471480,0.395210,0.181184,0.836576,0.055511,-0.744534,-0.230085,-0.297810,0.284168
Cc1cc(N2CCOCC2)cc2[nH]c(-c3c(NCC(O)c4cccc(Cl)c4)cc[nH]c3=O)nc12,0.886675,5.676350,7.658895,-0.507045,-5.845931,3.011087,3.310966,4.373243,0.136572,-2.537051,...,-0.512504,0.375018,-0.121645,-0.032188,1.316664,-0.518432,-0.309929,-1.380790,-0.265494,-0.245688
Cl.Cl.c1ccc([C@@H]2C[C@H]2NC2CCNCC2)cc1,5.945081,2.647774,10.389327,-2.971686,-6.238258,-0.583933,5.741209,1.338418,-2.780702,4.936309,...,0.550587,0.149925,0.061534,-0.301705,-0.490559,0.042261,0.214776,0.919423,-0.135971,0.713558
O=C(c1ccc(/C=C/c2n[nH]c3ccccc23)cc1)N1CCNCC1,-1.057689,3.401932,6.699368,-2.915870,-3.852238,1.118359,-2.190202,0.448093,0.951138,4.746724,...,0.140141,0.321975,-0.075490,-0.211345,-0.456241,0.177017,0.354334,0.985909,0.086379,0.727438
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Cc1cc(CS(=O)(=O)c2ccccc2)cc(OCc2ccc(CN3CCC[C@@H]3CO)cc2)c1,5.623644,5.570826,2.457215,6.833165,-5.539162,-0.434083,-1.206793,0.832395,2.513798,4.813797,...,0.528235,0.312193,-0.696733,-0.232044,-0.237644,0.122018,-0.500457,0.115442,-0.657997,0.737079
CN(C)CCOc1ccc(/C(=C(\CCCl)c2ccccc2)c2ccccc2)cc1,1.209961,-2.287903,1.391642,1.202353,-1.680503,-1.807198,3.338384,-1.620294,0.594872,1.918776,...,0.096867,0.254615,-0.346058,-0.112815,-0.661905,0.132621,0.539877,1.072693,0.103202,0.816531
CC1(C)C=Cc2c(ccc3c2[N+]([O-])=C2C3=C[C@@]34NC(=O)[C@]5(CCCN5C3=O)C[C@H]4C2(C)C)O1,2.400828,-7.876045,-1.187408,3.642219,-6.089412,-1.792802,1.092967,7.154920,1.991817,-2.085588,...,-0.726810,0.372121,0.186016,-0.006123,0.946836,-0.410416,0.034444,-0.738778,0.305928,-0.150375
C[C@@H]1CC(=O)NN=C1c1ccc(N)c([N+](=O)[O-])c1,2.345975,2.942803,2.595910,0.994162,-2.953399,5.007487,2.610762,5.818273,0.155535,1.398346,...,-0.226300,0.457870,0.546527,-0.094465,1.204776,-0.509441,0.005396,-0.867492,-0.321617,-0.144999


In [30]:
smiles = pd.read_csv("../lincs_trapnell.smiles")
smiles2 = final_df.index

In [39]:
set(list(smiles["smiles"])) == set(list(smiles2))

True