In [None]:
import os
import json
import numpy as np
import tensorflow as tf

from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [None]:
mvc_info = json.load(open(r"S:\Projects\MVC\mvc_info.json"))

In [None]:
mb_size = 20
img_size = 64
z_dim = 512
X_dim = img_size * img_size
y_dim = 8
h_dim = 2048
lr = 1e-5
input_size = 10000

In [None]:
input_images = []
input_labels = []
for i in range(input_size):
    input_images.append(np.array(Image.open("S:\\Projects\\MVC\\64x64_bw\\" + str(i) + ".png")).reshape(X_dim))
    labels = np.zeros(y_dim)
    labels[mvc_info[i]["viewId"]] = 1
    input_labels.append(labels)
input_images = np.array(input_images).astype("float32") / 255.0
input_labels = np.array(input_labels)

In [None]:
def plot(samples):
    fig = plt.figure(figsize=(16, 16))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(img_size, img_size), cmap='Greys_r')

    return fig


def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

def Q(X, c):
    inputs = tf.concat(axis=1, values=[X, c])
    h = tf.nn.relu(tf.matmul(inputs, Q_W1) + Q_b1)
    z_mu = tf.matmul(h, Q_W2_mu) + Q_b2_mu
    z_logvar = tf.matmul(h, Q_W2_sigma) + Q_b2_sigma
    return z_mu, z_logvar

def P(z, c):
    inputs = tf.concat(axis=1, values=[z, c])
    h = tf.nn.relu(tf.matmul(inputs, P_W1) + P_b1)
    logits = tf.matmul(h, P_W2) + P_b2
    prob = tf.nn.sigmoid(logits)
    return prob, logits

def sample_z(mu, log_var):
    eps = tf.random_normal(shape=tf.shape(mu))
    return mu + tf.exp(log_var / 2) * eps


In [None]:
tf.reset_default_graph()

with tf.device('/device:GPU:3'):

# =============================== Q(z|X) ======================================

    X = tf.placeholder(tf.float32, shape=[None, X_dim])
    c = tf.placeholder(tf.float32, shape=[None, y_dim])
    z = tf.placeholder(tf.float32, shape=[None, z_dim])

    Q_W1 = tf.Variable(xavier_init([X_dim + y_dim, h_dim]))
    Q_b1 = tf.Variable(tf.zeros(shape=[h_dim]))

    Q_W2_mu = tf.Variable(xavier_init([h_dim, z_dim]))
    Q_b2_mu = tf.Variable(tf.zeros(shape=[z_dim]))

    Q_W2_sigma = tf.Variable(xavier_init([h_dim, z_dim]))
    Q_b2_sigma = tf.Variable(tf.zeros(shape=[z_dim]))


    # =============================== P(X|z) ======================================

    P_W1 = tf.Variable(xavier_init([z_dim + y_dim, h_dim]))
    P_b1 = tf.Variable(tf.zeros(shape=[h_dim]))

    P_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
    P_b2 = tf.Variable(tf.zeros(shape=[X_dim]))

In [None]:
# =============================== TRAINING ====================================

with tf.device('/device:GPU:3'):
    z_mu, z_logvar = Q(X, c)
    z_sample = sample_z(z_mu, z_logvar)
    _, logits = P(z_sample, c)

    # Sampling from random z
    X_samples, _ = P(z, c)

    # E[log P(X|z)]
    recon_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=X), 1)
    # D_KL(Q(z|X) || P(z|X)); calculate in closed form as both dist. are Gaussian
    kl_loss = 0.5 * tf.reduce_sum(tf.exp(z_logvar) + z_mu**2 - 1. - z_logvar, 1)
    # VAE loss
    vae_loss = tf.reduce_mean(recon_loss + kl_loss)

    solver = tf.train.AdamOptimizer().minimize(vae_loss)

saver = tf.train.Saver()

sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out/'):
    os.makedirs('out/')

i = 0
cur = 0

for it in range(1000000):
    if cur == 0 or cur + mb_size == input_size:
        cur = mb_size
        X_mb = input_images[0:mb_size]
        y_mb = input_labels[0:mb_size]
    else:
        cur = (cur + mb_size) % input_size
        X_mb = input_images[cur-mb_size:cur]
        y_mb = input_labels[cur-mb_size:cur]
        

    _, loss = sess.run([solver, vae_loss], feed_dict={X: X_mb, c: y_mb})

    if it % 100 == 0:

        y = np.zeros(shape=[16, y_dim])
        idx = np.random.randint(0, 4)
        y[:, idx] = 1.
        
        print('Iter: {}'.format(it) + '   Loss: {:.4}'. format(loss) + ',   Output ' + str(i) + ': ' + str(idx))

        samples = sess.run(X_samples,
                           feed_dict={z: np.random.randn(16, z_dim), c: y})
        
        saver.save(sess, "checkpoints/model.ckpt")

        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)