In [None]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.backend import *
from tensorflow.keras.optimizers import * 
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(threshold=np.inf)

In [None]:
# flags
img_size = 28 # mnist image width/height

seed_size = 100 # size of randomly seed
kernel_size = 5 # conv kernal

batch_size = 100
epochs = 2
learning_rate = 0.001

In [None]:
# prepare data
mnist = tf.keras.datasets.mnist

(x_train, _),(_,_) = mnist.load_data()

x_train = (x_train.reshape(x_train.shape[0], img_size, img_size, 1).astype('float32') - 127.5) / 255.

total_batch = len(x_train) / batch_size

def next_batch():
    for i in range(0, len(x_train), batch_size):
        yield x_train[i:i + batch_size]
        
def next_sample(batch_size):
    d = np.random.normal(0, 1.0, (batch_size, seed_size)).astype('float32')
    return d


In [None]:
# test next_batch and next_sample

# bs = [i for i in next_batch()][0]
# for i in bs[0:10]:
#     plt.figure()
#     plt.imshow(i[:,:,0]*255+127.5)
# plt.show()

# print(next_sample(10))

In [None]:
generator = tf.keras.Sequential([
    Dense(7*7*batch_size, use_bias=False, input_shape=(seed_size,)),
    BatchNormalization(),LeakyReLU(),
    Reshape((7,7,batch_size)),
    Conv2DTranspose(filters=128,
                    kernel_size=(kernel_size,kernel_size),
                    strides=1,
                    padding = 'same',
                    use_bias=False),
    BatchNormalization(),LeakyReLU(),
    Conv2DTranspose(filters=64,
                    kernel_size=(kernel_size,kernel_size),
                    strides=2,
                    padding = 'same',
                    use_bias=False),
    BatchNormalization(),LeakyReLU(),
    Conv2DTranspose(filters=1,
                    kernel_size=(kernel_size,kernel_size),
                    strides=2,
                    padding = 'same',
                    use_bias=False),
])


In [None]:
# test image generation

# with tf.Session() as sess:
#     sess.run(tf.global_variables_initializer())
    
#     z = next_sample(1)
#     image = generator(z, training=False).eval(session=sess)    
#     plt.figure()
#     plt.imshow(image[0,:,:,0], cmap='gray')

In [None]:
# classifiy generated/train image as 1(real) or 0(fake)
discriminator = tf.keras.Sequential([
    Conv2D(filters = 64,
        kernel_size=(kernel_size,kernel_size),
        strides= 2,
        padding = 'same',
        input_shape=(28,28,1)),
    LeakyReLU(),
    Dropout(0.3),
    Conv2D(filters = 128,
        kernel_size=(kernel_size,kernel_size),
        strides= 2,
        padding = 'same'),
    LeakyReLU(),
    Dropout(0.3),  
    Flatten(),
    Dense(1)
])

def generator_loss(fake):
    return binary_crossentropy(tf.ones_like(fake), fake, from_logits=True)

def discriminator_loss(fake, real):
    return binary_crossentropy(tf.zeros_like(fake), fake,from_logits=True) + binary_crossentropy(tf.ones_like(real), real, from_logits=True)


In [None]:
# test g and d
# with tf.Session() as sess:
#     sess.run(tf.global_variables_initializer())
    
#     samples = tf.placeholder(tf.float32,[None,100])
#     generated = generator(samples)
#     fake = discriminator(generated)
    
#     images = tf.placeholder(tf.float32,[None,28,28,1])
#     real = discriminator(images) 
    
#     g_loss = generator_loss(fake)
#     d_loss = discriminator_loss(fake,real)
    
#     z = next_sample(10)
#     f,r,g,d = sess.run([fake,real, g_loss, d_loss], feed_dict={samples:z, images:x_train[0:10]})
#     print(f,'\n')
#     print(r,'\n')
#     print(g,'\n')
#     print(d,'\n')


In [None]:
# model
    
samples = tf.placeholder(tf.float32,[None,100])
generated = generator(samples)
fake = discriminator(generated)

images = tf.placeholder(tf.float32,[None,28,28,1])
real = discriminator(images) 

g_loss = generator_loss(fake)
d_loss = discriminator_loss(fake,real)    

g_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss,var_list = generator.trainable_variables)
d_opt = tf.train.AdagradOptimizer(learning_rate).minimize(d_loss,var_list = discriminator.trainable_variables)

# train
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    
    test_sample = next_sample(100)
    
    for epoch in range(epochs):
        print('epoch',epoch)
        
        for x in next_batch():
            z = next_sample(batch_size)
            _,g_l = sess.run([g_opt, g_loss], feed_dict= {
                samples:z,
            })
            _,d_l = sess.run([d_opt, d_loss], feed_dict= {
                samples:z,
                images:x
            })
            
            
    image = sess.run([generated], feed_dict={samples:test_sample})
    image = np.array(image) * 255 + 127.5
    plt.figure(figsize=(10,10))
    for i in range(16):
        plt.subplot(4,4,i+1)
        j = choice(range(batch_size))
        plt.imshow(image[0,j,:,:,0])
    plt.show()