In [1]:
import sys
sys.path.append('/cluster/sj1/bb_opt/src')
sys.path.append('/cluster/sj1/')

In [2]:
import os
import seaborn as sns

In [3]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn
import numpy as np
import h5py
from collections import namedtuple
import torch.distributions as tdist
from gpu_utils.utils import gpu_init
from tqdm import tnrange
import pandas as pd
from bb_opt.src.utils import train_val_test_split

gpu_id = gpu_init()
print(f"Running on GPU {gpu_id}")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

%matplotlib inline

Running on GPU 0


In [4]:
import hsic
import bayesian_opt as bopt
import chemvae_bopt as cbopt
import reparam_trainer as reparam

Using TensorFlow backend.


In [5]:
from chemvae_keras import vae_utils
from chemvae_keras import mol_utils as mu

In [6]:
chemvae_num_z = 196

In [7]:
Params = namedtuple('params', [
    'lr', 
    'output_dist_std', 
    'output_dist_fn',
    'prior_mean', 
    'prior_std', 
    'num_epochs', 
    'num_samples', 
    'train_batch_size', 
    'device', 
    'exp_noise_samples',
    'hsic_train_lambda',
    
    'prop_activation',
    'prop_pred_num_input_features',
    'prop_pred_num_random_inputs',
    'prop_pred_num_hidden',
    'prop_pred_dropout',
    'prop_pred_depth',
    'prop_pred_growth_factor',
    'prop_batchnorm',
    
    'ack_batch_size',
    'num_queries',
    'batch_opt_lr',
    'batch_opt_num_iter',
    'input_opt_lr',
    'mves_kernel_fn',
    'input_opt_num_iter',
    'ack_num_model_samples',
    'hsic_diversity_lambda',
    
    'retrain_iters',
    
    'score_fn', # order is logp, qed, sas
])

params = Params(
    train_batch_size=100, 
    output_dist_std=0.01,
    output_dist_fn=tdist.Normal, 
    num_samples=10, 
    exp_noise_samples=2, 
    lr=1e-3, 
    prior_mean=0., 
    prior_std=1., 
    device='cuda', 
    num_epochs=1000,
    hsic_train_lambda=20.,
    
    prop_activation='relu',
    prop_pred_num_input_features=chemvae_num_z,
    prop_pred_num_random_inputs=20,
    prop_pred_num_hidden=50,
    prop_pred_dropout=False,
    prop_pred_depth=2,
    prop_pred_growth_factor=1.5,
    prop_batchnorm=True,
    
    ack_batch_size=5,
    num_queries=1,
    batch_opt_lr=1e-3,
    batch_opt_num_iter=100,
    input_opt_lr=1e-3,
    mves_kernel_fn='mixrq_kernels',
    input_opt_num_iter=100,
    ack_num_model_samples=100,
    hsic_diversity_lambda=1.,
    
    retrain_iters=100,
    
    score_fn=lambda x : x[0],
)

device = params.device

In [23]:
import bayesian_opt as bopt
import chemvae_bopt as cbopt
import reparam_trainer as reparam

In [24]:
zinc250k, vae, smiles_one_hot, smiles_z, labels = cbopt.load_zinc250k(score_fn=params.score_fn, num_to_load=10000)

read zinc250k from file




Using standarized functions? True
Standarization: estimating mu and std values ...done!
initialized VAEUtils
done 0K samples
converted to one_hot
(10000, 196)
processed encoded representation


In [None]:
exclude_top = 0.01
n_train = 1000
random_selection = False

num_total = len(zinc250k[0])
idx = np.arange(num_total)

if random_selection:
    train_idx, _, _ = train_val_test_split(idx, split=[n_train, 0])
    train_idx2, _, test_idx2 = train_val_test_split(n_train, split=[0.9, 0])

    test_idx = train_idx[test_idx2]
    train_idx = train_idx[train_idx2]
else:
    num_train = int(0.9 * n_train)    
    train_idx = idx[:num_train]
    test_idx = idx[num_train:n_train]
    

train_inputs = inputs[train_idx]
train_labels = labels[train_idx]
val_inputs = inputs[test_idx]
val_labels = labels[test_idx]

In [None]:
train_label_mean = train_labels.mean()
train_label_std = train_labels.std()

