In [1]:
import pandas as pd
import torch
from hgraph import PairVocab, HierVAE, common_atom_vocab
from hgraph.hgnn import make_cuda
from preprocess import tensorize
from rdkit import Chem

In [2]:
# Parse Vocabulary File
vocab_file = "data/chembl/vocab.txt"
vocab = PairVocab([x.strip("\r\n ").split() for x in open(vocab_file)])


In [3]:
# load SMILES
def canonicalize(smiles: str):
    return Chem.MolToSmiles(Chem.MolFromSmiles(smiles), True)
df_filename = "/home/simon/Documents/ETH/Masters_thesis/chemical_CPA/embeddings/grover/data/embeddings/grover_base.parquet"
smiles = [canonicalize(s) for s in pd.read_parquet(df_filename).index]

In [4]:
for i in range(len(smiles)):
    try:
        batches, tensors, all_orders = tensorize(smiles[i:i+1], vocab)
    except Exception as e:
        print(i, type(e))

236 <class 'KeyError'>
402 <class 'KeyError'>
413 <class 'KeyError'>
457 <class 'KeyError'>
511 <class 'KeyError'>
551 <class 'KeyError'>
588 <class 'KeyError'>
652 <class 'KeyError'>
716 <class 'KeyError'>
771 <class 'KeyError'>
803 <class 'AttributeError'>
837 <class 'KeyError'>
842 <class 'KeyError'>
857 <class 'KeyError'>
863 <class 'KeyError'>
907 <class 'KeyError'>
909 <class 'KeyError'>
1063 <class 'KeyError'>
1072 <class 'KeyError'>


KeyboardInterrupt: 

In [None]:
from dataclasses import dataclass

batches, tensors, all_orders = tensorize(smiles[0:200], vocab)
tree_tensors, graph_tensors = make_cuda(tensors)

In [None]:
@dataclass
class Args:
    seed= 7
    nsample= 10000
    rnn_type= "LSTM"
    hidden_size=250
    embed_size=250
    batch_size=50
    latent_size=32
    depthT=15
    depthG=15
    diterT=1
    diterG=3
    dropout=0.0
    vocab=vocab
    atom_vocab=common_atom_vocab

In [None]:
model_fname = "ckpt/chembl-pretrained/model.ckpt"
model = HierVAE(Args()).cuda()
model.load_state_dict(torch.load(model_fname)[0])
model = model.eval()

In [None]:
with torch.no_grad():
    root_vecs, tree_vecs, _, graph_vecs = model.encoder(tree_tensors, graph_tensors)
    # TODO What's the difference between the first and second root_vecs?
    root_vecs, root_kl = model.rsample(
        root_vecs, model.R_mean, model.R_var, perturb=False
    )

In [None]:
root_vecs.shape