## Adversarial Variational Bayes ##
- https://arxiv.org/pdf/1701.04722.pdf
- Use adversarial training along with a more flexible inference model 
- Reformulates VAE so that the noise is input into the model along with data
- This allows model to learn arbitrary probability distributions in the latent space (unlike imposing a Gaussian for traditional VAEs)
- They also introduce a discriminator which takes pairs of data and latent representation and tries to distinguish actual pairs from the data and generated pairs from the current model
- The new optimization objective aims to maximize (over inference and generative parameters) the log probability of observing the data, $x$, given the learned latent representation, $z$, minus the discriminator, $T$'s, ability to tell that the pair $x, z$ were generated

$$\max_{\theta,\phi} \mathbf{E}_{p_{D}(x)}\mathbf{E}_{\epsilon}\left(-T^*(x, z_\phi(x, \epsilon)) + \text{log} p_\theta(x \mid z_\phi(x, \epsilon) ) \right) $$

- $z_\phi$ is the inference model
- $\epsilon$ is Gaussian noise
- $p_\theta$ is the generative model
- $T^*$ is the optimal discriminator

In [1]:
import tensorflow as tf
from keras import objectives
import numpy as np
import pandas as pd
import seaborn as sns
from tensorflow.python.debug.lib import debug_utils
from sklearn.cross_validation import train_test_split
np.set_printoptions(precision=3)
np.random.seed(18181)  # for reproducibility

%matplotlib inline

import plotly.plotly as py
import plotly.graph_objs as go

import matplotlib.pyplot as plt
sns.set(style="white", color_codes=True)

from tensorflow.python.client import device_lib
device_lib.list_local_devices()

Using TensorFlow backend.


[name: "/cpu:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 4765815740749994049, name: "/gpu:0"
 device_type: "GPU"
 memory_limit: 15927646618
 locality {
   bus_id: 2
 }
 incarnation: 15212392117987625565
 physical_device_desc: "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:82:00.0"]

In [2]:
# Simple NN implementation in TF

def _build_layer(layer_input, 
                 input_dim, 
                 output_dim, 
                 name=None, 
                 activation=None):
    
    # Create variable named "weights".
    weights = tf.get_variable("weights", 
                              [input_dim, output_dim], 
                              dtype=tf.float32,
                              initializer=tf.random_normal_initializer(stddev=np.sqrt(2./input_dim), 
                                                                       seed=18181))
    
    # Create variable named "biases".
    biases = tf.get_variable("biases", 
                             output_dim, 
                             dtype=tf.float32,
                             initializer=tf.constant_initializer(0.0))
    
    #print weights.name
    x = tf.matmul(layer_input, weights) + biases
    if activation is None:
        return tf.nn.elu(x)
    else:
        return activation(x)
    
def build_deep_net(net_input, 
                   input_dim, 
                   hidden_dim, 
                   output_dim, 
                   n_layers, 
                   name,
                   activation=None,
                   parent_scope=False):
    
    assert(n_layers > 2)
    
    if parent_scope:
        # Input layer
        with tf.variable_scope(name + "_input", reuse=True):
            current_input = _build_layer(net_input, input_dim, hidden_dim)

        # Hidden layers
        for i in range(n_layers - 2):
            with tf.variable_scope(name + "_hidden" + str(i), reuse=True):
                current_input = _build_layer(current_input, hidden_dim, hidden_dim)

        # Output layer
        with tf.variable_scope(name + "_output", reuse=True):
            # We only allow activation specification of the last layer
            output = _build_layer(current_input, hidden_dim, output_dim, name=name, activation=activation)
            
    else:
        # Create a new scope
        with tf.variable_scope(name):

            # Input layer
            with tf.variable_scope(name + "_input"):
                current_input = _build_layer(net_input, input_dim, hidden_dim)

            # Hidden layers
            for i in range(n_layers - 2):
                with tf.variable_scope(name + "_hidden" + str(i)):
                    current_input = _build_layer(current_input, hidden_dim, hidden_dim)

            # Output layer
            with tf.variable_scope(name + "_output"):
                output = _build_layer(current_input, hidden_dim, output_dim, name=name, activation=activation)
    
    return output

def build_layer(layer_input, input_dim, output_dim, name, activation=None):
    # Single layer net
    with tf.variable_scope(name):
        return _build_layer(layer_input, input_dim, output_dim, activation)

In [3]:
# Generate test data 
n_pts = 15000

def genSpiral(frac, noiseStd = 0.01, spins = 2):
    x = frac*np.cos(2*np.pi*frac*spins) + np.random.normal(0, noiseStd)
    y = frac*np.sin(2*np.pi*frac*spins) + np.random.normal(0, noiseStd)
    return([x, y])

