In [None]:
import os
import json
import numpy as np
import scipy.misc
import tensorflow as tf
from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [None]:
mb_size = 20
z_dim = 400
X_dim = 128
y_dim = 4
h_dim = 512
input_size = 50000
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
input_images = []
input_labels = []
for i in range(input_size):
    input_images.append(np.array(Image.open("S:\\Projects\\MVC\\128x128_clean\\" + str(i) + ".jpg")))
    labels = np.zeros(y_dim)
    labels[i % 4] = 1
    input_labels.append(labels)
input_images = np.array(input_images).astype("float32") / 255.0
input_labels = np.array(input_labels)

In [None]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample)

    return fig

def bn(inputs):
    return tf.layers.batch_normalization(inputs)

def conv(inputs, filters, kernel_size, strides, activation=tf.nn.relu):
    return tf.layers.conv2d(inputs, filters, kernel_size, strides, padding="same", activation=activation)

In [None]:
tf.reset_default_graph()

X = tf.placeholder(tf.float32, shape=[None, X_dim, X_dim, 3])
Y = tf.placeholder(tf.float32, shape=[None, X_dim, X_dim, 3])
c = tf.placeholder(tf.float32, shape=[None, y_dim])
z = tf.placeholder(tf.float32, shape=[None, z_dim])

def Q(X, c):
    conv1 = bn(conv(X, filters=16, kernel_size=3, strides=1))
    conv2 = bn(conv(conv1, filters=16, kernel_size=3, strides=2))
    conv3 = bn(conv(conv2, filters=32, kernel_size=3, strides=1))
    conv4 = bn(conv(conv3, filters=32, kernel_size=3, strides=2))
    conv5 = bn(conv(conv4, filters=64, kernel_size=3, strides=1))
    conv6 = bn(conv(conv5, filters=1, kernel_size=1, strides=1))
    X_ = tf.reshape(conv6, [-1, int(X_dim / 4) * int(X_dim / 4)])

    inputs = tf.concat(axis=1, values=[X_, c])
    h = bn(tf.layers.dense(inputs, h_dim, activation=tf.nn.relu))
    z_mu = tf.layers.dense(h, z_dim, activation=tf.nn.relu)
    z_logvar = tf.layers.dense(h, z_dim, activation=tf.nn.relu)
    return z_mu, z_logvar

def P(z, c):
    inputs = tf.concat(axis=1, values=[z, c])

    with tf.variable_scope("P", reuse=tf.AUTO_REUSE):
        h = bn(tf.layers.dense(inputs, h_dim, activation=tf.nn.relu))
        logits = bn(tf.layers.dense(h, int(X_dim / 4) * int(X_dim / 4)))

        small_img = tf.reshape(logits, [-1, int(X_dim / 4), int(X_dim / 4), 1])
        conv1 = bn(conv(small_img, filters=64, kernel_size=3, strides=1))
        conv2 = tf.image.resize_images(conv1, [int(X_dim / 2), int(X_dim / 2)])
        conv3 = bn(conv(conv2, filters=32, kernel_size=3, strides=1))
        conv4 = tf.image.resize_images(conv3, [X_dim, X_dim])
        conv5 = bn(conv(conv4, filters=32, kernel_size=3, strides=1))
        conv6 = conv(conv5, filters=16, kernel_size=3, strides=1)
        conv7 = conv(conv6, filters=3, kernel_size=3, strides=1, activation=None)
        big_img = tf.reshape(conv7, [-1, X_dim, X_dim, 3])

    prob = tf.nn.sigmoid(big_img)
    return prob, big_img

def sample_z(mu, log_var):
    eps = tf.random_normal(shape=tf.shape(mu))
    return mu + tf.exp(log_var / 2) * eps


# =============================== TRAINING ====================================

z_mu, z_logvar = Q(X, c)
z_sample = sample_z(z_mu, z_logvar)
_, logits = P(z_sample, c)

# Sampling from random z
X_samples, _ = P(z, c)

# E[log P(X|z)]
recon_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=Y), [1, 2, 3])
# D_KL(Q(z|X) || P(z|X)); calculate in closed form as both dist. are Gaussian
kl_loss = 0.5 * tf.reduce_sum(tf.exp(z_logvar) + z_mu**2 - 1. - z_logvar, 1)
# VAE loss
vae_loss = tf.reduce_mean(recon_loss + kl_loss)

solver = tf.train.AdamOptimizer(learning_rate=0.0002).minimize(vae_loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

i = 0
cur = 0

test_n = 16
test_it = 0

for it in range(100000):
    if cur == 0 or cur + mb_size == input_size:
        cur = mb_size
        X_mb = input_images[0:mb_size]
        X_mb_shape = X_mb.shape
        X_mb = np.tile(X_mb.reshape((int(mb_size / y_dim), y_dim) + X_mb_shape[1:]), (1, 4, 1, 1, 1))
        X_mb = X_mb.reshape((y_dim * mb_size,) + X_mb_shape[1:])
        Y_mb = np.repeat(input_images[0:mb_size], 4, axis=0)
        y_mb = np.repeat(input_labels[0:mb_size], 4, axis=0)
    else:
        cur = (cur + mb_size) % input_size
        X_mb = input_images[cur-mb_size:cur]
        X_mb_shape = X_mb.shape
        X_mb = np.tile(X_mb.reshape((int(mb_size / y_dim), y_dim) + X_mb_shape[1:]), (1, 4, 1, 1, 1))
        X_mb = X_mb.reshape((y_dim * mb_size,) + X_mb_shape[1:])
        Y_mb = np.repeat(input_images[cur-mb_size:cur], 4, axis=0)
        y_mb = np.repeat(input_labels[cur-mb_size:cur], 4, axis=0)

    _, loss = sess.run([solver, vae_loss], feed_dict={X: X_mb, Y: Y_mb, c: y_mb})

    if it % 1000 == 0 and it != 0:
        y = np.zeros(shape=[test_n, y_dim])
        view = np.random.randint(0, 4)
        y[:, view] = 1.

        out_z = sess.run(z_sample, feed_dict={X: input_images[test_it*test_n:(test_it+1)*test_n], c: y})
        samples = sess.run(X_samples, feed_dict={z: out_z, c: y})

        print('Iter: {}'.format(it) + ", View: " + str(view))
        print('Loss: {:.4}'. format(loss))
        print(np.min(samples), np.max(samples))
        print()
        
        test_it += 1

        fig = plot(samples)
        plt.savefig('out_5/{}.jpg'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

In [None]:
# Predict

# saver = tf.train.Saver()
# saver.save(sess, "checkpoints/model.ckpt")
# saver.restore(sess, "checkpoints/model.ckpt")

for i in range(10000):
    y = np.zeros(shape=[4, y_dim])
    for j in range(4):
        y[j, j] = 1.

    out_z = sess.run(z_sample, feed_dict={X: np.repeat([input_images[i]], 4, axis=0), c: y})
    samples = sess.run(X_samples, feed_dict={z: out_z, c: y})
    
    scipy.misc.imsave("S:\\Projects\\MVC\\128x128_vae\\" + str(i) + "_0.jpg", samples[0])
    scipy.misc.imsave("S:\\Projects\\MVC\\128x128_vae\\" + str(i) + "_1.jpg", samples[1])
    scipy.misc.imsave("S:\\Projects\\MVC\\128x128_vae\\" + str(i) + "_2.jpg", samples[2])
    scipy.misc.imsave("S:\\Projects\\MVC\\128x128_vae\\" + str(i) + "_3.jpg", samples[3])