train_labels = (train_labels - train_label_mean) / train_label_std
val_labels = (val_labels - train_label_mean) / train_label_std

In [None]:
model = cbopt.PropertyPredictor(params)
qz = reparam.GaussianQz(params.prop_pred_num_random_inputs)

mu_e = torch.zeros(params.prop_pred_num_random_inputs, requires_grad=False).to(device)
std_e = torch.ones(params.prop_pred_num_random_inputs, requires_grad=False).to(device)

e_dist = tdist.Normal(mu_e + params.prior_mean, std_e*params.prior_std)

In [None]:
train_losses = []
train_kl_losses = []
train_hsic_losses = []
val_losses = []

train_corrs = []
val_corrs = []

train_X = torch.FloatTensor(train_inputs)
train_Y = torch.FloatTensor(train_labels)
val_X = torch.FloatTensor(val_inputs)
val_Y = torch.FloatTensor(val_labels)

In [None]:
data = [train_X, train_Y, val_X, val_Y]
i_losses, i_kl_losses, i_hsic_losses, i_val_losses, i_corrs, i_val_corrs = train(params, data, model, qz, e_dist)

train_losses = i_losses
train_kl_losses = i_kl_losses
train_hsic_losses = i_hsic_losses
val_losses = i_val_losses

train_corrs = i_corrs
val_corrs = i_val_corrs

In [None]:
plt.figure(figsize=(15, 8))

plt.subplot(121)
plt.plot(train_losses)
plt.title("log prob loss")

plt.subplot(122)
plt.plot(train_losses[-1000:])
plt.title("Recent log prob loss")

plt.figure(figsize=(15, 8))

plt.subplot(121)
plt.plot(train_kl_losses)
plt.title("kl loss")

plt.subplot(122)
plt.plot(train_kl_losses[-3000:])

plt.figure(figsize=(15, 8))
plt.subplot(121)

plt.plot(train_hsic_losses)
plt.title("hsic loss")

plt.subplot(122)
plt.plot(train_hsic_losses[-3000:])

plt.title("Recent hsic loss")

plt.figure(figsize=(15, 4))

plt.plot(train_corrs, label="train_corrs")
plt.plot(val_corrs, label="val_corrs")
plt.legend()
plt.title("Kendall Tau");

title = "DNA Binding - CRX"
train_title = title + " (train)"
val_title = title + " (val)"

if n_train > 1:
    preds = reparam.predict(train_X, model, qz, e)[:, :, 0].mean(1)
    jointplot(preds, train_labels, train_title)
    print('train_corrcoef:', np.corrcoef(preds, train_labels)[0, 1])

preds = reparam.predict(val_X, model, qz, e)[:, :, 0].mean(1)
jointplot(preds, val_labels, val_title)
print('val_corrcoef:', np.corrcoef(preds, val_labels)[0, 1])

In [None]:
e = reparam.generate_prior_samples(params.ack_num_model_samples, e_dist)
model_ensemble = reparam.generate_ensemble_from_stochastic_net(model, e)

In [None]:
input_shape = model.input_shape()

for ack_iter in range(params.num_queries)
    good_points = bopt.optimize_model_input(
        params,
        input_shape, 
        model_ensemble, 
        hsic_diversity_lambda=params.hsic_diversity_lambda
    )
    ei_batch = bopt.acquire_batch_via_grad_ei(
        params,
        model_ensemble,
        input_shape
    )
    mves_batch = bopt.acquire_batch_via_grad_mves(
        params,
        model_ensemble,
        input_shape
    )
    
    ack_props = []
    ack_props += [cbopt.acquire_properties(good_points, vae, params.device)]
    ack_props += [cbopt.acquire_properties(ei_batch, vae, params.device)]
    ack_props += [cbopt.acquire_properties(mves_batch, vae, params.device)]
    
    ack_labels = []
    for i in range(len(ack_props)):
        new_labels = [params.score_fn(ack_props[i][j]) for j in ack_props[i].shape[0]]
        ack_labels += [torch.tensor(new_labels)]
    ack_labels = torch.tensor(ack_labels)
    assert ack_labels.ndimension() == 2
    max_ack = torch.max(ack_labels, dim=1)
    
    print(max_ack)