# Generate three point clouds
x_train = np.random.multivariate_normal((0,0), np.diag([1,1]), n_pts/3)
x_train = np.concatenate([x_train, np.random.multivariate_normal((-10,-20), np.diag([1,1]), n_pts/3)])
x_train = np.concatenate([x_train, np.random.multivariate_normal((30,20), np.diag([1,1]), n_pts/3)])
np.random.shuffle(x_train)
x_train = x_train.astype(np.float32)

In [4]:
# Test data plot
trace = go.Scatter(
    x = x_train[:500,0],
    y = x_train[:500,1],
    mode = 'markers',
    marker=dict(
        size='4'#,
        #color = x_train[indices, markers['CD3']], 
        #colorscale='Viridis',
        #showscale=True
    )
)

plt_data = [trace]

layout = go.Layout(
    xaxis=dict(
       range=[-40, 40]
    ),
    yaxis=dict(
        range=[-40,40]
   ),
   height=800,
   width=800,
)

fig = go.Figure(data=plt_data, layout=layout)
py.iplot(fig, filename='avb')

In [76]:
tf.reset_default_graph()

original_dim = 2
latent_dim = 1
layers = 6 # >= 3
batch_size = 200
hidden_dim = 256
epsilon_std = 1.
clipNorm = 0.05
M = 30 # number of noise vectors
lr = 10**-6 # Learning Rate
epochs = 5000

x = tf.placeholder(tf.float32, [batch_size, original_dim], name='x')
z = tf.random_normal([batch_size, latent_dim], name="z")

## Adaptive Contrast ##
- Using "Adaptive Contrast" from the paper
    - Allows us to compare the current inference model to an adaptive distribution instead of the prior $p(z)$
    - Estimate moments of inference model
    - Generate $m$ noise vectors ($\epsilon$)
    - Each noise vector is passed through a small NN producing a vector with same dimensionality as latent space ... $v_{i,k}(\epsilon)$
    - Another set of NNs are used to estimate $m$ coefficients from the input $x$ .... $a_i(x)$
    - The latent space representation is a linear combination of the learned noise vectors using these coefficients
    
    $$z_k = \sum_{i=1}^m v_{i,k}(\epsilon_i )a_{i,k}(x)$$

In [77]:
# Adaptive distribution r_alpha(z | x)
# Generate noise vector nets
noise_basis = []

for i in range(M):
    noise_basis.append(build_deep_net(   tf.random_normal([batch_size, latent_dim], 
                                                          stddev=epsilon_std, 
                                                          dtype=tf.float32), 
                                         latent_dim, 
                                         128, 
                                         latent_dim, 
                                         16, 
                                         name="alpha_noise_basis_{}".format(i)))
    
v_basis = tf.stack(noise_basis, axis=1, name="v_basis")  
a_phis = tf.expand_dims(build_deep_net( x, 
                                        original_dim, 
                                        hidden_dim, 
                                        M, 
                                        layers, 
                                        name="alpha_a"), axis=1)

# Moments estimated from minibatch - for each noise basis vector
alpha_mini_mu = tf.reduce_mean(v_basis, axis=0, name="alpha_mini_mu")
alpha_mini_sigma = tf.reduce_mean( (v_basis - alpha_mini_mu)**2 , axis=0, name="alpha_mini_sigma")

# Estimate moments of current inference model for x
alpha_mu_x = tf.reduce_sum(alpha_mini_mu * tf.transpose(a_phis, perm=[0,2,1]), axis=1, name="alpha_mu_x")
alpha_sigma_x = tf.reduce_sum(alpha_mini_sigma * tf.transpose(a_phis, perm=[0,2,1])**2, axis=1, name="alpha_sigma_x")

## Encoder/Inference ##
- Creates a latent representation of data

$$q_\phi\left( z \mid x\right)$$

- Using the reparameratization trick from Kingma and Welling...

$$z_\phi(x, \epsilon)$$

In [78]:
# Inference model - q_phi(z | x)
# Linear combination of the noise basis vectors from above
z_phi = tf.squeeze(tf.matmul(a_phis, v_basis), axis=1, name="z_phi") 

# Z-score transform it
z_hat = tf.identity((z_phi - alpha_mu_x) / alpha_sigma_x, name="z_hat")

## Decoder/Generative ##
- Reconstruct original from latent space representation
$$p_\theta(x \mid z) $$

In [79]:
# Decoder - P(x | z) - the output of this should be the original x
p_theta = build_deep_net(z_hat, 
                         latent_dim, 
                         hidden_dim, 
                         original_dim, 
                         layers,
                         activation=tf.identity,
                         name="p_theta")

