In [1]:
import warnings
warnings.filterwarnings("ignore")

import sys
import itertools
from keras.layers import Input, Dense, Reshape, Flatten
from keras import layers, initializers
from keras.models import Model, load_model
import keras.backend as K
import numpy as np
from seqtools import SequenceTools as ST
from gfp_gp import SequenceGP
from util import AA, AA_IDX
from util import build_vae
from sklearn.model_selection import train_test_split, ShuffleSplit
from keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt
import pandas as pd
from gan import WGAN
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
import scipy.stats
from scipy.stats import norm
from scipy.optimize import minimize
from keras.utils.generic_utils import get_custom_objects
from util import one_hot_encode_aa, partition_data, get_balaji_predictions, get_samples
from util import convert_idx_array_to_aas, build_pred_vae_model, get_experimental_X_y
from util import get_gfp_X_y_aa
from losses import neg_log_likelihood
import json
plt.rcParams['figure.dpi'] = 300
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

Using TensorFlow backend.


In [2]:
def build_model(M):
    x = Input(shape=(M, 20,))
    y = Flatten()(x)
    y = Dense(50, activation='elu')(y)
    y = Dense(2)(y)
    model = Model(inputs=x, outputs=y)
    return model

def evaluate_ground_truth(X_aa, ground_truth, save_file=None):
    y_gt = ground_truth.predict(X_aa, print_every=100000)[:, 0]
    if save_file is not None:
        np.save(save_file, y_gt)
        
def train_and_save_oracles(X_train, y_train, n=10, suffix='', batch_size=100):
    for i in range(n):
        model = build_model(X_train.shape[1])
        model.compile(optimizer='adam',
                      loss=neg_log_likelihood,
                      )
        early_stop = EarlyStopping(monitor='val_loss', 
                                   min_delta=0, 
                                   patience=5, 
                                   verbose=1)

        model.fit(X_train, y_train, 
                  epochs=100, 
                  batch_size=batch_size, 
                  validation_split=0.1, 
                  callbacks=[early_stop],
                  verbose=2)
        model.save("models/oracle_%i%s.h5" % (i, suffix))

