In [4]:
!pip install -q torch==1.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
!pip install -q torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q dive-into-graphs
!pip install -q p_tqdm



In [18]:
%load_ext autoreload
%autoreload 2
import os
import json
import argparse

import numpy as np
import torch

from dig.ggraph.dataset import ZINC250k, ZINC800
from molecule_optimizer.externals.fast_jtnn.datautils import SemiMolTreeFolder
from molecule_optimizer.runner.semi_jtvae import SemiJTVAEGeneratorPredictor
from torch_geometric.data import DenseDataLoader

import rdkit

lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

import warnings
warnings.filterwarnings("ignore")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
conf = json.load(open("training/configs/rand_gen_zinc250k_config_dict.json"))

In [7]:
print("Processing Dataset...")
_ = ZINC800(
    root=conf["data"]["root"],
    one_shot=False,
    use_aug=False,
)

Processing Dataset...


In [8]:
zinc_800_jt = torch.load(
    os.path.join(conf["data"]["root"], 'zinc_800_jt/processed/data.pt')
)
smiles = zinc_800_jt[-1]
labels = zinc_800_jt[0].y

In [9]:
runner = SemiJTVAEGeneratorPredictor(smiles)

  0%|          | 0/800 [00:00<?, ?it/s]

In [10]:
runner.get_model(
    "rand_gen",
    {
        "hidden_size": conf["model"]["hidden_size"],
        "latent_size": conf["model"]["latent_size"],
        "depthT": conf["model"]["depthT"],
        "depthG": conf["model"]["depthG"],
        "label_size": 1,
        "label_mean": float(torch.mean(labels)),
        "label_var": float(torch.var(labels)),
    },
)



In [12]:
import pickle 
filehandler = open('runner.xml', 'wb') 
pickle.dump(runner, filehandler)

In [13]:
filehandler = open('runner.xml', 'rb') 
runner = pickle.load(filehandler)

In [15]:
preprocessed, labels = runner.preprocess(smiles, labels)

100%|██████████| 800/800 [01:22<00:00,  9.74it/s]


In [16]:
loader = SemiMolTreeFolder(
    preprocessed,
    labels,
    runner.vocab,
    conf["batch_size"],
    num_workers=conf["num_workers"],
)


In [20]:
print("Training model...")
runner.train_gen_pred(
    loader=loader,
    load_epoch=0,
    lr=conf["lr"],
    anneal_rate=conf["anneal_rate"],
    clip_norm=conf["clip_norm"],
    num_epochs=conf["num_epochs"],
    alpha=conf["alpha"],
    beta=conf["beta"],
    max_beta=conf["max_beta"],
    step_beta=conf["step_beta"],
    anneal_iter=conf["anneal_iter"],
    kl_anneal_iter=conf["kl_anneal_iter"],
    print_iter=1,
    save_iter=conf["save_iter"],
)

Training model...
Model #Params: 4261K
[1] Alpha: 0.100, Beta: 0.000, Loss: 3.70, KL: 0.16, MAE: 0.03, Word Loss: 2.78, Topo Loss: 0.71, Assm Loss: 0.21, Pred Loss: 0.00, Word: 0.00, Topo: 0.98, Assm: 0.75, PNorm: 93.79, GNorm: 27.63
[2] Alpha: 0.100, Beta: 0.000, Loss: 3.66, KL: 0.39, MAE: 0.04, Word Loss: 2.77, Topo Loss: 0.71, Assm Loss: 0.19, Pred Loss: 0.00, Word: 0.02, Topo: 1.15, Assm: 0.90, PNorm: 93.82, GNorm: 26.59
[3] Alpha: 0.100, Beta: 0.000, Loss: 3.39, KL: 0.99, MAE: 0.04, Word Loss: 2.59, Topo Loss: 0.66, Assm Loss: 0.14, Pred Loss: 0.00, Word: 0.13, Topo: 1.14, Assm: 1.20, PNorm: 93.87, GNorm: 27.56
[4] Alpha: 0.100, Beta: 0.000, Loss: 3.29, KL: 2.45, MAE: 0.06, Word Loss: 2.49, Topo Loss: 0.63, Assm Loss: 0.16, Pred Loss: 0.01, Word: 0.30, Topo: 1.17, Assm: 0.94, PNorm: 93.92, GNorm: 32.06
[5] Alpha: 0.100, Beta: 0.000, Loss: 3.26, KL: 3.25, MAE: 0.06, Word Loss: 2.46, Topo Loss: 0.65, Assm Loss: 0.15, Pred Loss: 0.01, Word: 0.53, Topo: 1.18, Assm: 1.08, PNorm: 93.97,

KeyboardInterrupt: 