## Discriminator ##
- Tries to distinguish generated from real pairs of $x$ and $z$
- $\text{z_hat}$ is the Z-score transformed version of $z$ generated by the inference model
- That way, the discriminator just compares this to the multivariate Gaussian

In [80]:
# Discriminator - Same Network - Distinguish generated pairs (x,z) from real
T_psi_generated = build_deep_net(tf.concat([x, z_hat], axis=1), 
                                           original_dim + latent_dim, 
                                           hidden_dim, 
                                           1, 
                                           layers, 
                                           name="T_psi",
                                           parent_scope=False)

with tf.variable_scope("T_psi"): # "real" input, but using same network as above - using adaptive contrast
    T_psi_real = build_deep_net(tf.concat([x, z], axis=1), 
                                           original_dim + latent_dim, 
                                           hidden_dim, 
                                           1, 
                                           layers, 
                                           name="T_psi",
                                           parent_scope=True)

## Losses and Optimizer ##
- These are negated from the paper so we can apply gradient descent
- We use MSE of receonstruction for $p(x\mid z)$

In [81]:
loss_theta = objectives.mean_squared_error(x, p_theta) # P(x|z) should be high (so minimize recon error)
loss_phi = T_psi_generated + loss_theta # Phi is the inference network. We want to minimize the 
                                        # reconstruction error and the discriminator's ability to distinguish
                                        # real from generated samples

# This is original since it's subtracted in paper
# Added the squared value loss to regularize and to prevent saturation of either
loss_psi =  tf.log(tf.sigmoid(T_psi_generated) + 10**-5) + \
            tf.log(1 - tf.sigmoid(T_psi_real) + 10**-5) 
            #0.1*tf.square(T_psi_generated) + \
            #0.1*tf.square(T_psi_real)

In [82]:
def compute_apply_grad(loss, keyword=None, learn_rate=0.001, clip_norm=0.01):
    """ Computes the gradients of variables matching keyword. Returns an optimizer"""
    opt = tf.train.AdamOptimizer(learning_rate=learn_rate)
    if keyword is None:
        cur_vars = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)]
    else:
        cur_vars = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if keyword in v.name]
    g = opt.compute_gradients(loss, cur_vars)
    g = [(tf.clip_by_norm(gr, clip_norm), v) for gr, v in g]
    optimize = opt.apply_gradients(g)
    return optimize

In [None]:
optimizer_all = compute_apply_grad(loss_theta + loss_phi + loss_psi, learn_rate=lr, clip_norm=clipNorm)

## Training ##

In [83]:
avg_theta = tf.reduce_mean(loss_theta) # Over batch
avg_phi = tf.reduce_mean(loss_phi) 
avg_psi = tf.reduce_mean(loss_psi)

In [84]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
ctr = 0
for e in range(epochs):
    for i in range(0, x_train.shape[0], batch_size):
        ex = x_train[i:i+batch_size,:]
        
        _, at, ap, aps = sess.run([optimizer_all,
                                   avg_theta,
                                   avg_phi,
                                   avg_psi],
                                  feed_dict={x: ex})            
        ctr += 1
    if e % 20 == 0:    
        statusStr = "theta(generative)={:.2f}, phi(inference)={:.2f}, psi(discrimination)={:.2f}"
        print "Epoch:", '%04d' % (e+1), statusStr.format(at, ap, aps)

Epoch: 0001 theta(generative)=201.10, phi(inference)=200.10, psi(discrimination)=-1.63
Epoch: 0021 theta(generative)=171.37, phi(inference)=170.40, psi(discrimination)=-2.08


In [75]:
# Predict on some training data
x_sample = x_train[np.random.choice(x_train.shape[0], batch_size, replace=False), :]
pred_z, pred_x = sess.run([z_phi, p_theta], feed_dict={x: x_sample})

# Create a trace
pred = go.Scatter(
    x = pred_x[:,0],
    y = pred_x[:,1],
    mode = 'markers',
    marker=dict(
        size='4'
    )
)

orig = go.Scatter(
    x = x_sample[:,0],
    y = x_sample[:,1],
    mode = 'markers',
    marker=dict(
        size='4'
    )
)

z_plt = go.Scatter(
    x = pred_z,
    y = np.repeat(0, pred_z.shape[0]),
    mode = 'markers',
    marker=dict(
        size='4'
    )
)

plt_data = [pred, orig, z_plt]

layout = go.Layout(
   height=800,
   width=800,
   xaxis=dict(
       range=[-40, 40]
    ),
    yaxis=dict(
        range=[-40,40]
   )
)

fig = go.Figure(data=plt_data, layout=layout)
py.iplot(fig, filename='avb')