In [24]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.distributions import Normal
from tensorflow.distributions import Bernoulli

In [25]:
class DenseLayer(object):
    '''A fully connected layer'''
    
    def __init__(self, n_in, n_out, activation=tf.nn.relu):
        self.weights = tf.Variable(tf.random_normal(shape=(n_in, n_out), stddev=2/np.sqrt(n_in)))
        self.bias = tf.Variable(tf.constant(0.0,shape=[n_out]))
        self.activation = activation

    def feed_forward(self, X):
        return self.activation(tf.matmul(X, self.weights) + self.bias)

In [27]:
class VariationalAutoencoder:
    def __init__(self, n_input, n_list):
        ''''''
        # input
        self.X = tf.placeholder(tf.float32, shape=(None, n_input))
        
        # encoder
        # build hidden layers
        self.encoder_layers = []
        # input of first hidden layer
        previous = n_input
        # current is the output of each layer (skip last because there is nothing after it)
        for current in n_list[:-1]:
            # hidden layer
            h = DenseLayer(previous,current)
            self.encoder_layers.append(h)
            previous = current
        # latent features number
        latent = n_list[-1]
        encoder_output = DenseLayer(current,latent*2,activation=lambda x:X)
        self.encoder_layers.append(encoder_output)
        
        # feed forward through encoder
        c_X = self.X
        for layer in self.encoder_layers:
            c_X = layer.feed_forward(c_X)
        # c_X now holds the output of the encoder
        self.means = c_X[:,:latent]
        # std must be positive, 1e-6 for smoothing
        self.std = tf.nn.softplus(c_X[:,latent:]) + 1e-6
        
        # reparameterization trick
        normal = Normal(loc=self.means,scale=self.std)
        self.Z = normal.sample()
        
        
        # decoder
        self.decoder_layers = []
        previous = latent
        for current in reversed(n_list[:-1]):
            h = DenseLayer(previous,current)
            self.decoder_layers.append(h)
            previous = current
        decoder_output = DenseLayer(previous,n_input,activation=lambda x:x)
        self.decoder_layers.append(decoder_output)

        #feed forward through decoder
        c_X = self.Z
        for layer in self.decoder_layers:
            c_X = layer.feed_forward(c_X)
        logits = c_X
        # ???
        post_pred_logits = logits
        
        # output
        self.pred = Bernoulli(logits=logits)
        
        # sample from output
        self.post_pred = self.pred.sample()
        self.post_pred_probs = tf.nn.sigmoid(logits)
        
        scale = np.ones(latent,dtype=np.float32)
        loc = scale*0
        std_norm = Normal(loc,scale)
        
        Z_std = std_norm.sample(1)
        c_X = Z_std
        for layer in layer.decoder_layers:
            c_X = layer.feed_forward(c_X)
        logits = c_X
        
        prior_pred_dist = Bernoulli(logits=logits)
        self.prior_pred = prior_pred_dist.sample()
        self.prior_pred_pros = tf.nn.sigmoid(logits)
        
        self.Z_input = tf.placeholder(ft.float32, shape=(None, latent))
        c_X = self.Z_input
        for layer in self.decoder_layers:
            c_X = layer.feed_forward(c_X)
        logits = c_X
        self.prior_pred_from_in_probs = tf.nn.sigmoid(logits)
        
        # cost
        # TODO
        self.elbo = tf.reduce_sum(expected_log_likelihood - kl)
        self.optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001).minimize(-self.elbo)
        
        self.init = tf.global_variables_initializer()
        self.session = tf.session()
        self.session.run(self.init)
        
    def fit(self,X,epochs=10,batch=50):
        costs = []
        n_batches = len(X) // batch
        for epoch in range(epochs):
            np.random.shuffle(X)
            for b in range(n_batches):
                c_batch = X[b*batch:(b+1)*batch]
                _,c, = self.session.run((self.train_op, self.elbo),feed_dict={self.X: c_batch})
                costs.append(c)
        plt.plot(costs)
        plt.show()
    
    def transform(self,X):
        return self.session.run(self.means,feed_dict={self.X: X})