### Variational Autoencoder 

In [1]:
import warnings,os
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim # I lkie slim 
import matplotlib.pyplot as plt
# warnings.filterwarnings("ignore") # Stop showing annoying warnings
tf.logging.set_verbosity(tf.logging.ERROR) # I like old-style MNIST w/o warnings
from tensorflow.examples.tutorials.mnist import input_data
from sklearn.utils import shuffle
from util import gpu_sess
mnist = input_data.read_data_sets('data', one_hot=True)
%matplotlib inline
print ("TF version is %s."%(tf.__version__))

Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
TF version is 1.8.0.


### Define vae class

In [4]:
class vae_class(object):
    def __init__(self,_name='VAE',_xDim=784,_zDim=10,_hDims=[64,64],_cDim=0,
                 _actv=tf.nn.relu,_bn=slim.batch_norm,
                 _lr=0.001,_beta1=0.9,_beta2=0.9,_epsilon=0.1,
                 _VERBOSE=True):
        self.name  = _name # Name
        self.xDim  = _xDim # Dimension of input
        self.zDim  = _zDim # Dimension of latent vector
        self.hDims = _hDims # Dimention of hidden layer(s)
        self.cDim  = _cDim # Dimention of conditional vector 
        self.actv  = _actv # Activation function 
        self.bn    = _bn # Batch norm (slim.batch_norm / None)
        self.lr    = _lr # Learning rate 
        self.beta1 = _beta1 # Adam related (beta1)
        self.beta2 = _beta2 # Adam related (beta2)
        self.epsilon = _epsilon # Adam related (epsilon)
        self.VERBOSE = _VERBOSE
        if self.VERBOSE:
            print ("[%s] xdim:[%d] zdim:[%d] hdim:%s cdim:[%d]"\
                % (self.name,self.xDim,self.zDim,self.hDims,self.cDim))
        # Make model 
        self._build_model()
        self._build_graph()
        self._check_params()
    def _build_model(self):
        # Placeholders
        self.x  = tf.placeholder(tf.float32, shape=[None,self.xDim], name="x") # This will be inputs
        self.z  = tf.placeholder(tf.float32, shape=[None,self.zDim], name="z") # Latent vectors
        self.c  = tf.placeholder(tf.float32, shape=[None,self.cDim], name="c") # Conditioning vectors
        self.q  = tf.placeholder(tf.float32, shape=[None], name="q") # Weighting vectors 
        self.kp = tf.placeholder(tf.float32) # Keep prob.
        self.klWeight = tf.placeholder(tf.float32) # KL weight heuristics 
        self.isTraining = tf.placeholder(dtype=tf.bool,shape=[]) # Training flag
        # Build graph
        self.bnInit = {'beta':tf.constant_initializer(0.),'gamma':tf.constant_initializer(1.)}
        self.bnParams = {'is_training':self.isTraining,'decay':0.9,'epsilon':1e-5,
                           'param_initializers':self.bnInit,'updates_collections':None}
        with tf.variable_scope(self.name,reuse=False) as scope:
            with slim.arg_scope([slim.fully_connected],activation_fn=self.actv,
                               weights_initializer=tf.random_normal_initializer(stddev=0.1),
                               biases_initializer=tf.constant_initializer(value=0.0),
                               normalizer_fn=self.bn,normalizer_params=self.bnParams,
                               weights_regularizer=None):
                _net = self.x
                self.N = tf.shape(self.x)[0] # Number of current inputs 
                # Encoder 
                for hIdx in range(len(self.hDims)): # Loop over hidden layers
                    _hDim = self.hDims[hIdx]
                    _net = slim.fully_connected(_net,_hDim,scope='enc_lin'+str(hIdx))
                    _net = slim.dropout(_net,keep_prob=self.kp,is_training=self.isTraining
                                        ,scope='enc_dr'+str(hIdx))
                # Latent vector z (NO ACTIVATION!)
                self.zMuEncoded = slim.fully_connected(_net,self.zDim,scope='zMuEncoded',activation_fn=None)
                self.zLogVarEncoded = slim.fully_connected(_net,self.zDim,scope='zLogVarEncoded',activation_fn=None)
                # Define z sampler (reparametrization trick)
                self.eps = tf.random_normal(shape=(self.N,self.zDim),mean=0.,stddev=1.,dtype=tf.float32)
                self.zSample = self.zMuEncoded+tf.sqrt(tf.exp(self.zLogVarEncoded))*self.eps
                # Concatenate the condition vector to the sampled latent vector
                if self.cDim != 0:
                    self.zEncoded = tf.concat([self.zSample,self.c],axis=1)
                else:
                    self.zEncoded = self.zSample
                # Decoder 
                _net = self.zEncoded
                for hIdx in range(len(self.hDims)): # Loop over hidden layers
                    _hDim = self.hDims[len(self.hDims)-hIdx-1]
                    _net = slim.fully_connected(_net,_hDim,scope='dec_lin'+str(hIdx))
                    _net = slim.dropout(_net,keep_prob=self.kp,is_training=self.isTraining
                                        ,scope='dec_dr'+str(hIdx))
                # Reconstruct output (NO ACTIVATION)
                self.xRecon = slim.fully_connected(_net,self.xDim,scope='xRecon',activation_fn=None)
        # Additional graph for debugging purposes
        with tf.variable_scope(self.name,reuse=True) as scope:
            with slim.arg_scope([slim.fully_connected],activation_fn=self.actv,
                               weights_initializer=tf.random_normal_initializer(stddev=0.1),
                               biases_initializer=tf.constant_initializer(value=0.0),
                               normalizer_fn=self.bn,normalizer_params=self.bnParams,
                               weights_regularizer=None):
                # Start from given z, instead of sampled z
                if self.cDim != 0:
                    self.zGiven = tf.concat([self.z,self.c],axis=1)
                else:
                    self.zGiven = self.z
                # Decoder 
                _net = self.zGiven
                for hIdx in range(len(self.hDims)): # Loop over hidden layers
                    _hDim = self.hDims[len(self.hDims)-hIdx-1]
                    _net = slim.fully_connected(_net,_hDim,scope='dec_lin'+str(hIdx))
                # Reconstruct output (NO ACTIVATION)
                self.xGivenZ = slim.fully_connected(_net,self.xDim,scope='xRecon',activation_fn=None)
    def _build_graph(self):
        # Original VAE losses
        # Recon loss
        self._reconLoss = 1./2.*tf.norm(self.xRecon-self.x,ord=1,axis=1)
        self._reconLossWeighted = tf.nn.softplus(self.q)*self._reconLoss
        self.reconLossWeighted = tf.reduce_mean(self._reconLossWeighted)
        # KL loss
        self._klLoss = 0.5*tf.reduce_sum(tf.exp(self.zLogVarEncoded)+self.zMuEncoded**2-1.-self.zLogVarEncoded,1)
        self._klLossWeighted = tf.nn.softplus(self.q)*self._klLoss
        self.klLossWeighted = self.klWeight*tf.reduce_mean(self._klLossWeighted)
        # Total loss
        self.totalLoss = self.reconLossWeighted + self.klLossWeighted
        # Solver
        self.optm = tf.train.AdamOptimizer(
            learning_rate=self.lr,beta1=self.beta1,beta2=self.beta2,epsilon=self.epsilon)\
            .minimize(self.totalLoss)
    # Check parameters
    def _check_params(self):
        _g_vars = tf.global_variables()
        self.g_vars = [var for var in _g_vars if '%s/'%(self.name) in var.name]
        if self.VERBOSE:
            print ("==== Global Variables ====")
        for i in range(len(self.g_vars)):
            w_name  = self.g_vars[i].name
            w_shape = self.g_vars[i].get_shape().as_list()
            if self.VERBOSE:
                print (" [%02d] Name:[%s] Shape:[%s]" % (i,w_name,w_shape))
    # Train
    def train(self,_sess,_X,_C,_Q,_maxIter,_batchSize,_PRINT_EVERY=100,_PLOT_EVERY=100):
        # X: inputs [N x D]
        # C: condition vectors [N x 1]
        # Q: weighting vectors [N x 1]
        self.sess = _sess
        # Initialize variables
        self.sess.run(tf.global_variables_initializer())
        # Train
        _Xtrain = _X
        nX = _X.shape[0]
        for _iter in range(_maxIter):
            randIdx = np.random.permutation(nX)[:_batchSize] # Random indices every iteration
            xBatch = _X[randIdx,:] # X batch
            # Q batch
            if _Q == None:
                qBatch = np.ones(shape=(_batchSize))
            else:
                qBatch = _Q[randIdx]
            if _C == None: # Original VAE (without conditioning)
                feeds = {self.x:xBatch,self.q:qBatch,self.klWeight:1.0,self.isTraining:True,self.kp:0.9}
            else: # Conditional VAE
                cBatch = _C[randidx,:]
                feeds = {self.x:xBatch,self.c:cBatch,self.q:qBatch,self.klWeight:1.0,self.isTraining:True,self.kp:0.9}
            # Train
            opers = [self.optm,self.totalLoss,self.reconLossWeighted,self.klLossWeighted]
            _,totalLossVal,reconLossWeightedVal,klLossWeightedVal = self.sess.run(opers,feed_dict=feeds)
            # Print 
            if (_iter%_PRINT_EVERY) == 0:
                print ("[%04d/%d] Loss: %.2f(recon:%.2f+kl:%.2f)"%
                       (_iter,_maxIter,totalLossVal,reconLossWeightedVal,klLossWeightedVal))
            # Plot
            if (_iter%_PLOT_EVERY) == 0:
                nR,nC     = 1,10
                zRandn = 1.*np.random.randn(nR*nC,self.zDim)
                

