In [0]:
try:
  %tensorflow_version 2.x
except Exception:
  pass


TensorFlow 2.x selected.


In [0]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras import layers
from IPython import display
import time

In [0]:
(train,_) , (test,_)  = tf.keras.datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [0]:
train_image = train.reshape(60000,28*28)/255

In [0]:
EPOCH = 1000
bs = 128
buffer = 60000

In [0]:
train_image = tf.data.Dataset.from_tensor_slices(train_image).shuffle(buffer).batch(bs)

In [0]:
class GAN(tf.keras.Model):
  
  def __init__(self):
    super(GAN,self).__init__()
    
    
    #Binary Cross Entropy Loss
    
    
    #Generator intitialization
    
    #Model
    self.g_model = tf.keras.Sequential()
    self.g_model.add(layers.Dense(256,activation='relu',input_shape=(100,)))
    self.g_model.add(layers.Dropout(0.1))
    self.g_model.add(layers.Dense(256,activation='relu'))
    self.g_model.add(layers.Dropout(0.3))
    self.g_model.add(layers.Dense(784,activation='tanh'))
    
    #Optimizer
    self.g_optimizer = tf.keras.optimizers.Adam(1e-4)
    
    #Discriminator intitialization
    
    #Model
    self.d_model = tf.keras.Sequential()
    self.d_model.add(layers.Dense(256,activation='relu',input_shape=(784,)))
    self.g_model.add(layers.Dropout(0.2))
    self.d_model.add(layers.Dense(256,activation='relu'))
    self.g_model.add(layers.Dropout(0.2))
    self.d_model.add(layers.Dense(1,activation='sigmoid'))
    
    #Optimizer
    self.d_optimizer = tf.keras.optimizers.Adam(1e-4)
    
  def forward(self,x):
    z = tf.random.normal([bs, 100])
    
    generator_output = self.g_model(z)
    
    self.fake_output = self.d_model(generator_output)
    self.real_output = self.d_model(x)
    
    
  def loss(self,x):
    
    
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    self.g_loss = cross_entropy(tf.ones_like(self.fake_output), self.fake_output)
    
    d_fake = cross_entropy(tf.zeros_like(self.fake_output), self.fake_output)
    d_real = cross_entropy(tf.ones_like(self.real_output), self.real_output)
    
    self.d_loss = d_fake+d_real
    
    return self.g_loss,self.d_loss
  
  
  @tf.function
  def train(self, x):
      
 
      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            self.forward(x)
            g_loss , d_loss = self.loss(x)
        
      g_gradients = gen_tape.gradient(g_loss, self.g_model.trainable_variables)
      d_gradients = disc_tape.gradient(d_loss, self.d_model.trainable_variables)
      
      self.g_optimizer.apply_gradients(zip(g_gradients,self.g_model.trainable_variables))
      self.d_optimizer.apply_gradients(zip(d_gradients,self.d_model.trainable_variables))
      

     

  

In [0]:
def generate_and_save_images(model, epoch):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  test_input = tf.random.normal([16, 100])
  predictions = model(test_input, training=False).numpy().reshape(-1,28,28,1)
 
  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      a = predictions[i, :, :, 0] * 127.5  
      a[a<10]=0
      plt.imshow(a, cmap='gray')
      plt.axis('off')
  if epoch%10==0:
#     plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
      pass
  plt.show()

In [0]:
gan = GAN()

In [0]:
for epoch in range(0,EPOCH+1):
  start = time.time()
  for  x in train_image:
    
    gan.train(x)
  display.clear_output(wait=True)
  generate_and_save_images(gan.g_model,
                             epoch )
    
  print ('Time for epoch {} is {} sec'.format(epoch , time.time()-start))
    
    
    
  