# VAE for stick-breaking gaussian vs concrete

In [1]:
import tensorflow as tf
import sys
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
sys.path.append('/Users/Cybele/GIT/birkhoff/birkhoff/')
sys.path.append('/Users/Cybele/GIT/birkhoff/')
import numpy as np

import categorical as cat
%matplotlib inline
slim=tf.contrib.slim
Bernoulli = tf.contrib.distributions.Bernoulli
Dirichlet = tf.contrib.distributions.Dirichlet
from scipy.special import gammaln

[[ -2.07232658e+01  -2.07232658e+01  -2.07232658e+01 ...,  -1.60425682e+01
   -2.07232658e+01  -6.25831864e+00]
 [ -2.07232658e+01  -2.07232658e+01  -2.07232658e+01 ...,  -1.37222759e-05
   -2.07232658e+01  -1.11964969e+01]
 [ -2.07232658e+01  -2.07232658e+01  -2.07232658e+01 ...,  -2.07232658e+01
   -9.99999972e-10  -2.07232658e+01]
 ..., 
 [ -2.07232658e+01  -2.07232658e+01  -2.07232658e+01 ...,  -2.07232658e+01
   -2.07232658e+01  -1.11826596e+01]
 [ -2.07232658e+01  -2.07232658e+01  -2.07232658e+01 ...,  -2.07232658e+01
   -2.07232658e+01  -2.07232658e+01]
 [ -2.07232658e+01  -2.07232658e+01  -9.43282997e+00 ...,  -2.07232658e+01
   -2.07232658e+01  -2.07232658e+01]]


In [2]:
#Create gaussian-stick-breaking-sigmoid related functions(more generally, any stick-breaking reparametrizable one)
def psi_to_pi(psi):
    #return tf.cumsum(tf.log(1-psi), axis=1)
    #print psi.get_shape()[0]
    log1 = tf.concat(1, [ tf.log(psi), tf.zeros((tf.shape(psi)[0],1))])
    log2 = tf.cumsum(tf.concat(1, [ tf.zeros((tf.shape(psi)[0],1)), tf.log(1-psi) ]) ,axis=1)
    return tf.exp(log2+log1)

def sample_pi_from_gaussian(params, temperature, eps=1e-9):
    logit_mu, logit_sigma = tf.split(1, 2, params) 
    #log_s = tf.maximum(tf.minimum(log_std,10),-10)
    sigma_min= 1e-9
    sigma_max= 5
    mu_min = -5
    mu_max = 5
    mu = mu_min +(mu_max-mu_min)*tf.sigmoid(logit_mu)
    sigma = sigma_min +(sigma_max-sigma_min)*tf.sigmoid(logit_sigma)
    sample = (mu + sigma * tf.random_normal(tf.shape(mu), mean=0.0, stddev=1.0)) / temperature 
    psi =  (eps) +(1-2*eps)*tf.sigmoid(sample)
    return (sample,psi,psi_to_pi(psi))

def log_density_pi_gaussian(sample, params, temperature, eps=1e-9):
    logit_mu, logit_sigma = tf.split(1, 2, params)
    K = tf.shape(sample)[1]+1
    N = tf.shape(sample)[0]
    sigma_min= 1e-9
    sigma_max= 5
    mu_min = -5
    mu_max = 5
    mu = mu_min +(mu_max-mu_min)*tf.sigmoid(logit_mu)

    sigma = sigma_min +(sigma_max-sigma_min)*tf.sigmoid(logit_sigma) / temperature
    
    var =  tf.pow(sigma, 2.0) / tf.pow(temperature, 2.0)
    psi =  (eps) +(1-2*eps) * tf.sigmoid(sample)
    seq = tf.cast(tf.tile(tf.reshape(tf.range(K-2,0,-1),[1,-1]),[N,1]),dtype=tf.float32)
    return  -tf.reduce_sum(tf.multiply(seq, tf.log(1 - tf.slice(psi,[0,0],[N,K-2]))), axis=1) + tf.reduce_sum(-0.5 * tf.div(tf.pow(sample-mu / temperature, 2.0), var)  - 0.5 * tf.log(2*np.pi) - 0.5* tf.log(var) , axis=1) \
       + tf.reduce_sum(-tf.log(psi) - tf.log(1-psi), axis=1) 
    #return tf.reduce_sum(tf.div(tf.pow(sample-mu, 2.0), var)   , axis=1)

In [None]:
#Gumbel-related function
def sample_gumbel(shape, eps=1e-20): 
    """Sample from Gumbel(0, 1)"""
    U = tf.random_uniform(shape,minval=0,maxval=1)
    return -tf.log(-tf.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature): 
    """ Draw a sample from the Gumbel-Softmax distribution"""
    y = logits + sample_gumbel(tf.shape(logits))
    return tf.nn.softmax( y / temperature)