### Run

In [5]:
tf.reset_default_graph()
tf.set_random_seed(0); np.random.seed(0)
V = vae_class(_name='VAE',_xDim=784,_zDim=10,_hDims=[64,64],_cDim=0, 
             _actv=tf.nn.relu,_bn=slim.batch_norm,
             _lr=0.001,_beta1=0.9,_beta2=0.9,_epsilon=0.1,
             _VERBOSE=True)
sess = gpu_sess()
# We will use MNIST
X = mnist.train.images
Y = mnist.train.labels
V.train(_sess=sess,_X=X,_C=None,_Q=None,_maxIter=(int)(1e4),_batchSize=256,_PRINT_EVERY=100,_PLOT_EVERY=100)

[VAE] xdim:[784] zdim:[10] hdim:[64, 64] cdim:[0]
==== Global Variables ====
 [00] Name:[VAE/enc_lin0/weights:0] Shape:[[784, 64]]
 [01] Name:[VAE/enc_lin0/BatchNorm/beta:0] Shape:[[64]]
 [02] Name:[VAE/enc_lin0/BatchNorm/moving_mean:0] Shape:[[64]]
 [03] Name:[VAE/enc_lin0/BatchNorm/moving_variance:0] Shape:[[64]]
 [04] Name:[VAE/enc_lin1/weights:0] Shape:[[64, 64]]
 [05] Name:[VAE/enc_lin1/BatchNorm/beta:0] Shape:[[64]]
 [06] Name:[VAE/enc_lin1/BatchNorm/moving_mean:0] Shape:[[64]]
 [07] Name:[VAE/enc_lin1/BatchNorm/moving_variance:0] Shape:[[64]]
 [08] Name:[VAE/zMuEncoded/weights:0] Shape:[[64, 10]]
 [09] Name:[VAE/zMuEncoded/BatchNorm/beta:0] Shape:[[10]]
 [10] Name:[VAE/zMuEncoded/BatchNorm/moving_mean:0] Shape:[[10]]
 [11] Name:[VAE/zMuEncoded/BatchNorm/moving_variance:0] Shape:[[10]]
 [12] Name:[VAE/zLogVarEncoded/weights:0] Shape:[[64, 10]]
 [13] Name:[VAE/zLogVarEncoded/BatchNorm/beta:0] Shape:[[10]]
 [14] Name:[VAE/zLogVarEncoded/BatchNorm/moving_mean:0] Shape:[[10]]
 [15] N