In [1]:
import torch, os
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from dig.ggraph.dataset import ZINC250k, ZINC800
from dig.ggraph.method import JTVAE
from dig.ggraph.method.JTVAE.fast_jtnn import MolTreeFolder
from dig.ggraph.evaluation import RandGenEvaluator
from tqdm import tqdm

from rdkit import RDLogger, Chem
RDLogger.DisableLog('rdApp.*')



In [2]:
print('Loading ZINC250k...')
data_dir = "/mnt/data/shared/jacob/CHEMIR/GNN/data"
limit = 100

dataset = ZINC250k(one_shot=False, root=data_dir)
smiles = torch.load(os.path.join(dataset.processed_dir, "data.pt"))[-1]
vocab = list(torch.load(os.path.join(dataset.processed_dir, "vocab.pt")))
trees = torch.load(os.path.join(dataset.processed_dir, "trees.pt"))
if limit > 0 or limit is not None:
    dataset.slices = {k:v[:limit] for k,v in dataset.slices.items()}
    smiles = smiles[:limit]
    trees = trees[:limit]


Loading ZINC250k...


In [3]:
# Define model
jtvae = JTVAE(smiles)
config = {"hidden_size": 420, "latent_size": 56, "depthT": 20, "depthG": 3}
jtvae.get_model('rand_gen', config)



In [4]:
set(jtvae.vocab) - set(vocab)

set()

In [5]:
preprocessed = jtvae.preprocess(smiles)

100%|██████████| 100/100 [01:07<00:00,  1.49it/s]


In [6]:
preprocessed == trees

False

In [7]:
all([tree0.smiles == tree.smiles for (tree0, tree) in zip(trees, preprocessed) ])

True

In [8]:
all([all([node0.smiles == node.smiles for node0, node in zip(tree0.nodes, tree.nodes)]) for tree0, tree in zip(trees, preprocessed)])

True

In [9]:
loader = MolTreeFolder(preprocessed, jtvae.vocab, 32, num_workers=4)
load_epoch = 0
lr = 1e-3
anneal_rate = 0.9
clip_norm = 50.0
num_epochs = 1
total_step = 0  # TODO args.load_epoch
beta = 0.0  # TODO args.beta
max_beta = 1.0
save_iter = 5000
step_beta = 0.002
anneal_iter = 40000
print_iter = 50
kl_anneal_iter = 2000
jtvae.train_rand_gen(loader, load_epoch, lr, anneal_rate, clip_norm, num_epochs, beta, max_beta, step_beta, anneal_iter, kl_anneal_iter, print_iter, save_iter)

Model #Params: 4142K




In [16]:
loader = MolTreeFolder(preprocessed, jtvae.vocab, 32, num_workers=4)
batch = next(iter(loader))

In [17]:
len(batch)

4

In [11]:
samples = list(map(Chem.MolFromSmiles, jtvae.run_rand_gen(1)))

In [12]:
res_dict = {'mols': samples, 'train_smiles': smiles}
evaluator = RandGenEvaluator()
results = evaluator.eval(res_dict)
print(results)

Valid Ratio: 1/1 = 100.00%
Unique Ratio: 1/1 = 100.00%
Novel Ratio: 1/1 = 100.00%
{'valid_ratio': 100.0, 'unique_ratio': 100.0, 'novel_ratio': 100.0}
