In [2]:
import sys
sys.path.append('/cluster/sj1')
sys.path.append('/cluster/sj1/bb_opt/src')

In [3]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from collections import namedtuple
import torch.distributions as tdist
import bb_opt.src.reparam_trainer as reparam
import numpy as np
from scipy.stats import kendalltau
import matplotlib.pyplot as plt
import seaborn as sns
import bb_opt.src.dna_bopt as dbopt
from gpu_utils.utils import gpu_init
from tqdm import tnrange
import pandas as pd
from bb_opt.src.utils import train_val_test_split
from sklearn.model_selection import train_test_split
import utils

gpu_id = gpu_init()
print(f"Running on GPU {gpu_id}")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

%matplotlib inline

Running on GPU 1


In [5]:
Params = namedtuple('params', [
    'lr', 
    'num_latents', 
    'output_dist_std', 
    'output_dist_fn', 
    'prior_mean', 
    'prior_std', 
    'num_epochs', 
    'num_samples', 
    'batch_size', 
    'device', 
    'exp_noise_samples'])
params = Params(
    batch_size=100, 
    num_latents=20,
    output_dist_std=0.01,
    output_dist_fn=tdist.Normal, 
    num_samples=10, 
    exp_noise_samples=2, 
    lr=1e-3, 
    prior_mean=0., 
    prior_std=1., 
    device='cuda', 
    num_epochs=1000)

In [7]:
n_train = 1000

project = "dna_binding"
dataset = "crx_ref_r1"

root = "/cluster/sj1/bb_opt/"
data_dir = root+"data/"+project+"/"+dataset+"/"
inputs = np.load(data_dir+"inputs.npy")
labels = np.load(data_dir+"labels.npy")


exclude_top = 0.1

idx = np.arange(labels.shape[0])

sort_idx = labels.argsort()[:-int(labels.shape[0]*exclude_top)]
idx = idx[sort_idx]

train_idx, _, _ = train_val_test_split(idx, split=[n_train, 0])
train_idx2, _, test_idx2 = train_val_test_split(n_train, split=[0.9, 0])

test_idx = train_idx[test_idx2]
train_idx = train_idx[train_idx2]

train_inputs = inputs[train_idx]
train_labels = labels[train_idx]

val_inputs = inputs[test_idx]
val_labels = labels[test_idx]


#train_inputs, test_inputs, train_labels, test_labels = train_test_split(inputs, labels, train_size=n_train, random_state=521)
#train_inputs, val_inputs, train_labels, val_labels = train_test_split(train_inputs, train_labels, train_size=0.9, random_state=521)

In [8]:
train_label_mean = train_labels.mean()
train_label_std = train_labels.std()

train_labels = (train_labels - train_label_mean) / train_label_std
val_labels = (val_labels - train_label_mean) / train_label_std

In [10]:
import bb_opt.src.dna_bopt as dbopt

In [11]:
params = Params(
    batch_size=10,
    num_epochs=500,
    num_latents=15, 
    output_dist_fn=tdist.Normal, 
    num_samples=20, 
    exp_noise_samples=3, 
    lr=1e-4, 
    prior_mean=0., 
    prior_std=3.,
    device='cuda', 
    output_dist_std=1.
)

In [12]:
model, qz, e_dist = dbopt.get_model_nn(inputs.shape[1], params.num_latents, params.prior_std)

train_losses = []
train_kl_losses = []
train_hsic_losses = []
val_losses = []

train_corrs = []
val_corrs = []

num_batches: 90


In [None]:
train_X = torch.FloatTensor(train_inputs, device='cpu')
train_Y = torch.FloatTensor(train_labels, device='cpu')
val_X = torch.FloatTensor(val_inputs, device='cpu')
val_Y = torch.FloatTensor(val_labels, device='cpu')

In [403]:
import bb_opt.src.dna_bopt as dbopt

