In [None]:
import os
import json
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [None]:
mvc_info = json.load(open(r"S:\Projects\MVC\mvc_info.json"))

In [None]:
mb_size = 20
z_dim = 400
X_dim = 64
y_dim = 8
h_dim = 1024
input_size = 10000

In [None]:
input_images = []
input_labels = []
for i in range(input_size):
    input_images.append(np.array(Image.open("S:\\Projects\\MVC\\64x64_bw\\" + str(i) + ".png")).reshape(X_dim * X_dim))
    labels = np.zeros(y_dim)
    labels[mvc_info[i]["viewId"]] = 1
    input_labels.append(labels)
input_images = np.array(input_images).astype("float32") / 255.0
input_labels = np.array(input_labels)

In [None]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(X_dim, X_dim), cmap='Greys_r')

    return fig

def bn(inputs):
    return tf.layers.batch_normalization(inputs)

def conv(inputs, filters, kernel_size, strides, activation=tf.nn.relu):
    return tf.layers.conv2d(inputs, filters, kernel_size, strides, padding="same", activation=activation)

In [None]:
tf.reset_default_graph()

with tf.device('/device:GPU:0'):
    X = tf.placeholder(tf.float32, shape=[None, X_dim * X_dim])
    c = tf.placeholder(tf.float32, shape=[None, y_dim])
    z = tf.placeholder(tf.float32, shape=[None, z_dim])

    def Q(X, c):
        img = tf.reshape(X, [-1, X_dim, X_dim, 1])
        conv1 = conv(img, filters=16, kernel_size=3, strides=1)
        conv2 = conv(conv1, filters=32, kernel_size=3, strides=2)
        conv3 = conv(conv2, filters=32, kernel_size=3, strides=1)
        conv4 = conv(conv3, filters=64, kernel_size=3, strides=2)
        conv5 = conv(conv4, filters=64, kernel_size=3, strides=1)
        conv6 = conv(conv5, filters=1, kernel_size=1, strides=1)
        X_ = tf.reshape(conv6, [-1, int(X_dim / 4) * int(X_dim / 4)])
        
        inputs = tf.concat(axis=1, values=[X_, c])
        h = tf.layers.dense(inputs, h_dim, activation=tf.nn.relu)
        z_mu = tf.layers.dense(h, z_dim, activation=tf.nn.relu)
        z_logvar = tf.layers.dense(h, z_dim, activation=tf.nn.relu)
        return z_mu, z_logvar

    def P(z, c):
        inputs = tf.concat(axis=1, values=[z, c])
        
        with tf.variable_scope("P", reuse=tf.AUTO_REUSE):
            h = tf.layers.dense(inputs, h_dim, activation=tf.nn.relu)
            logits = tf.layers.dense(h, int(X_dim / 4) * int(X_dim / 4))
            
            small_img = tf.reshape(logits, [-1, int(X_dim / 4), int(X_dim / 4), 1])
            conv1 = conv(small_img, filters=64, kernel_size=3, strides=1)
            conv2 = tf.image.resize_images(conv1, [int(X_dim / 2), int(X_dim / 2)])
            conv3 = conv(conv2, filters=32, kernel_size=3, strides=1)
            conv4 = tf.image.resize_images(conv3, [X_dim, X_dim])
            conv5 = conv(conv4, filters=32, kernel_size=3, strides=1)
            conv6 = conv(conv5, filters=16, kernel_size=3, strides=1)
            conv7 = conv(conv6, filters=1, kernel_size=1, strides=1, activation=None)
            big_img = tf.reshape(conv7, [-1, X_dim * X_dim])
            
        prob = tf.nn.sigmoid(big_img)
        return prob, big_img

    def sample_z(mu, log_var):
        eps = tf.random_normal(shape=tf.shape(mu))
        return mu + tf.exp(log_var / 2) * eps


    # =============================== TRAINING ====================================

    z_mu, z_logvar = Q(X, c)
    z_sample = sample_z(z_mu, z_logvar)
    _, logits = P(z_sample, c)

    # Sampling from random z
    X_samples, _ = P(z, c)

    # E[log P(X|z)]
    recon_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=X), 1)
    # D_KL(Q(z|X) || P(z|X)); calculate in closed form as both dist. are Gaussian
    kl_loss = 0.5 * tf.reduce_sum(tf.exp(z_logvar) + z_mu**2 - 1. - z_logvar, 1)
    # VAE loss
    vae_loss = tf.reduce_mean(recon_loss + kl_loss)

    solver = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(vae_loss)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    i = 0
    cur = 0

    for it in range(100000):
        if cur == 0 or cur + mb_size == input_size:
            cur = mb_size
            X_mb = input_images[0:mb_size]
            y_mb = input_labels[0:mb_size]
        else:
            cur = (cur + mb_size) % input_size
            X_mb = input_images[cur-mb_size:cur]
            y_mb = input_labels[cur-mb_size:cur]

        _, loss = sess.run([solver, vae_loss], feed_dict={X: X_mb, c: y_mb})

        if it % 1000 == 0 and it != 0:
            print('Iter: {}'.format(it))
            print('Loss: {:.4}'. format(loss))
            print()

            y = np.zeros(shape=[16, y_dim])
            y[:, np.random.randint(0, y_dim)] = 1.

            out_z = sess.run(z_sample, feed_dict={X: X_mb[:16], c: y})
            samples = sess.run(X_samples, feed_dict={z: out_z, c: y})
            
            print(np.min(samples), np.max(samples))

            fig = plot(samples)
            plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
            i += 1
            plt.close(fig)