In [None]:
import os
import json
import numpy as np
import scipy.misc
import tensorflow as tf
from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [None]:
mvc_info = json.load(open(r"S:\Projects\MVC\mvc_info.json"))

In [None]:
mb_size = 20
Z_dim = 100
X_dim = 64
y_dim = 4
h_dim = 128
input_size = 10000
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
input_images = []
input_labels = []
for i in range(input_size):
    input_images.append(np.array(Image.open("S:\\Projects\\MVC\\64x64_clean\\" + str(i) + ".jpg")))
    labels = np.zeros(y_dim)
    labels[i % 4] = 1
    input_labels.append(labels)
input_images = np.array(input_images).astype("float32") / 255.0
input_labels = np.array(input_labels)

In [None]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

def bn(inputs):
    return tf.layers.batch_normalization(inputs)

def conv(inputs, filters, kernel_size, strides, activation=tf.nn.relu):
    return tf.layers.conv2d(inputs, filters, kernel_size, strides, padding="same", activation=activation)

def dense(inputs, units, activation=tf.nn.relu):
    return tf.layers.dense(inputs, units, activation, kernel_initializer=tf.truncated_normal_initializer())

In [None]:
#mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128

In [None]:
tf.reset_default_graph()


""" Discriminator Net model """
X = tf.placeholder(tf.float32, shape=[None, 784])
y = tf.placeholder(tf.float32, shape=[None, y_dim])

def discriminator(x, y, reuse):
    inputs = tf.concat(axis=1, values=[x, y])

    with tf.variable_scope("D", reuse=tf.AUTO_REUSE):
        D_h1 = dense(inputs, h_dim)
        D_logit = dense(D_h1, 1, activation=None)
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit


""" Generator Net model """
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])


def generator(z, y):
    inputs = tf.concat(axis=1, values=[z, y])

    with tf.variable_scope("G"):
        G_h1 = dense(inputs, h_dim)
        G_log_prob = dense(G_h1, X_dim, activation=None)
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob


def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])


G_sample = generator(Z, y)
D_real, D_logit_real = discriminator(X, y, False)
D_fake, D_logit_fake = discriminator(G_sample, y, True)

D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))


vars = tf.trainable_variables()
d_params = [v for v in vars if v.name.startswith('D/')]
g_params = [v for v in vars if v.name.startswith('G/')]
print(d_params)


D_solver = tf.train.AdamOptimizer(learning_rate=0.001).minimize(D_loss, var_list=d_params)
G_solver = tf.train.AdamOptimizer(learning_rate=0.001).minimize(G_loss, var_list=g_params)

sess = tf.Session()
sess.run(tf.global_variables_initializer())


train_writer = tf.summary.FileWriter('tb', sess.graph)


i = 0

for it in range(1000000):
    X_mb, y_mb = mnist.train.next_batch(mb_size)

    Z_sample = sample_Z(mb_size, Z_dim)
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: Z_sample, y:y_mb})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: Z_sample, y:y_mb})

    if it % 1000 == 0:
        n_sample = 16

        Z_sample = sample_Z(n_sample, Z_dim)
        y_sample = np.zeros(shape=[n_sample, y_dim])
        y_sample[:, 7] = 1

        samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})

        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)
        
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()