In [1]:
# lr:0.001
# anneal_rate:0.9
# batch_size:32
# clip_norm:50
# num_epochs:5
# alpha:250
# beta:0
# max_beta:1
# step_beta:0.002
# anneal_iter:40000
# kl_anneal_iter:2000
# print_iter:100
# save_iter:5000
# num_workers:4

In [2]:
!pip install -q torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu116.html
!pip install -q dive-into-graphs

[0m

In [3]:
!pip install -q toolz
!pip install -q wandb

[0m

In [4]:
%load_ext autoreload
%autoreload 2
import os
import json
import argparse
import pickle 

import numpy as np
import pandas as pd
import torch

from molecule_optimizer.externals.fast_jtnn.datautils import SemiMolTreeFolder, SemiMolTreeFolderTest
from molecule_optimizer.runner.semi_jtvae import SemiJTVAEGeneratorPredictor
from torch_geometric.data import DenseDataLoader

import rdkit

lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

import warnings
warnings.filterwarnings("ignore")

import math

In [5]:
conf = json.load(open("training/configs/rand_gen_zinc250k_config_dict.json"))

In [6]:
csv = pd.read_csv("ZINC_310k.csv")

In [8]:
smiles = csv['SMILES']

In [10]:
smiles = smiles[:60000]

In [13]:
labels = torch.tensor(csv['QED'][:60000]).float()

In [8]:
# if 'runner.xml' not in os.listdir("."):
#     runner = SemiJTVAEGeneratorPredictor(smiles)
#     with open('runner.xml', 'wb') as f:
#         pickle.dump(runner, f)

In [15]:
if 'runner_20.xml' not in os.listdir("."):
    runner = SemiJTVAEGeneratorPredictor(smiles)
    with open('runner_20.xml', 'wb') as f:
        pickle.dump(runner, f)

100%|██████████| 12/12 [00:35<00:00,  2.98s/it]
100%|██████████| 12/12 [19:24<00:00, 97.07s/it]


In [9]:
# with open('runner.xml', 'rb') as f:
#     runner = pickle.load(f)

In [16]:
with open('runner_20.xml', 'rb') as f:
    runner = pickle.load(f)

In [17]:
runner.get_model(
    "rand_gen",
    {
        "hidden_size": conf["model"]["hidden_size"],
        "latent_size": conf["model"]["latent_size"],
        "depthT": conf["model"]["depthT"],
        "depthG": conf["model"]["depthG"],
        "label_size": 1,
        "label_mean": float(torch.mean(labels)),
        "label_var": float(torch.var(labels)),
    },
)

In [18]:
labels = runner.get_processed_labels(labels)
preprocessed = runner.processed_smiles

In [19]:
len(labels)

56519

In [20]:
len(preprocessed)

56519

In [21]:
# N_TEST = 10000
N_TEST = 200
VAL_FRAC = 0.05

In [22]:
perm_id=np.random.permutation(len(labels))
X_train = preprocessed[perm_id[N_TEST:]]
L_train = torch.tensor(labels.numpy()[perm_id[N_TEST:]])


X_test = preprocessed[perm_id[:N_TEST]]
L_test = torch.tensor(labels.numpy()[perm_id[:N_TEST]])

In [23]:
val_cut = math.floor(len(X_train) * VAL_FRAC)

In [24]:
X_Val = X_train[:val_cut]
L_Val = L_train[:val_cut]

X_train = X_train[val_cut :]
L_train = L_train[val_cut :]

In [25]:
train_loader = SemiMolTreeFolder(
    X_train,
    L_train,
    runner.vocab,
    conf["batch_size"],
    label_pct=0.05,
    num_workers=conf["num_workers"],
)

In [26]:
test_loader = SemiMolTreeFolderTest(
    X_test,
    L_test,
    runner.vocab,
    conf["batch_size"],
    num_workers=conf["num_workers"],
)

In [27]:
val_loader = SemiMolTreeFolderTest(
    X_Val,
    L_Val,
    runner.vocab,
    conf["batch_size"],
    num_workers=conf["num_workers"],
)

In [None]:
print("Training model...")
runner.train_gen_pred(
    loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    load_epoch=0,
    lr=conf["lr"],
    anneal_rate=conf["anneal_rate"],
    clip_norm=conf["clip_norm"],
    num_epochs=conf["num_epochs"],
    alpha=conf["alpha"],
    beta=conf["beta"],
    max_beta=conf["max_beta"],
    step_beta=conf["step_beta"],
    anneal_iter=conf["anneal_iter"],
    kl_anneal_iter=conf["kl_anneal_iter"],
    print_iter=100,
    save_iter=conf["save_iter"],
)

Training model...
Model #Params: 4732K
[Train][100] Alpha: 250.000, Beta: 0.000, Loss: 354.66, KL: 48.97, MAE: 0.12068, Word Loss: 93.83, Topo Loss: 24.13, Assm Loss: 8.91, Pred Loss: 0.91, Word: 27.50, Topo: 80.39, Assm: 54.93, PNorm: 101.85, GNorm: 50.00
[Train][200] Alpha: 250.000, Beta: 0.000, Loss: 233.86, KL: 78.11, MAE: 0.09622, Word Loss: 60.27, Topo Loss: 17.28, Assm Loss: 8.43, Pred Loss: 0.59, Word: 46.34, Topo: 87.11, Assm: 58.82, PNorm: 105.63, GNorm: 50.00
[Train][300] Alpha: 250.000, Beta: 0.000, Loss: 189.19, KL: 109.50, MAE: 0.08460, Word Loss: 52.84, Topo Loss: 13.22, Assm Loss: 8.19, Pred Loss: 0.46, Word: 51.24, Topo: 90.44, Assm: 59.88, PNorm: 108.75, GNorm: 50.00
[Train][400] Alpha: 250.000, Beta: 0.000, Loss: 177.42, KL: 146.13, MAE: 0.08182, Word Loss: 49.17, Topo Loss: 11.51, Assm Loss: 8.07, Pred Loss: 0.43, Word: 54.87, Topo: 91.78, Assm: 61.69, PNorm: 111.50, GNorm: 50.00
[Train][500] Alpha: 250.000, Beta: 0.000, Loss: 177.50, KL: 152.52, MAE: 0.08204, Word 

In [None]:
print("Training model...")
runner.train_gen_pred_supervised(
    loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    load_epoch=0,
    lr=conf["lr"],
    anneal_rate=conf["anneal_rate"],
    clip_norm=conf["clip_norm"],
    num_epochs=conf["num_epochs"],
    alpha=50,
    beta=conf["beta"],
    max_beta=conf["max_beta"],
    step_beta=conf["step_beta"],
    anneal_iter=conf["anneal_iter"],
    kl_anneal_iter=conf["kl_anneal_iter"],
    print_iter=100,
    save_iter=conf["save_iter"],
)

Training model...
Model #Params: 5207K
[Train][100] Alpha: 1.000, Beta: 0.000, Loss: 60.34, KL: 166.62, MAE: 0.11884, Word Loss: 45.12, Topo Loss: 9.87, Assm Loss: 4.47, Pred Loss: 0.89, Word: 28.59, Topo: 84.86, Assm: 54.80, PNorm: 103.63, GNorm: 35.51
[Train][200] Alpha: 1.000, Beta: 0.000, Loss: 38.86, KL: 164.02, MAE: 0.10257, Word Loss: 28.00, Topo Loss: 6.06, Assm Loss: 4.13, Pred Loss: 0.67, Word: 50.60, Topo: 91.49, Assm: 59.06, PNorm: 108.04, GNorm: 46.30
[Train][300] Alpha: 1.000, Beta: 0.000, Loss: 33.42, KL: 184.05, MAE: 0.09147, Word Loss: 23.80, Topo Loss: 5.25, Assm Loss: 3.84, Pred Loss: 0.53, Word: 59.80, Topo: 92.62, Assm: 60.15, PNorm: 111.32, GNorm: 35.47
[Train][400] Alpha: 1.000, Beta: 0.000, Loss: 30.39, KL: 195.26, MAE: 0.08689, Word Loss: 21.25, Topo Loss: 4.83, Assm Loss: 3.84, Pred Loss: 0.48, Word: 63.47, Topo: 93.24, Assm: 61.96, PNorm: 114.29, GNorm: 25.53
[Train][500] Alpha: 1.000, Beta: 0.000, Loss: 28.59, KL: 218.78, MAE: 0.08508, Word Loss: 19.85, Topo