In [None]:
data = [train_X, train_Y, val_X, val_Y]
logging = dbopt.train(params, params.num_train_latent_samples, data, model, qz, e_dist)
i_losses, i_kl_losses, i_hsic_losses, i_val_losses, i_corrs, i_val_corrs = logging

train_losses += i_losses
train_kl_losses += i_kl_losses
train_hsic_losses += i_hsic_losses
val_losses += i_val_losses

train_corrs += i_corrs
val_corrs += i_val_corrs

In [None]:
plt.figure(figsize=(15, 8))

plt.subplot(121)
plt.plot(train_losses)
plt.title("log prob loss")

plt.subplot(122)
plt.plot(train_losses[-1000:])
plt.title("Recent log prob loss")

plt.figure(figsize=(15, 8))

plt.subplot(121)
plt.plot(train_kl_losses)
plt.title("kl loss")

plt.subplot(122)
plt.plot(train_kl_losses[-3000:])

plt.figure(figsize=(15, 8))
plt.subplot(121)

plt.plot(train_hsic_losses)
plt.title("hsic loss")

plt.subplot(122)
plt.plot(train_hsic_losses[-3000:])

plt.title("Recent hsic loss")

plt.figure(figsize=(15, 4))

plt.plot(train_corrs, label="train_corrs")
plt.plot(val_corrs, label="val_corrs")
plt.legend()
plt.title("Kendall Tau");

title = "DNA Binding - CRX"
train_title = title + " (train)"
val_title = title + " (val)"

if n_train > 1:
    preds = reparam.predict(train_X, model, qz, e)[:, :, 0].mean(1)
    jointplot(preds, train_labels, train_title)
    print('train_corrcoef:', np.corrcoef(preds, train_labels)[0, 1])

preds = reparam.predict(val_X, model, qz, e)[:, :, 0].mean(1)
jointplot(preds, val_labels, val_title)
print('val_corrcoef:', np.corrcoef(preds, val_labels)[0, 1])

In [None]:
X = torch.tensor(inputs, device=device)
Y = torch.tensor(labels, device=device)

In [None]:
import bayesian_opt as bopt

In [None]:
e = reparam.generate_prior_samples(params.ack_num_model_samples, e_dist)

In [None]:
skip_idx = set(train_idx)
for ack_iter in range(10):
    model_ensemble = reparam.generate_ensemble_from_stochastic_net(model, e)
    preds = model_ensemble(X, resize_at_end=True) # (num_candidate_points, num_samples)
    preds = preds.tranpose(0, 1)
    
    ei = preds.mean(dim=0).view(-1).cpu().numpy()
    ei_sortidx = np.argsort(ei)
    
    ack_ei = X[ei_sortidx[-params.ack_batch_size:]]
    
    train_X = torch.cat([train_X, ack_ei], dim=0)
    data = [train_X, train_Y, val_X, val_Y]
    logging = dbopt.train(params, params.num_train_latent_samples, data, model, qz, e_dist)

In [None]:
skip_idx = set(train_idx)
for ack_iter in range(10):
    preds = model_ensemble(X, resize_at_end=True) # (num_candidate_points, num_samples)
    preds = preds.tranpose(0, 1)
    
    max_pred = preds.max(dim=1).view(-1)
    mves_idx = dbopt.acquire_batch_mves_sid(params, max_pred, preds, params.mves_compute_batch_size)
    skip_idx.update(mves_idx)
    
    ack_mves = X[torch.tensor(mves_idx, device=params.device)]
    
    train_X = torch.cat([train_X, ack_mves], dim=0)
    data = [train_X, train_Y, val_X, val_Y]
    logging = dbopt.train(params, params.num_train_latent_samples, data, model, qz, e_dist)

In [None]:
for ack_iter in range(10):
    preds = model_ensemble(X, resize_at_end=True) # (num_candidate_points, num_samples)
    preds = preds.tranpose(0, 1)
    
    max_pred_idx = preds.argmax(dim=1).view(-1)
    max_pred = preds.max(dim=1).view(-1)