# VAE for stick-breaking gaussian vs concrete

In [1]:
import tensorflow as tf
import sys
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
sys.path.append('/Users/Cybele/GIT/birkhoff/birkhoff/')
sys.path.append('/Users/Cybele/GIT/birkhoff/')
import numpy as np

import categorical as cat
%matplotlib inline
slim=tf.contrib.slim
Bernoulli = tf.contrib.distributions.Bernoulli
from scipy.special import gammaln

In [2]:
#Create gaussian-stick-breaking-sigmoid related functions(more generally, any stick-breaking reparametrizable one)
def psi_to_pi(psi):
    #return tf.cumsum(tf.log(1-psi), axis=1)
    #print psi.get_shape()[0]
    log1 = tf.concat(1, [ tf.log(psi), tf.zeros((tf.shape(psi)[0],1))])
    log2 = tf.cumsum(tf.concat(1, [ tf.zeros((tf.shape(psi)[0],1)), tf.log(1-psi) ]) ,axis=1)
    return tf.exp(log2+log1)

def sample_pi_from_gaussian(params, temperature, eps=1e-15):
    mu, logit_sigma = tf.split(1, 2, params) 
    #log_s = tf.maximum(tf.minimum(log_std,10),-10)
    sigma_min= 1e-9
    sigma_max= 2
    sigma = sigma_min +(sigma_max-sigma_min)*tf.sigmoid(logit_sigma)
    sample = (mu + sigma * tf.random_normal(tf.shape(mu), mean=0.0, stddev=1.0)) / temperature 
    psi =  tf.maximum(tf.minimum(tf.sigmoid(sample), 1-eps), eps)
    return (sample,psi,psi_to_pi(psi))

def log_density_pi_gaussian(sample, params, temperature, eps=1e-15):
    mu, logit_sigma = tf.split(1, 2, params)
    K = tf.shape(sample)[1]+1
    N = tf.shape(sample)[0]
    sigma_min= 1e-9
    sigma_max= 2
    sigma = sigma_min +(sigma_max-sigma_min)*tf.sigmoid(logit_sigma) / temperature
    
    var =  tf.pow(sigma, 2.0) / tf.pow(temperature, 2.0)
    psi = tf.maximum(tf.minimum(tf.sigmoid(sample), 1-eps), eps)
    seq = tf.cast(tf.tile(tf.reshape(tf.range(K-2,0,-1),[1,-1]),[N,1]),dtype=tf.float32)
    return  -tf.reduce_sum(tf.multiply(seq, tf.log(1 - tf.slice(psi,[0,0],[N,K-2]))), axis=1) + tf.reduce_sum(-0.5 * tf.div(tf.pow(sample-mu / temperature, 2.0), var)  - 0.5 * tf.log(2*np.pi) - 0.5* tf.log(var) , axis=1) \
       + tf.reduce_sum(-tf.log(psi) - tf.log(1-psi), axis=1) 
    #return tf.reduce_sum(tf.div(tf.pow(sample-mu, 2.0), var)   , axis=1)

In [3]:
#Gumbel-related function
def sample_gumbel(shape, eps=1e-20): 
    """Sample from Gumbel(0, 1)"""
    U = tf.random_uniform(shape,minval=0,maxval=1)
    return -tf.log(-tf.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature): 
    """ Draw a sample from the Gumbel-Softmax distribution"""
    y = logits + sample_gumbel(tf.shape(logits))
    return tf.nn.softmax( y / temperature)


