In [1]:
!pip install -q torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric #-f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q dive-into-graphs



In [2]:
%load_ext autoreload
%autoreload 2
import os
import json
import argparse
import pickle 

import numpy as np
import pandas as pd
import torch

from molecule_optimizer.externals.fast_jtnn.datautils import SemiMolTreeFolder
from molecule_optimizer.runner.semi_jtvae import SemiJTVAEGeneratorPredictor
from torch_geometric.data import DenseDataLoader

import rdkit

lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

import warnings
warnings.filterwarnings("ignore")

In [3]:
conf = json.load(open("training/configs/rand_gen_zinc250k_config_dict.json"))

In [4]:
csv = pd.read_csv("ZINC_310k.csv")

In [5]:
smiles = csv['SMILES']

In [6]:
if 'runner.xml' not in os.listdir("."):
    runner = SemiJTVAEGeneratorPredictor(smiles)
    with open('runner.xml', 'wb') as f:
        pickle.dump(runner, f)

In [7]:
with open('runner.xml', 'rb') as f:
    runner = pickle.load(f)

In [8]:
labels = torch.tensor(csv['LogP']).float()

In [9]:
runner.get_model(
    "rand_gen",
    {
        "hidden_size": conf["model"]["hidden_size"],
        "latent_size": conf["model"]["latent_size"],
        "depthT": conf["model"]["depthT"],
        "depthG": conf["model"]["depthG"],
        "label_size": 1,
        "label_mean": float(torch.mean(labels)),
        "label_var": float(torch.var(labels)),
    },
)

In [10]:
labels = runner.get_processed_labels(labels)
preprocessed = runner.processed_smiles

In [11]:
len(labels)

292191

In [12]:
loader = SemiMolTreeFolder(
    preprocessed,
    labels,
    runner.vocab,
    conf["batch_size"],
    num_workers=conf["num_workers"],
)

In [None]:
print("Training model...")
runner.train_gen_pred(
    loader=loader,
    load_epoch=0,
    lr=conf["lr"],
    anneal_rate=conf["anneal_rate"],
    clip_norm=conf["clip_norm"],
    num_epochs=conf["num_epochs"],
    alpha=conf["alpha"],
    beta=conf["beta"],
    max_beta=conf["max_beta"],
    step_beta=conf["step_beta"],
    anneal_iter=conf["anneal_iter"],
    kl_anneal_iter=conf["kl_anneal_iter"],
    print_iter=100,
    save_iter=conf["save_iter"],
)

Training model...
Model #Params: 5207K
[100] Alpha: 250.000, Beta: 0.000, Loss: 528.10, KL: 156.63, MAE: 0.02341, Word Loss: 199.06, Topo Loss: 56.74, Assm Loss: 17.34, Pred Loss: 1.02, Word: 0.49, Topo: 1.57, Assm: 1.12, PNorm: 102.92, GNorm: 50.00
[200] Alpha: 250.000, Beta: 0.000, Loss: 298.32, KL: 272.62, MAE: 0.02483, Word Loss: 127.71, Topo Loss: 28.91, Assm Loss: 16.51, Pred Loss: 0.50, Word: 0.86, Topo: 1.80, Assm: 1.19, PNorm: 106.52, GNorm: 50.00
