In [3]:
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.misc import imsave as ims

In [4]:
def merge(images, size):
    print images.shape##64x28x28
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1]))#28x28 matrix of zeros

    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx / size[1]
        img[j*h:j*h+h, i*w:i*w+w] = image

    return img

In [5]:
class LatentAttention():
    def __init__(self):
        self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
        self.n_samples = self.mnist.train.num_examples
        print self.n_samples

        self.n_hidden = 500
        self.n_z = 20
        self.batchsize = 100

        self.images = tf.placeholder(tf.float32, [None, 784])
        image_matrix = tf.reshape(self.images,[-1, 28, 28, 1])
        z_mean, z_stddev = self.recognition(image_matrix)
        print z_mean.shape,z_mean.shape
        samples = tf.random_normal([self.batchsize,self.n_z],0,1,dtype=tf.float32)
        guessed_z = z_mean + (z_stddev * samples)
        print guessed_z.shape

        self.generated_images = self.generation(guessed_z)
        generated_flat = tf.reshape(self.generated_images, [self.batchsize, 28*28])

        self.generation_loss = -tf.reduce_sum(self.images * tf.log(1e-8 + generated_flat) + (1-self.images) * tf.log(1e-8 + 1 - generated_flat),1)

        self.latent_loss = 0.5 * tf.reduce_sum(tf.square(z_mean) + tf.square(z_stddev) - tf.log(tf.square(z_stddev)) - 1,1)
        self.cost = tf.reduce_mean(self.generation_loss + self.latent_loss)
        print self.cost.shape
        self.optimizer = tf.train.AdamOptimizer(0.001).minimize(self.cost)
        
        
    def recognition(self, input_images):
        with tf.variable_scope("recognition"):
            filter_size = 5
            n_filters = 16
            filter_shape = [filter_size, filter_size, 1, n_filters]
            W1 = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1))
            b1 = tf.Variable(tf.constant(0.1, shape=[n_filters]))
            conv1 = tf.nn.conv2d(input=input_images,
                                         filter=W1,
                                          strides=[1, 2, 2, 1],
                                          padding='SAME')
            h1=tf.nn.bias_add(conv1, b1)
            z1 = tf.nn.relu(h1)
            
            
            filter_size = 5
            n_filters = 32
            filter_shape = [filter_size, filter_size, 16, n_filters]
            W2 = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1))
            b2 = tf.Variable(tf.constant(0.1, shape=[n_filters]))
            conv2 = tf.nn.conv2d(input=z1,
                                         filter=W2,
                                          strides=[1, 2, 2, 1],
                                          padding='SAME')
            h2=tf.nn.bias_add(conv2, b2)
            z2 = tf.nn.relu(h2)
            
            
            h2_flat = tf.reshape(h2,[self.batchsize, 7*7*32])
            print h2_flat.shape

            w_mean = tf.layers.dense(h2_flat,self.n_z,activation=None,use_bias=True,bias_initializer=tf.zeros_initializer())
            print w_mean.shape
    
            w_stddev = tf.layers.dense(h2_flat,self.n_z,activation=None,use_bias=True,bias_initializer=tf.zeros_initializer())

        return w_mean, w_stddev
        
        
    def generation(self, z):
        with tf.variable_scope("generation"):
            z_develop=tf.layers.dense(z,7*7*32,activation=None,use_bias=True,bias_initializer=tf.zeros_initializer())
            z_matrix = tf.nn.relu(tf.reshape(z_develop, [self.batchsize, 7, 7, 32]))
            print z_matrix.shape
            
            
            filter_size = 5
            filter_shape = [filter_size, filter_size, 16,32]
            W1 = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1))
            
            conv1 = tf.nn.conv2d_transpose(z_matrix, W1,[self.batchsize, 14, 14, 16], strides=[1,2,2,1])
            
            
            #b1 = tf.Variable(tf.constant(0.1, shape=[16]))
            
            z1 = tf.nn.relu(conv1)
            
            print z1.shape
            
            filter_size = 5
            filter_shape = [filter_size, filter_size, 1,16]
            W2 = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1))
            
            conv2 = tf.nn.conv2d_transpose(z1, W2,[self.batchsize, 28, 28, 1], strides=[1,2,2,1])
            
            z2 = tf.nn.sigmoid(conv2)
            
            
            
        return z2
    
    def train(self):
        b=self.mnist.train.next_batch(self.batchsize)
        visualization= self.mnist.train.next_batch(self.batchsize)[0]#100 images
        print len(b),len(visualization)
        reshaped_vis = visualization.reshape(self.batchsize,28,28)#100x28x28
        ims("results/base.jpg",merge(reshaped_vis[:64],[8,8]))#pass 64 images
        # train
        saver = tf.train.Saver(max_to_keep=2)
        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())
            for epoch in range(10):
                for idx in range(int(self.n_samples / self.batchsize)):
                    batch = self.mnist.train.next_batch(self.batchsize)[0]
                    _, gen_loss, lat_loss = sess.run((self.optimizer, self.generation_loss, self.latent_loss), feed_dict={self.images: batch})
                    # dumb hack to print cost every epoch
                    if idx % (self.n_samples - 3) == 0:
                        print "epoch %d: genloss %f latloss %f" % (epoch, np.mean(gen_loss), np.mean(lat_loss))
                        saver.save(sess, os.getcwd()+"/training/train",global_step=epoch)
                        generated_test = sess.run(self.generated_images, feed_dict={self.images: visualization})
                        generated_test = generated_test.reshape(self.batchsize,28,28)
                        ims("results/"+str(epoch)+".jpg",merge(generated_test[:64],[8,8]))






In [6]:
model = LatentAttention()

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
55000
(100, 1568)
(100, 20)
(100, 20) (100, 20)
(100, 20)
(100, 7, 7, 32)
(100, 14, 14, 16)
()


In [8]:
model.train()

2 100
(64, 28, 28)
Instructions for updating:
Use `tf.global_variables_initializer` instead.
epoch 0: genloss 546.541443 latloss 26.686228
(64, 28, 28)
epoch 1: genloss 114.144508 latloss 23.428110
(64, 28, 28)
epoch 2: genloss 96.980049 latloss 24.964941
(64, 28, 28)
epoch 3: genloss 94.634842 latloss 24.394951
(64, 28, 28)
epoch 4: genloss 90.282257 latloss 24.470772
(64, 28, 28)
epoch 5: genloss 87.564743 latloss 23.690863
(64, 28, 28)
epoch 6: genloss 94.319870 latloss 24.259581
(64, 28, 28)
epoch 7: genloss 84.751640 latloss 24.703306
(64, 28, 28)
epoch 8: genloss 87.313240 latloss 24.946445
(64, 28, 28)
epoch 9: genloss 84.469246 latloss 24.784950
(64, 28, 28)
