In [16]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import tensorflow as tf

from utils import load_features_and_labels
from models.gaussian_process import train_gp_model

import rdkit
import rdkit.Chem as Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem

In [4]:
smiles, X, X_p, y = load_features_and_labels('./processed_data/jtnn_features.csv' ,'e_iso_pi')

In [5]:
model = train_gp_model(X, y)


Beginning training loop...

mean R^2: 0.8601 +- 0.0144
mean RMSE: 23.8784 +- 1.4816
mean MAE: 14.8252 +- 0.6878



In [15]:
n_iter = 20
lr = 2.0

curr_vec = None
visited = []

jtnnVAE = ()

for step in range(n_iter):
  curr_vec_tf = tf.Variable(curr_vec, name='curr_vec')
  with tf.GradientTape() as tape:
    wavelength, _ = model.predict_f(curr_vec_tf)
  grad = tape.gradient(wavelength, curr_vec_tf)
  curr_vec = curr_vec.data + lr * grad.data
  visited.append(curr_vec)

l, r = 0, n_iter - 1
while l < r - 1:
  mid = (l + r) / 2
  new_vec = visited[mid]
  tree_vec, mol_vec = # split tf variable
  new_smiles = jtnnVAE.decode(tree_vec, mol_vec, prob_decode=False)
  if new_smiles is None:
    r = mid - 1
    continue

  new_mol = Chem.MolFromSmiles(new_smiles)
  fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
  sim = DataStructs.TanimotoSimilarity(fp1, fp2) 
  if sim < sim_cutoff:
      r = mid - 1
  else:
      l = mid

tree_vec,mol_vec = torch.chunk(visited[l], 2, dim=1)
new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
new_mol = Chem.MolFromSmiles(new_smiles)
fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
sim = DataStructs.TanimotoSimilarity(fp1, fp2) 



tf.Tensor([[0.11229529]], shape=(1, 1), dtype=float64)
tf.Tensor(
[[ 1.81992289e-03  1.97104244e-03  1.41004114e-03  1.55459689e-02
   1.65529002e-02 -7.73708644e-03  8.84009349e-03  1.45640911e-02
  -6.73003740e-03  5.02547915e-04  1.76446229e-03  9.56246277e-03
   1.15347834e-02  7.28755851e-04 -6.80421270e-03 -6.07509846e-03
   0.00000000e+00 -1.72284991e-02 -7.00637148e-03  2.93790804e-03
   1.41696001e-02 -5.87784293e-03  5.68474396e-04  8.78960136e-04
   0.00000000e+00 -7.20706491e-03  2.21996935e-03 -6.20273673e-03
  -6.42833186e-03 -3.79273532e-03  6.92857416e-03 -6.52852651e-03
   2.48528459e-03  1.05787305e-03  2.47934490e-03  5.08678987e-03
   1.06363020e-02  4.32101562e-03  2.49000493e-03 -1.65036782e-02
  -9.21693353e-03  1.46647016e-02  1.37459041e-03  8.08790041e-03
  -9.94336272e-03  9.44255935e-03 -5.65814163e-03 -8.36161163e-03
   1.91349394e-03  3.27806026e-03  2.30262157e-03  7.25989297e-03
   2.67484379e-03  7.13353160e-03  7.57079746e-03 -3.80583575e-03
  -8.74011