In [None]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.backend import *
from tensorflow.keras.optimizers import * 
import matplotlib.pyplot as plt

In [None]:
#flags
img_size = 28 # mnist image width/height

seed_size = 100 # size of randomly seed
kernel_size = 5

batch_size = 100
epochs = 1
learning_rate = 0.001

In [None]:
# prepare data
mnist = tf.keras.datasets.mnist

(x_train, _),(_,_) = mnist.load_data()

x_train = (x_train.reshape(x_train.shape[0],img_size,img_size,1).astype('float32') -127.5) / 255.

def next_batch():
    for i in range(0, len(x_train), batch_size):
        yield x_train[i:i + batch_size]
        
total_batch = len(x_train)/batch_size
    

In [None]:
# generator
# that generates an image from a random noise

generator = tf.keras.Sequential([
    Dense(int(img_size/4 * img_size/4 * 256),
          use_bias=False, input_shape=(seed_size,)),
#     BatchNormalization(),
#     LeakyReLU(),
    Reshape((int(img_size/4),int(img_size/4), 256)),
    
    Conv2DTranspose(filters=128,
                    kernel_size=(kernel_size,kernel_size),
                    strides=1,
                    padding = 'same',
                    use_bias=False),
#     BatchNormalization(),
#     LeakyReLU(),
    
    Conv2DTranspose(filters=64,
                kernel_size=(kernel_size,kernel_size),
                strides=2,
                padding = 'same',
                use_bias=False),
#     BatchNormalization(),
#     LeakyReLU(),
    
    Conv2DTranspose(filters=1,
            kernel_size=(kernel_size,kernel_size),
            strides= 2,
            padding = 'same',
            use_bias=False,
            activation='tanh')
])

for layer in generator.layers:
    print(layer.input_shape, layer.output_shape)
    
print(generator.trainable_variables)

In [None]:
# discriminator
# that classifiy generated/real image as 1(real) or 2(fake)
discriminator = tf.keras.Sequential([
    Conv2D(
        filters = 64,
        kernel_size=(kernel_size,kernel_size),
        strides= 2,
        padding = 'same',
        input_shape=(28,28,1)),
#     LeakyReLU(),
#     Dropout(0.3),
    
    Conv2D(
        filters = 128,
        kernel_size=(kernel_size,kernel_size),
        strides= 2,
        padding = 'same'),
#     LeakyReLU(),
#     Dropout(0.3),
    
    Flatten(),
    Dense(1)
])

for layer in discriminator.layers:
    print(layer.input_shape, layer.output_shape)

print(discriminator.trainable_variables)

In [None]:
# run image generation once
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    seed = tf.random.normal([1, seed_size])
    image = generator(seed, training=False)
    img = image.eval(session=sess)
    
    plt.imshow(img[0,:,:,0], cmap='gray')

In [None]:
# loss
def generator_loss(fake):
    all_zeros = tf.zeros_like(fake)
    return binary_crossentropy(all_zeros, fake)

def discriminator_loss(fake, real):
    all_zeros = tf.zeros_like(fake)
    all_ones = tf.ones_like(real)
    return binary_crossentropy(all_zeros, fake) + binary_crossentropy(all_ones, real)

# optimizer
generator_optimizer = tf.train.AdagradOptimizer(learning_rate)
discriminator_optimizer = tf.train.AdagradOptimizer(learning_rate)


In [None]:
# train
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for epoch in range(epochs):
        i = 0
        for batch in next_batch():
            if i % 100 == 0:
                print('epoch', epoch, "batch", i, '/', total_batch)
            
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                real = discriminator(batch, training = True)

                noise = tf.random.normal([batch_size, seed_size])
                generated = generator(noise, training = True)
                fake = discriminator(generated, training = True)

                gen_loss = generator_loss(fake)
                disc_loss = discriminator_loss(fake, real)

                gen_gradients = gen_tape.gradient(gen_loss,generator.trainable_variables)
                disc_gradients = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
                       
                generator_optimizer.apply_gradients(zip(gen_gradients,generator.trainable_variables))
                discriminator_optimizer.apply_gradients(zip(disc_gradients,discriminator.trainable_variables))
                i = i+1
                
        # run image generation 
        # once at the end of each epoch
        images = generator(noise, training=False).eval(session=sess)
        
        plt.figure(figsize=(10,10))
        for i in range(images.shape[0]):
            plt.subplot(10,10,i+1)
            plt.imshow(images[i,:,:,0] * 127.5 + 127.5, cmap='gray')
        plt.show()
                