In [3]:
def weighted_ml_opt(X_train, oracles, ground_truth, vae_0, weights_type='dbas',
        LD=20, iters=20, samples=500, homoscedastic=False, homo_y_var=0.1,
        quantile=0.95, verbose=False, alpha=1, train_gt_evals=None,
        cutoff=1e-6, it_epochs=10, enc1_units=50):
    
    assert weights_type in ['cbas', 'dbas','rwr', 'cem-pi', 'fbvae']
    L = X_train.shape[1]
    vae = build_vae(latent_dim=LD,
                    n_tokens=20, seq_length=L,
                    enc1_units=enc1_units)

    traj = np.zeros((iters, 7))
    oracle_samples = np.zeros((iters, samples))
    gt_samples = np.zeros((iters, samples))
    oracle_max_seq = None
    oracle_max = -np.inf
    gt_of_oracle_max = -np.inf
    y_star = -np.inf
    
    
    # FOR REVIEW:
    all_seqs = pd.DataFrame(0, index=range(int((iters-1)*samples)), columns=['seq', 'val'])
    l_ = 0
    
    
    for t in range(iters):
        ### Take Samples ###
        zt = np.random.randn(samples, LD)
        if t > 0:
            Xt_p = vae.decoder_.predict(zt)
            Xt = get_samples(Xt_p)
        else:
            Xt = X_train
        
        ### Evaluate ground truth and oracle ###
        yt, yt_var = get_balaji_predictions(oracles, Xt)
        if homoscedastic:
            yt_var = np.ones_like(yt) * homo_y_var
        Xt_aa = np.argmax(Xt, axis=-1)
        if t == 0 and train_gt_evals is not None:
            yt_gt = train_gt_evals
        else:
            yt_gt = ground_truth.predict(Xt_aa, print_every=1000000)[:, 0]
        
        ### Calculate weights for different schemes ###
        if t > 0:
            if weights_type == 'cbas': 
                log_pxt = np.sum(np.log(Xt_p) * Xt, axis=(1, 2))
                X0_p = vae_0.decoder_.predict(zt)
                log_px0 = np.sum(np.log(X0_p) * Xt, axis=(1, 2))
                w1 = np.exp(log_px0-log_pxt)
                y_star_1 = np.percentile(yt, quantile*100)
                if y_star_1 > y_star:
                    y_star = y_star_1
                w2= scipy.stats.norm.sf(y_star, loc=yt, scale=np.sqrt(yt_var))
                weights = w1*w2 
            elif weights_type == 'cem-pi':
                pi = scipy.stats.norm.sf(max_train_gt, loc=yt, scale=np.sqrt(yt_var))
                pi_thresh = np.percentile(pi, quantile*100)
                weights = (pi > pi_thresh).astype(int)
            elif weights_type == 'dbas':
                y_star_1 = np.percentile(yt, quantile*100)
                if y_star_1 > y_star:
                    y_star = y_star_1
                weights = scipy.stats.norm.sf(y_star, loc=yt, scale=np.sqrt(yt_var))
            elif weights_type == 'rwr':
                weights = np.exp(alpha*yt)
                weights /= np.sum(weights)
        else:
            weights = np.ones(yt.shape[0])
            max_train_gt = np.max(yt_gt)
            
        yt_max_idx = np.argmax(yt)
        yt_max = yt[yt_max_idx]
        if yt_max > oracle_max:
            oracle_max = yt_max
            try:
                oracle_max_seq = convert_idx_array_to_aas(Xt_aa[yt_max_idx-1:yt_max_idx])[0]
            except IndexError:
                print(Xt_aa[yt_max_idx-1:yt_max_idx])
            gt_of_oracle_max = yt_gt[yt_max_idx]
        
        ### Record and print results ##
        if t == 0:
            rand_idx = np.random.randint(0, len(yt), samples)
            oracle_samples[t, :] = yt[rand_idx]
            gt_samples[t, :] = yt_gt[rand_idx]
        if t > 0:
            oracle_samples[t, :] = yt
            gt_samples[t, :] = yt_gt
        
        traj[t, 0] = np.max(yt_gt)
        traj[t, 1] = np.mean(yt_gt)
        traj[t, 2] = np.std(yt_gt)
        traj[t, 3] = np.max(yt)
        traj[t, 4] = np.mean(yt)
        traj[t, 5] = np.std(yt)
        traj[t, 6] = np.mean(yt_var)
        
        if verbose:
            print(weights_type.upper(), t, traj[t, 0], color.BOLD + str(traj[t, 1]) + color.END, 
                  traj[t, 2], traj[t, 3], color.BOLD + str(traj[t, 4]) + color.END, traj[t, 5], traj[t, 6])
        
        ### Train model ###
        if t == 0:
            vae.encoder_.set_weights(vae_0.encoder_.get_weights())
            vae.decoder_.set_weights(vae_0.decoder_.get_weights())
            vae.vae_.set_weights(vae_0.vae_.get_weights())
        else:
            cutoff_idx = np.where(weights < cutoff)
            Xt = np.delete(Xt, cutoff_idx, axis=0)
            yt = np.delete(yt, cutoff_idx, axis=0)
            weights = np.delete(weights, cutoff_idx, axis=0)
            vae.fit([Xt], [Xt, np.zeros(Xt.shape[0])],
                  epochs=it_epochs,
                  batch_size=10,
                  shuffle=False,
                  sample_weight=[weights, weights],
                  verbose=0)
    
    max_dict = {'oracle_max' : oracle_max, 
                'oracle_max_seq': oracle_max_seq, 
                'gt_of_oracle_max': gt_of_oracle_max}
    return traj, oracle_samples, gt_samples, max_dict