def gumbel_softmax(logits, temperature, hard=False):
    """Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
    logits: [batch_size, n_class] unnormalized log-probs
    temperature: non-negative scalar
    hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
    [batch_size, n_class] sample from the Gumbel-Softmax distribution.
    If hard=True, then the returned sample will be one-hot, otherwise it will
    be a probabilitiy distribution that sums to 1 across classes
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        k = tf.shape(logits)[-1]
        #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
        y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
    return y

def log_density_pi_concrete(pi,params,temperature):
    K=tf.cast(tf.shape(pi)[1],dtype=tf.float32)
    return -tf.log(tf.reduce_sum(tf.pow(pi, -temperature) * tf.exp(params), axis=1)) * K  \
    + tf.reduce_sum(params + (-temperature - 1) * tf.log(pi), axis=1) + (K - 1) * tf.log(temperature) + tf.lgamma(K+0.0)
         


In [4]:
K=10 # number of classes
N=30 # number of categorical distributions


In [None]:

# Create gaussian VAE

x = tf.placeholder(tf.float32,[None,784])
net = slim.stack(x,slim.fully_connected,[512,256])
params = tf.reshape(slim.fully_connected(net,2*(K-1)*N,activation_fn=None),[-1,2*(K-1)])

tau = tf.Variable(5.0,name="temperature")
 
sample,psi,pi = sample_pi_from_gaussian(params,tau)
z = tf.reshape(pi,[-1,N,K])

log_dens=log_density_pi_gaussian(sample, params, tau)

net = slim.stack(slim.flatten(z),slim.fully_connected,[256,512])
logits_x = slim.fully_connected(net,784,activation_fn=None)
p_x = Bernoulli(logits=logits_x)



In [None]:
# input image x (shape=(batch_size,784))

#Debug Gaussian VAE: verify the outputs by tensorflow code are the same as in categorical.py
# variational posterior q(y|x), i.e. the encoder (shape=(batch_size,200))
#x = tf.placeholder(tf.float32,[None,784])
#net = slim.stack(x,slim.fully_connected,[512,256])
K=5
N=1000
params = tf.concat(1, [ -1*tf.ones((1000,K-1)), 0*tf.ones((1000,K-1)) ]) 
tau = tf.Variable(1.0,name="temperature")
with tf.Session() as sess:
    sample,psi,pi = sample_pi_from_gaussian(params,tau)
    log_dens=log_density_pi_gaussian(sample, params, tau)
    np_sample=sess.run(sample,feed_dict={tau:1})
    np_log_dens=sess.run(log_dens,feed_dict={tau:1})
    np_pi=sess.run(pi,feed_dict={tau:1})
    np_psi = sess.run(psi,feed_dict={tau:1})
    np_params = sess.run(params,feed_dict={tau:1})
    print np.mean(np_log_dens)
    print np.mean(cat.log_density_gaussian_psi(np_sample, np_params, 0,10000))


In [5]:

# Create gumbell VAE

x = tf.placeholder(tf.float32,[None,784])
net = slim.stack(x,slim.fully_connected,[512,256])
params = tf.reshape(slim.fully_connected(net,K*N,activation_fn=None),[-1,K])

tau = tf.Variable(5.0,name="temperature")
 
pi = gumbel_softmax_sample(params,tau)
z = tf.reshape(pi,[-1,N,K])

log_dens=log_density_pi_concrete(pi, params, tau)

net = slim.stack(slim.flatten(z),slim.fully_connected,[256,512])
logits_x = slim.fully_connected(net,784,activation_fn=None)
p_x = Bernoulli(logits=logits_x)



In [None]:

#Debug Gumbell
K=5
N=100000
tau = tf.Variable(5.0,name="temperature")
params = -1*tf.ones((N,K))
with tf.Session() as sess:
    
    pi = gumbel_softmax_sample(params,tau)
    log_dens=log_density_pi_concrete(pi, params, tau)
    np_sample=sess.run(pi,feed_dict={tau:1})
    np_log_dens=sess.run(log_dens,feed_dict={tau:1})
    np_pi=sess.run(pi,feed_dict={tau:1})
   
    np_params = sess.run(params,feed_dict={tau:1})
    
    print np.mean(np_log_dens)
    print np.mean(cat.log_density_pi_concrete(np_sample, np_params, 1))


In [6]:
# loss and train ops (Run either the gumbel or the gaussian for now)
kl_tmp = tf.reshape((log_dens+tf.lgamma(K+0.0)),[-1,N])
KL = tf.reduce_sum(kl_tmp,1)
elbo=tf.reduce_sum(p_x.log_prob(x),1) - KL


In [7]:
loss=tf.reduce_mean(-elbo)
lr=tf.constant(0.001)
train_op=tf.train.AdamOptimizer(learning_rate=lr).minimize(loss,var_list=slim.get_model_variables())
init_op=tf.initialize_all_variables()

Instructions for updating:
Use `tf.global_variables_initializer` instead.


Instructions for updating:
Use `tf.global_variables_initializer` instead.


In [8]:
# get data
data = input_data.read_data_sets('/tmp/', one_hot=True).train 

Extracting /tmp/train-images-idx3-ubyte.gz
Extracting /tmp/train-labels-idx1-ubyte.gz
Extracting /tmp/t10k-images-idx3-ubyte.gz
Extracting /tmp/t10k-labels-idx1-ubyte.gz


In [9]:
BATCH_SIZE=100
NUM_ITERS=200
tau0=1.0 # initial temperature
np_temp=tau0
np_lr=0.001
ANNEAL_RATE=0.00003
MIN_TEMP=0.5

In [10]:
dat=[]
sess=tf.InteractiveSession()
sess.run(init_op)
for i in range(1,NUM_ITERS):
  np_x,np_y=data.next_batch(BATCH_SIZE)
  _,np_loss=sess.run([train_op,loss],{
      x:np_x,
      tau:np_temp,
      lr:np_lr
    })
  if i % 100 == 1:
    dat.append([i,np_temp,np_loss])
  if i % 1000 == 1:
    np_temp=np.maximum(tau0*np.exp(-ANNEAL_RATE*i),MIN_TEMP)
    np_lr*=0.9
  if i % 10 == 1:
      print('Step %d, ELBO: %0.3f' % (i,-np_loss))

Step 1, ELBO: -1412.705
Step 11, ELBO: -1174.455
Step 21, ELBO: -1084.537
Step 31, ELBO: -1092.381
Step 41, ELBO: -1075.355
Step 51, ELBO: -1073.973
Step 61, ELBO: -1076.603
Step 71, ELBO: -1070.497
Step 81, ELBO: -1088.307
Step 91, ELBO: -1091.449
Step 101, ELBO: -1068.103
Step 111, ELBO: -1071.765
Step 121, ELBO: -1067.802
Step 131, ELBO: -1063.069
Step 141, ELBO: -1062.998
Step 151, ELBO: -1087.623
Step 161, ELBO: -1066.692
Step 171, ELBO: -1083.626
Step 181, ELBO: -1088.090
Step 191, ELBO: -1069.358
