In [1]:
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf
import numpy as np
import random
import logpdf
import time

from sklearn.metrics import accuracy_score

np.random.seed(19538)
random.seed(19538)
tf.set_random_seed(19538)

In [2]:
import tensorflow.contrib.layers as tcl
class VAE:
    def __init__(self):
        self.name = 'VAE'
        
        # classifier params
        self.hidden_size = 600
        
        # encode & decoder params
        self.z_dim = 50
        self.x_dim = 28*28
        
        # training
        # learning rate
        self.lr_decay_factor = 0.95
        self.learning_rate = 1e-3
        # first moment decay
        self.beta1 = 1-1e-1
        # second moment decay
        self.beta2 = 1-1e-3
        
        self._build_graph()
        self._build_train_op()
        self.check_parameters()
    
    def check_parameters(self):
        for var in tf.trainable_variables():
            print('%s: %s' % (var.name, var.get_shape()))
        print()
    
    def get_collection(self, collections):
        return [var for var in tf.get_collection(collections)]
    
    def reparameterize(self, mu, logvar):
        batch_size, eps_dim = tf.shape(mu)[0], tf.shape(mu)[1]
        std = tf.exp(logvar * 0.5)
        eps = tf.random_normal([batch_size, eps_dim])
        z = mu + eps * std
        return z
    
    def encode(self, x, reuse = False):
        with tf.variable_scope(self.name+'/encoder', reuse = reuse):
            h1 = tcl.fully_connected(x,  self.hidden_size, activation_fn = tf.nn.softplus)
            h1 = tf.nn.dropout(h1, keep_prob = self.encode_keep_prob)
            h2 = tcl.fully_connected(h1, self.hidden_size, activation_fn = tf.nn.softplus)
            h2 = tf.nn.dropout(h2, keep_prob = self.encode_keep_prob)
            mu     = tcl.fully_connected(h2, self.z_dim, activation_fn = None)
            logvar = tcl.fully_connected(h2, self.z_dim, activation_fn = None)
            z = self.reparameterize(mu, logvar)
            return z, mu, logvar
    
    def decode(self, z, reuse = False):
        with tf.variable_scope(self.name+'/decoder', reuse = reuse):
            h1 = tcl.fully_connected(z,  self.hidden_size, activation_fn = tf.nn.softplus)
            h1 = tf.nn.dropout(h1, keep_prob = self.decode_keep_prob)
            h2 = tcl.fully_connected(h1, self.hidden_size, activation_fn = tf.nn.softplus)
            h2 = tf.nn.dropout(h2, keep_prob = self.decode_keep_prob)
            x  = tcl.fully_connected(h2, self.x_dim, activation_fn = tf.nn.softmax) 
            return x
        
    def L(self, x, recon_x, z, mu_z, logvar_z):
        # (batch_size, z_dim) -> batch_size,
        kld = tf.reduce_sum(logpdf.KLD(mu_z, logvar_z), 1)
        # (batch_size, 784)   -> batch_size,
        logpx = tf.reduce_sum(logpdf.bernoulli(recon_x, x), 1)
        loss = kld - logpx
        return loss
    
    def _build_graph(self, reuse = False):
        self.x = tf.placeholder(tf.float32, shape = (None, self.x_dim))
        self.encode_keep_prob = tf.placeholder(tf.float32)
        self.decode_keep_prob = tf.placeholder(tf.float32)
        '''
            labelled data, encoder & decoder
        '''
        # encoder, labelled data
        self.z, self.mu_z, self.logvar_z = self.encode(self.x, reuse = reuse)

        # decoder, labelled data
        self.x_recon = self.decode(self.z, reuse = reuse)
        
        # loss of labelled data, refered as L(x, y)
        self.loss = self.L(self.x, self.x_recon, self.z, self.mu_z, self.logvar_z)
        self.loss = tf.reduce_mean(self.loss, 0)

        trainable_vars_key = tf.GraphKeys.TRAINABLE_VARIABLES
        encoder_vars = tf.get_collection(key=trainable_vars_key, scope=self.name+"/encoder")
        decoder_vars = tf.get_collection(key=trainable_vars_key, scope=self.name+"/decoder")
        tcl.apply_regularization(tcl.l2_regularizer(1.0), encoder_vars)
        tcl.apply_regularization(tcl.l2_regularizer(1.0), decoder_vars)

        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        reg_constant = 1e-6
        self.loss += reg_constant * tf.reduce_sum(reg_losses)
        
        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=3,\
            pad_step_number=True, keep_checkpoint_every_n_hours=5.0)
            
    def _build_train_op(self):
        self.global_step = tf.Variable(0, name="global_step", trainable = False)
        self.lr = tf.Variable(self.learning_rate, trainable=False, 
                    dtype=tf.float32)
        optimizer = tf.train.AdamOptimizer(self.lr)
        grads_and_vars = optimizer.compute_gradients(self.loss)
        def ClipIfNotNone(grad):
            if grad is None:
                return grad
            return tf.clip_by_value(grad, -1, 1)
        capped_gvs = [(ClipIfNotNone(grad), var) for grad, var in grads_and_vars]
        self.train_op = optimizer.apply_gradients(capped_gvs, self.global_step)
        self.lr_decay_op = self.lr.assign(
                self.lr * self.lr_decay_factor)
        
    def lr_decay(self, sess):
        _ = sess.run([self.lr_decay_op])
    
    def _build_generator_op(self):
        self._z   = tf.placeholder(tf.float32, shape = (None, self.z_dim))
        self._gen = self.decode(self._z, reuse = True)
    
    def optimize(self, sess, x):
        feed_dict = {
            self.x: x,
            self.encode_keep_prob: 1,
            self.decode_keep_prob: 1,
        }
        _, loss = sess.run([self.train_op, self.loss], feed_dict = feed_dict)
        return loss
    
    def recon_x(self, sess, x):
        feed_dict = {
            self.x: x,
            self.encode_keep_prob: 1,
            self.decode_keep_prob: 1,
        }
        x_recon = sess.run([self.x_recon], feed_dict = feed_dict)[0]
        return x_recon
    
    def generate(self, sess, z):
        feed_dict = {
            self._z : z,
            self.decode_keep_prob: 1,
        }
        _gen = sess.run([self._gen], feed_dict = feed_dict)[0]
        return _gen
    
    def save(self, sess, path = 'models/vae/ckpt'):
        self.saver.save(sess, path, global_step = self.global_step)
    
    def get_repr(self, sess, x, stats_only = False):
        feed_dict = {
            self.x: x,
            self.encode_keep_prob: 1,
            self.decode_keep_prob: 1,
        }
        if stats_only:
            mu, logvar = sess.run([self.mu_z, self.logvar_z], feed_dict = feed_dict)
            return mu, logvar
        else:
            z = sess.run([self.z], feed_dict = feed_dict)[0]
            return z