In [4]:
def fb_opt(X_train, oracles, ground_truth, vae_0, weights_type='fbvae',
        LD=20, iters=20, samples=500, 
        quantile=0.8, verbose=False, train_gt_evals=None,
        it_epochs=10, enc1_units=50):
    
    assert weights_type in ['fbvae']
    L = X_train.shape[1]
    vae = build_vae(latent_dim=LD,
                    n_tokens=20, seq_length=L,
                    enc1_units=enc1_units)

    traj = np.zeros((iters, 7))
    oracle_samples = np.zeros((iters, samples))
    gt_samples = np.zeros((iters, samples))
    oracle_max_seq = None
    oracle_max = -np.inf
    gt_of_oracle_max = -np.inf
    y_star = - np.inf
    for t in range(iters):
        ### Take Samples and evaluate ground truth and oracle ##
        zt = np.random.randn(samples, LD)
        if t > 0:
            Xt_sample_p = vae.decoder_.predict(zt)
            Xt_sample = get_samples(Xt_sample_p)
            yt_sample, _ = get_balaji_predictions(oracles, Xt_sample)
            Xt_aa_sample = np.argmax(Xt_sample, axis=-1)
            yt_gt_sample = ground_truth.predict(Xt_aa_sample, print_every=1000000)[:, 0]
        else:
            Xt = X_train
            yt, _ = get_balaji_predictions(oracles, Xt)
            Xt_aa = np.argmax(Xt, axis=-1)
            fb_thresh = np.percentile(yt, quantile*100)
            if train_gt_evals is not None:
                yt_gt = train_gt_evals
            else:
                yt_gt = ground_truth.predict(Xt_aa, print_every=1000000)[:, 0]
        
        ### Calculate threshold ###
        if t > 0:
            threshold_idx = np.where(yt_sample >= fb_thresh)[0]
            n_top = len(threshold_idx)
            sample_arrs = [Xt_sample, yt_sample, yt_gt_sample, Xt_aa_sample]
            full_arrs = [Xt, yt, yt_gt, Xt_aa]
            
            for l in range(len(full_arrs)):
                sample_arr = sample_arrs[l]
                full_arr = full_arrs[l]
                sample_top = sample_arr[threshold_idx]
                full_arr = np.concatenate([sample_top, full_arr])
                full_arr = np.delete(full_arr, range(full_arr.shape[0]-n_top, full_arr.shape[0]), axis=0)
                full_arrs[l] = full_arr
            Xt, yt, yt_gt, Xt_aa = full_arrs
        yt_max_idx = np.argmax(yt)
        yt_max = yt[yt_max_idx]
        if yt_max > oracle_max:
            oracle_max = yt_max
            try:
                oracle_max_seq = convert_idx_array_to_aas(Xt_aa[yt_max_idx-1:yt_max_idx])[0]
            except IndexError:
                print(Xt_aa[yt_max_idx-1:yt_max_idx])
            gt_of_oracle_max = yt_gt[yt_max_idx]
        
        ### Record and print results ##

        rand_idx = np.random.randint(0, len(yt), samples)
        oracle_samples[t, :] = yt[rand_idx]
        gt_samples[t, :] = yt_gt[rand_idx]

        traj[t, 0] = np.max(yt_gt)
        traj[t, 1] = np.mean(yt_gt)
        traj[t, 2] = np.std(yt_gt)
        traj[t, 3] = np.max(yt)
        traj[t, 4] = np.mean(yt)
        traj[t, 5] = np.std(yt)
        if t > 0:
            traj[t, 6] = n_top
        else:
            traj[t, 6] = 0
        
        if verbose:
            print(weights_type.upper(), t, traj[t, 0], color.BOLD + str(traj[t, 1]) + color.END, 
                  traj[t, 2], traj[t, 3], color.BOLD + str(traj[t, 4]) + color.END, traj[t, 5], traj[t, 6])
        
        ### Train model ###
        if t == 0:
            vae.encoder_.set_weights(vae_0.encoder_.get_weights())
            vae.decoder_.set_weights(vae_0.decoder_.get_weights())
            vae.vae_.set_weights(vae_0.vae_.get_weights())
        else:
        
            vae.fit([Xt], [Xt, np.zeros(Xt.shape[0])],
                  epochs=1,
                  batch_size=10,
                  shuffle=False,
                  verbose=0)
            
    max_dict = {'oracle_max' : oracle_max, 
                'oracle_max_seq': oracle_max_seq, 
                'gt_of_oracle_max': gt_of_oracle_max}
    return traj, oracle_samples, gt_samples, max_dict


In [5]:
def train_experimental_oracles():
    TRAIN_SIZE = 5000
    train_size_str = "%ik" % (TRAIN_SIZE/1000)
    i = 1
    num_models = [1, 5, 20]
    for i in range(len(num_models)):
        RANDOM_STATE = i+1
        nm = num_models[i]
        X_train, y_train, _  = get_experimental_X_y(random_state=RANDOM_STATE, train_size=TRAIN_SIZE)
        suffix = '_%s_%i_%i' % (train_size_str, nm, RANDOM_STATE)
        train_and_save_oracles(X_train, y_train, batch_size=10, n=nm, suffix=suffix)