def gumbel_softmax(logits, temperature, hard=False):
    """Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
    logits: [batch_size, n_class] unnormalized log-probs
    temperature: non-negative scalar
    hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
    [batch_size, n_class] sample from the Gumbel-Softmax distribution.
    If hard=True, then the returned sample will be one-hot, otherwise it will
    be a probabilitiy distribution that sums to 1 across classes
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        k = tf.shape(logits)[-1]
        #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
        y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
    return y

def log_density_pi_concrete(pi,params,temperature):
    K=tf.cast(tf.shape(pi)[1],dtype=tf.float32)
    return -tf.log(tf.reduce_sum(tf.pow(pi, -temperature) * tf.exp(params), axis=1)) * K  \
    + tf.reduce_sum(params + (-temperature - 1) * tf.log(pi), axis=1) + (K - 1) * tf.log(temperature) + tf.lgamma(K+0.0)
         


In [3]:
K=5 # number of classes
N=10 # number of categorical distributions


In [4]:
alpha=0.3
# Create gaussian VAE

x = tf.placeholder(tf.float32,[None,784])
net = slim.stack(x,slim.fully_connected,[512,256])
params = tf.reshape(slim.fully_connected(net,2*(K-1)*N,activation_fn=None),[-1,2*(K-1)])

tau = tf.Variable(5.0,name="temperature")
 
sample,psi,pi = sample_pi_from_gaussian(params,tau)
z = tf.reshape(pi,[-1,N,K])

log_dens=log_density_pi_gaussian(sample, params, tau)

net = slim.stack(slim.flatten(z),slim.fully_connected,[256,512])
logits_x = slim.fully_connected(net,784,activation_fn=None)
p_x = Bernoulli(logits=logits_x)
p_z = Dirichlet(tf.ones(K)*alpha)


In [None]:
# input image x (shape=(batch_size,784))

#Debug Gaussian VAE: verify the outputs by tensorflow code are the same as in categorical.py
# variational posterior q(y|x), i.e. the encoder (shape=(batch_size,200))
#x = tf.placeholder(tf.float32,[None,784])
#net = slim.stack(x,slim.fully_connected,[512,256])
K=5
N=1000
params = tf.concat(1, [ -1*tf.ones((1000,K-1)), 0*tf.ones((1000,K-1)) ]) 
tau = tf.Variable(1.0,name="temperature")
with tf.Session() as sess:
    sample,psi,pi = sample_pi_from_gaussian(params,tau)
    log_dens=log_density_pi_gaussian(sample, params, tau)
    np_sample=sess.run(sample,feed_dict={tau:1})
    np_log_dens=sess.run(log_dens,feed_dict={tau:1})
    np_pi=sess.run(pi,feed_dict={tau:1})
    np_psi = sess.run(psi,feed_dict={tau:1})
    np_params = sess.run(params,feed_dict={tau:1})
    print np.mean(np_log_dens)
    print np.mean(cat.log_density_gaussian_psi(np_sample, np_params, 0,10000))


In [None]:

# Create gumbell VAE
alpha=1
x = tf.placeholder(tf.float32,[None,784])
net = slim.stack(x,slim.fully_connected,[512,256])
params = tf.reshape(slim.fully_connected(net,K*N,activation_fn=None),[-1,K])

tau = tf.Variable(5.0,name="temperature")
 
pi = gumbel_softmax_sample(params,tau)
z = tf.reshape(pi,[-1,N,K])

log_dens=log_density_pi_concrete(pi, params, tau)

net = slim.stack(slim.flatten(z),slim.fully_connected,[256,512])
logits_x = slim.fully_connected(net,784,activation_fn=None)
p_x = Bernoulli(logits=logits_x)
p_z = Dirichlet(tf.ones(K)*alpha)


In [None]:

#Debug Gumbell
K=5
N=100000
tau = tf.Variable(5.0,name="temperature")
params = -1*tf.ones((N,K))
with tf.Session() as sess:
    
    pi = gumbel_softmax_sample(params,tau)
    log_dens=log_density_pi_concrete(pi, params, tau)
    np_sample=sess.run(pi,feed_dict={tau:1})
    np_log_dens=sess.run(log_dens,feed_dict={tau:1})
    np_pi=sess.run(pi,feed_dict={tau:1})
   
    np_params = sess.run(params,feed_dict={tau:1})
    
    print np.mean(np_log_dens)
    print np.mean(cat.log_density_pi_concrete(np_sample, np_params, 1))


In [5]:
# loss and train ops (Run either the gumbel or the gaussian for now)
#kl_tmp = tf.reshape((log_dens+tf.lgamma(K+0.0)),[-1,N])
kl_tmp = tf.reshape((log_dens-p_z.log_prob(pi)),[-1,N])
KL = tf.reduce_sum(kl_tmp,1)
elbo=tf.reduce_sum(p_x.log_prob(x),1) - KL


In [6]:
loss=tf.reduce_mean(-elbo)
lr=tf.constant(0.001)
train_op=tf.train.AdamOptimizer(learning_rate=lr).minimize(loss,var_list=slim.get_model_variables())
init_op=tf.initialize_all_variables()


Instructions for updating:
Use `tf.global_variables_initializer` instead.


Instructions for updating:
Use `tf.global_variables_initializer` instead.


In [7]:
# get data
data = input_data.read_data_sets('/tmp/', one_hot=True).train 

Extracting /tmp/train-images-idx3-ubyte.gz
Extracting /tmp/train-labels-idx1-ubyte.gz
Extracting /tmp/t10k-images-idx3-ubyte.gz
Extracting /tmp/t10k-labels-idx1-ubyte.gz


In [21]:
BATCH_SIZE=100
NUM_ITERS=14
tau0=1.0 # initial temperature
np_temp=tau0
np_lr=0.001
ANNEAL_RATE=0.00003
MIN_TEMP=0.5

In [44]:
dat=[]
sess=tf.InteractiveSession()
sess.run(init_op)
for i in range(1,NUM_ITERS):
  np_x,np_y=data.next_batch(BATCH_SIZE)
  _,np_loss,np_log_dens,np_sample,np_logits_x,np_pi,np_params,np_px_logx=sess.run([train_op,loss,log_dens,sample,logits_x,
                                                                        pi,params,p_x.log_prob(x)],{
      x:np_x,
      tau:np_temp,
      lr:np_lr
    })
  if i % 100 == 1:
    dat.append([i,np_temp,np_loss])
  if i % 1000 == 1:
    np_temp=np.maximum(tau0*np.exp(-ANNEAL_RATE*i),MIN_TEMP)
    np_lr*=0.9
  if i % 2 == 1:
      print('Step %d, ELBO: %0.3f' % (i,-np_loss)) 
  assert np.isnan(np_loss) is False  
        
        

Step 1, ELBO: -562.205


AssertionError: 

In [49]:
print np.isnan(np_sample)

[[False False False False]
 [False False False False]
 [False False False False]
 ..., 
 [False False False False]
 [False False False False]
 [False False False False]]


In [None]:
np_x1,_=data.next_batch(100)
np_x2,np_y1 = sess.run([p_x.mean(),z],{x:np_x1})

In [None]:
import matplotlib.animation as animation

In [None]:
def save_anim(data,figsize,filename):
  fig=plt.figure(figsize=(figsize[1]/10.0,figsize[0]/10.0))
  im = plt.imshow(data[0].reshape(figsize),cmap=plt.cm.gray,interpolation='none')
  plt.gca().set_axis_off()
  #fig.tight_layout()
  fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
  def updatefig(t):
    im.set_array(data[t].reshape(figsize))
    return im,
  anim=animation.FuncAnimation(fig, updatefig, frames=100, interval=50, blit=True, repeat=True)
  Writer = animation.writers['imagemagick']
  writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800)
  anim.save(filename, writer=writer)
  return

In [None]:
dat=np.array(dat).T

In [None]:
f,axarr=plt.subplots(1,2)
axarr[0].plot(dat[0],dat[1])
axarr[0].set_ylabel('Temperature')
axarr[1].plot(dat[0],dat[2])
axarr[1].set_ylabel('-ELBO')

In [None]:
M=100*N
np_y = np.zeros((M,K))
np_y[range(M),np.random.choice(K,M)] = 1
np_y = np.reshape(np_y,[100,N,K])

In [None]:
x_p=p_x.mean()
np_x= sess.run(x_p,{z:np_y})

In [None]:
np_y = np_y.reshape((10,10,N,K))
np_y = np.concatenate(np.split(np_y,10,axis=0),axis=3)
np_y = np.concatenate(np.split(np_y,10,axis=1),axis=2)
y_img = np.squeeze(np_y)

In [None]:
np_x = np_x.reshape((10,10,28,28))
# split into 10 (1,10,28,28) images, concat along columns -> 1,10,28,280
np_x = np.concatenate(np.split(np_x,10,axis=0),axis=3)
# split into 10 (1,1,28,280) images, concat along rows -> 1,1,280,280
np_x = np.concatenate(np.split(np_x,10,axis=1),axis=2)
x_img = np.squeeze(np_x)

In [None]:
f,axarr=plt.subplots(1,2,figsize=(15,15))
# samples
axarr[0].matshow(y_img,cmap=plt.cm.gray)
axarr[0].set_title('Z Samples')
# reconstruction
axarr[1].imshow(x_img,cmap=plt.cm.gray,interpolation='none')
axarr[1].set_title('Generated Images')

In [None]:
f.tight_layout()
f.savefig('Gaussian-Cont.png')