In [None]:
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.losses import mean_squared_error, mean_absolute_error

import os
import time
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds


In [None]:
def normalize(input_image, label):  
    input_image = tf.cast(input_image, tf.float32)
    input_image = (input_image / 127.5) - 1
    return input_image


In [None]:
class DCGAN():
    def __init__(self, rows, cols, channels, epochs, batch_size, z = 100):
        # Input shape
        self.img_rows = rows
        self.img_cols = cols
        self.channels = channels
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = z

        self.epochs = epochs
        self.enable_function = True
        self.batch_size = batch_size
        
        self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.generator_optimizer = tf.keras.optimizers.Adam(1e-4)
        self.discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
        self.checkpoint = tf.train.Checkpoint(
            generator_optimizer=self.generator_optimizer,
            discriminator_optimizer=self.discriminator_optimizer,
            generator=self.generator,
            discriminator=self.discriminator)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1))

        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)
    
    def train(self, dataset, checkpoint_pr,  save_interval=50):
      time_list = []
      if self.enable_function:
        self.train_step = tf.function(self.train_step)

      for epoch in range(self.epochs):
        start_time = time.time()
        for image, _ in dataset:
          gen_loss, disc_loss = self.train_step(image)

        wall_time_sec = time.time() - start_time
        time_list.append(wall_time_sec)

        str_template = 'Epoch {}, Generator loss {}, Discriminator Loss {}'
        print (str_template.format(epoch, gen_loss, disc_loss))

        if epoch % save_interval == 0:
          # saving (checkpoint) the model every 15 epochs
          self.checkpoint.save(file_prefix=checkpoint_pr)
          # save generated image samples
          self.save_imgs(epoch)
        
      self.save_imgs(self.epochs)

      return time_list

    def generator_loss(self, generated_output):
      return self.loss_object(tf.ones_like(generated_output), generated_output)

    def discriminator_loss(self, real_output, generated_output):
      real_loss = self.loss_object(tf.ones_like(real_output), real_output)
      generated_loss = self.loss_object(
          tf.zeros_like(generated_output), generated_output)

      total_loss = real_loss + generated_loss

      return total_loss

        
    def train_step(self, image):
      noise = tf.random.normal([self.batch_size, self.latent_dim])

      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = self.generator(noise, training=True)

        real_output = self.discriminator(image, training=True)
        generated_output = self.discriminator(generated_images, training=True)

        gen_loss = self.generator_loss(generated_output)
        disc_loss = self.discriminator_loss(real_output, generated_output)

      gradients_of_generator = gen_tape.gradient(
          gen_loss, self.generator.trainable_variables)
      gradients_of_discriminator = disc_tape.gradient(
          disc_loss, self.discriminator.trainable_variables)

      self.generator_optimizer.apply_gradients(zip(
          gradients_of_generator, self.generator.trainable_variables))
      self.discriminator_optimizer.apply_gradients(zip(
          gradients_of_discriminator, self.discriminator.trainable_variables))

      return gen_loss, disc_loss


    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/dcgan_mnist_%d.png" % epoch)
        plt.close()


In [None]:
buffer_size = 10000
batch_size = 64
epochs = 10
train_dataset = create_dataset(buffer_size, batch_size)
checkpoint_pr = 'ckpt'

dcgan = DCGAN(28,28,1, epochs, batch_size)
print ('Training ...')
dcgan.train(train_dataset, checkpoint_pr, save_interval=2)

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 6272)              633472    
                                                                 
 reshape (Reshape)           (None, 7, 7, 128)         0         
                                                                 
 up_sampling2d (UpSampling2D  (None, 14, 14, 128)      0         
 )                                                               
                                                                 
 conv2d (Conv2D)             (None, 14, 14, 128)       147584    
                                                                 
 batch_normalization (BatchN  (None, 14, 14, 128)      512       
 ormalization)                                                   
                                                                 
 activation (Activation)     (None, 14, 14, 128)       0

[45.848536252975464,
 35.64983105659485,
 35.62977480888367,
 35.4211049079895,
 35.38595175743103,
 35.44247651100159,
 35.377036809921265,
 35.366722106933594,
 35.4345383644104,
 35.5413556098938]