In [6]:
def train_experimental_vaes():
    TRAIN_SIZE = 5000
    train_size_str = "%ik" % (TRAIN_SIZE/1000)
    suffix = '_%s' % train_size_str
    for i in [0, 2]:
        RANDOM_STATE = i + 1
        X_train, _, _  = get_experimental_X_y(random_state=RANDOM_STATE, train_size=TRAIN_SIZE)
        vae_0 = build_vae(latent_dim=20,
                  n_tokens=20, 
                  seq_length=X_train.shape[1],
                  enc1_units=50)
        vae_0.fit([X_train], [X_train, np.zeros(X_train.shape[0])],
                  epochs=100,
                  batch_size=10,
                  verbose=2)
        vae_0.encoder_.save_weights("models/vae_0_encoder_weights%s_%i.h5"% (suffix, RANDOM_STATE))
        vae_0.decoder_.save_weights("models/vae_0_decoder_weights%s_%i.h5"% (suffix, RANDOM_STATE))
        vae_0.vae_.save_weights("models/vae_0_vae_weights%s_%i.h5"% (suffix, RANDOM_STATE))

In [11]:
def run_experimental_weighted_ml(it, repeats=3):
    
    assert it in [0, 1, 2]
    
    TRAIN_SIZE = 5000
    train_size_str = "%ik" % (TRAIN_SIZE/1000)
    num_models = [1, 5, 20][it]
    RANDOM_STATE = it + 1
    
    X_train, y_train, gt_train  = get_experimental_X_y(random_state=RANDOM_STATE, train_size=TRAIN_SIZE)
    
    vae_suffix = '_%s_%i' % (train_size_str, RANDOM_STATE)
    oracle_suffix = '_%s_%i_%i' % (train_size_str, num_models, RANDOM_STATE)
    
    vae_0 = build_vae(latent_dim=20,
                  n_tokens=20, 
                  seq_length=X_train.shape[1],
                  enc1_units=50)

    vae_0.encoder_.load_weights("models/vae_0_encoder_weights%s.h5" % vae_suffix)
    vae_0.decoder_.load_weights("models/vae_0_decoder_weights%s.h5"% vae_suffix)
    vae_0.vae_.load_weights("models/vae_0_vae_weights%s.h5"% vae_suffix)
    
    ground_truth = SequenceGP(load=True, load_prefix="data/gfp_gp")
    
    loss = neg_log_likelihood
    get_custom_objects().update({"neg_log_likelihood": loss})
    oracles = [load_model("models/oracle_%i%s.h5" % (i, oracle_suffix)) for i in range(num_models)]
    
    test_kwargs = [
                   {'weights_type':'cbas', 'quantile': 1},
                   {'weights_type':'rwr', 'alpha': 20},
                   {'weights_type':'dbas', 'quantile': 0.95},
                   {'weights_type':'cem-pi', 'quantile': 0.8},
                   {'weights_type': 'fbvae', 'quantile': 0.8}
    ]
    
    base_kwargs = {
        'homoscedastic': False,
        'homo_y_var': 0.01,
        'train_gt_evals':gt_train,
        'samples':100,
        'cutoff':1e-6,
        'it_epochs':10,
        'verbose':True,
        'LD': 20,
        'enc1_units':50,
        'iters': 50
    }
    
    if num_models==1:
        base_kwargs['homoscedastic'] = True
        base_kwargs['homo_y_var'] = np.mean((get_balaji_predictions(oracles, X_train)[0] - y_train)**2)
    
    for k in range(repeats):
        for j in range(len(test_kwargs)):
            test_name = test_kwargs[j]['weights_type']
            suffix = "_%s_%i_%i" % (train_size_str, RANDOM_STATE, k)
            if test_name == 'fbvae':
                if base_kwargs['iters'] > 100:
                    suffix += '_long'
            
                print(suffix)
                kwargs = {}
                kwargs.update(test_kwargs[j])
                kwargs.update(base_kwargs)
                [kwargs.pop(k) for k in ['homoscedastic', 'homo_y_var', 'cutoff', 'it_epochs']]
                test_traj, test_oracle_samples, test_gt_samples, test_max = fb_opt(np.copy(X_train), oracles, ground_truth, vae_0, **kwargs)
            else:
                if base_kwargs['iters'] > 100:
                    suffix += '_long'
                kwargs = {}
                kwargs.update(test_kwargs[j])
                kwargs.update(base_kwargs)
                test_traj, test_oracle_samples, test_gt_samples, test_max = weighted_ml_opt(np.copy(X_train), oracles, ground_truth, vae_0, **kwargs)
            np.save('results/%s_traj%s.npy' %(test_name, suffix), test_traj)
            np.save('results/%s_oracle_samples%s.npy' % (test_name, suffix), test_oracle_samples)
            np.save('results/%s_gt_samples%s.npy'%(test_name, suffix), test_gt_samples )

            with open('results/%s_max%s.json'% (test_name, suffix), 'w') as outfile:
                json.dump(test_max, outfile)
        

