In [1]:
!pip install -q torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu116.html
!pip install -q dive-into-graphs
!pip install -q toolz
!pip install -q wandb

[0m

In [3]:
%load_ext autoreload
%autoreload 2
import os
import json
import argparse
import pickle 

import numpy as np
import pandas as pd
import torch

from molecule_optimizer.externals.fast_jtnn.datautils import SemiMolTreeFolder, SemiMolTreeFolderTest
from molecule_optimizer.runner.semi_jtvae import SemiJTVAEGeneratorPredictor
from torch_geometric.data import DenseDataLoader

import rdkit

lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

import warnings
warnings.filterwarnings("ignore")

import math

In [4]:
conf = json.load(open("training/configs/rand_gen_zinc250k_config_dict.json"))

In [5]:
csv = pd.read_csv("ZINC_310k.csv")
smiles = csv['SMILES']

In [7]:
N_TEST = 10000
VAL_FRAC = 0.05
chem_prop = "MolWt"
load_epoch = 0

In [8]:
labels = torch.tensor(csv[chem_prop]).float()

In [23]:
runner = SemiJTVAEGeneratorPredictor(smiles)
processed_smiles, processed_idxs = SemiJTVAEGeneratorPredictor.preprocess(smiles) 

In [11]:
runner.get_model(
    "rand_gen",
    {
        "hidden_size": conf["model"]["hidden_size"],
        "latent_size": conf["model"]["latent_size"],
        "depthT": conf["model"]["depthT"],
        "depthG": conf["model"]["depthG"],
        "label_size": 1,
        "label_mean": float(torch.mean(labels)),
        "label_var": float(torch.var(labels)),
    },
)

In [12]:
labels = runner.get_processed_labels(labels, processed_idxs)
preprocessed = processed_smiles

In [13]:

perm_id=np.random.permutation(len(labels))

X_train = preprocessed[perm_id[N_TEST:]]
X_train_smiles = smiles[perm_id[N_TEST:]]
L_train = torch.tensor(labels.numpy()[perm_id[N_TEST:]])


X_test = preprocessed[perm_id[:N_TEST]]
X_test_smiles = smiles[perm_id[:N_TEST]]
L_test = torch.tensor(labels.numpy()[perm_id[:N_TEST]])

val_cut = math.floor(len(X_train) * VAL_FRAC)

X_Val = X_train[:val_cut]
X_Val_smiles = X_train_smiles[:val_cut]
L_Val = L_train[:val_cut]

X_train = X_train[val_cut :]
X_train_smiles = X_train_smiles[val_cut :]
L_train = L_train[val_cut :]

In [None]:
print("Training model...")
runner.train_gen_pred_supervised(
    X_train,
    L_train,
    X_test,
    L_test,
    X_Val,
    L_Val,
    load_epoch= load_epoch,
    lr=conf["lr"],
    anneal_rate=conf["anneal_rate"],
    clip_norm=conf["clip_norm"],
    num_epochs=conf["num_epochs"],
    alpha=0.0,
    max_alpha=conf["max_alpha"],
    step_alpha=conf["step_alpha"],
    beta=0.0,
    max_beta=conf["max_beta"],
    step_beta=conf["step_beta"],
    anneal_iter=conf["anneal_iter"],
    alpha_anneal_iter=conf["alpha_anneal_iter"],
    kl_anneal_iter=conf["kl_anneal_iter"],
    print_iter=100,
    save_iter= 1000,
    batch_size=conf["batch_size"],
    num_workers=conf["num_workers"],
    label_pct=0.5,
    chem_prop = chem_prop
)

Training model...
Model #Params: 5207K
[Train][9100] Alpha: 0.000, Beta: 0.000, Loss: 6.59, KL: 789.27, MAE: 90.13103, Word Loss: 4.55, Topo Loss: 1.20, Assm Loss: 0.84, Pred Loss: 2.76, Word: 90.45, Topo: 98.37, Assm: 92.49, PNorm: 235.40, GNorm: 21.08
[Train][9200] Alpha: 0.000, Beta: 0.000, Loss: 6.39, KL: 806.50, MAE: 91.06080, Word Loss: 4.47, Topo Loss: 1.13, Assm Loss: 0.80, Pred Loss: 2.78, Word: 90.38, Topo: 98.49, Assm: 92.93, PNorm: 237.90, GNorm: 22.20
[Train][9300] Alpha: 0.000, Beta: 0.000, Loss: 6.44, KL: 806.42, MAE: 91.96536, Word Loss: 4.33, Topo Loss: 1.25, Assm Loss: 0.86, Pred Loss: 2.92, Word: 90.80, Topo: 98.35, Assm: 92.68, PNorm: 240.23, GNorm: 23.92
[Train][9400] Alpha: 0.000, Beta: 0.000, Loss: 6.35, KL: 805.82, MAE: 89.43684, Word Loss: 4.38, Topo Loss: 1.14, Assm Loss: 0.84, Pred Loss: 2.73, Word: 90.73, Topo: 98.40, Assm: 92.54, PNorm: 242.54, GNorm: 23.66
[Train][9500] Alpha: 0.000, Beta: 0.000, Loss: 6.24, KL: 793.20, MAE: 92.61533, Word Loss: 4.34, Topo