In [3]:
dataset_size = 50000
n_batch_size = 100
n_epoch      = 100
max_iter = dataset_size/n_batch_size*n_epoch
from utils import (load_mnist, batch_generator, time_since)
MNIST_PATH = './data/mnist_28.pkl.gz'
train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist(MNIST_PATH)
batch_gen = batch_generator(zip(train_x), n_batch_size, n_epoch)

gen_x = []

# Set config for tensorflow session.
tf_config = tf.ConfigProto(
    device_count = {'GPU': 1}, # single gpu
)
tf_config.gpu_options.allow_growth=True
with tf.Session() as sess:
    model = VAE()
    sess.run(tf.global_variables_initializer())
    
    start = time.time()
    for x_batch in batch_gen:
        x_batch = zip(*x_batch)[0]
        loss = model.optimize(sess, x_batch)
        
        if(model.global_step.eval() % 500 == 0):
            print('Time: %s'        % time_since(start))
            print('Iteration %d/%d' % (model.global_step.eval(), max_iter))
            print('loss: %.2f'      % loss)
            print()
            
            # save cpkt
            model.save(sess)

        if model.global_step.eval() % 1000 == 0:
            model.lr_decay(sess)
            
            # save imgs
            # x_recon = model.recon_x(sess, x_batch)
            # gen_x.append(x_recon)

VAE/encoder/fully_connected/weights:0: (784, 600)
VAE/encoder/fully_connected/biases:0: (600,)
VAE/encoder/fully_connected_1/weights:0: (600, 600)
VAE/encoder/fully_connected_1/biases:0: (600,)
VAE/encoder/fully_connected_2/weights:0: (600, 50)
VAE/encoder/fully_connected_2/biases:0: (50,)
VAE/encoder/fully_connected_3/weights:0: (600, 50)
VAE/encoder/fully_connected_3/biases:0: (50,)
VAE/decoder/fully_connected/weights:0: (50, 600)
VAE/decoder/fully_connected/biases:0: (600,)
VAE/decoder/fully_connected_1/weights:0: (600, 600)
VAE/decoder/fully_connected_1/biases:0: (600,)
VAE/decoder/fully_connected_2/weights:0: (600, 784)
VAE/decoder/fully_connected_2/biases:0: (784,)

Time: 0 m 4 s
Iteration 500/50000
loss: 573.53

Time: 0 m 8 s
Iteration 1000/50000
loss: 553.91

Time: 0 m 11 s
Iteration 1500/50000
loss: 554.72

Time: 0 m 15 s
Iteration 2000/50000
loss: 558.06

Time: 0 m 19 s
Iteration 2500/50000
loss: 547.06

Time: 0 m 22 s
Iteration 3000/50000
loss: 584.04

Time: 0 m 26 s
Iterati

In [4]:
len(gen_x)

0

In [5]:
gen_x[25][0]

IndexError: list index out of range

In [None]:
import matplotlib
import matplotlib.pyplot as plt
% matplotlib inline
def display(x):
    x = np.clip(x, 0, 1)
    nrows, ncols = 10, 10
    n_count  = 0
    fig, axarr = plt.subplots(nrows = nrows, ncols = ncols)
    for i in range(nrows):
        for j in range(ncols):
            axarr[i, j].axis('off')
            axarr[i, j].imshow(x[n_count], cmap='gray')
            n_count += 1
    plt.show()

In [None]:
display(np.reshape(gen_x[80], (100,28,28)))