In [12]:
def test_cbas_q(qs = [0.5, 0.75, 0.95, 1]):
    it = 0
    
    TRAIN_SIZE = 5000
    train_size_str = "%ik" % (TRAIN_SIZE/1000)
    num_models = [1, 5, 20][it]
    RANDOM_STATE = it + 1
    
    X_train, y_train, gt_train  = get_experimental_X_y(random_state=RANDOM_STATE, train_size=TRAIN_SIZE)
    
    vae_suffix = '_%s_%i' % (train_size_str, RANDOM_STATE)
    oracle_suffix = '_%s_%i_%i' % (train_size_str, num_models, RANDOM_STATE)
    
    vae_0 = build_vae(latent_dim=20,
                  n_tokens=20, 
                  seq_length=X_train.shape[1],
                  enc1_units=50)

    vae_0.encoder_.load_weights("models/vae_0_encoder_weights%s.h5" % vae_suffix)
    vae_0.decoder_.load_weights("models/vae_0_decoder_weights%s.h5"% vae_suffix)
    vae_0.vae_.load_weights("models/vae_0_vae_weights%s.h5"% vae_suffix)
    
    ground_truth = SequenceGP(load=True, load_prefix="data/gfp_gp")
    
    loss = neg_log_likelihood
    get_custom_objects().update({"neg_log_likelihood": loss})
    oracles = [load_model("models/oracle_%i%s.h5" % (i, oracle_suffix)) for i in range(num_models)]
    
    test_kwargs = [ {'weights_type':'cbas', 'quantile': q} for q in qs]
    
    base_kwargs = {
        'homoscedastic': False,
        'homo_y_var': 0.01,
        'train_gt_evals':gt_train,
        'samples':100,
        'cutoff':1e-6,
        'it_epochs':10,
        'verbose':True,
        'LD': 20,
        'enc1_units':50,
        'iters':100
    }
    
    if num_models==1:
        base_kwargs['homoscedastic'] = True
        base_kwargs['homo_y_var'] = np.mean((get_balaji_predictions(oracles, X_train)[0] - y_train)**2)
    
    for j in range(len(test_kwargs)):
        test_name = test_kwargs[j]['weights_type']
        qj = test_kwargs[j]['quantile']
        suffix = "_qtest_%s_%i_%.2f" % (train_size_str, RANDOM_STATE, qj)
        print(suffix)
        kwargs = {}
        kwargs.update(test_kwargs[j])
        kwargs.update(base_kwargs)
        test_traj, test_oracle_samples, test_gt_samples, test_max = weighted_ml_opt(np.copy(X_train), oracles, ground_truth, vae_0, **kwargs)
        
        np.save('results/%s_traj%s.npy' %(test_name, suffix), test_traj)
        np.save('results/%s_oracle_samples%s.npy' % (test_name, suffix), test_oracle_samples)
        np.save('results/%s_gt_samples%s.npy'%(test_name, suffix), test_gt_samples )

        with open('results/%s_max%s.json'% (test_name, suffix), 'w') as outfile:
            json.dump(test_max, outfile)

In [13]:
# train_experimental_vaes()
# train_experimental_oracles()

run_experimental_weighted_ml(0, repeats=1)
# test_cbas_xq(qs=[0.5, 0.75, 0.95, 1])

_5k_1_0
CBAS 0 3.1522492716011907 [1m3.1114846665730953[0m 0.036913624308314026 3.185194730758667 [1m3.1188620679855346[0m 0.02476560246847297 0.0007276618358598761
CBAS 1 3.343200333472878 [1m3.1372669984822266[0m 0.08542544879038347 3.1836001873016357 [1m3.1255055832862855[0m 0.03757263871629514 0.0007276618358598758
CBAS 2 3.3200958737603283 [1m3.206044893124411[0m 0.06534528851399343 3.1916959285736084 [1m3.149348895549774[0m 0.0232937581283955 0.0007276618358598758
CBAS 3 3.353829579982408 [1m3.2272006237767954[0m 0.06136721276436493 3.1934802532196045 [1m3.1577219104766847[0m 0.018347936844734566 0.0007276618358598758
CBAS 4 3.4121971446578723 [1m3.2309365151930445[0m 0.05999169217094124 3.1904842853546143 [1m3.1590127110481263[0m 0.017228692024066457 0.0007276618358598758
CBAS 5 3.4121971446578723 [1m3.2396768796166953[0m 0.05992245477338031 3.2004764080047607 [1m3.161492450237274[0m 0.018917833816742342 0.0007276618358598758
CBAS 6 3.3283396546772055 [1