## Imports

In [None]:
import tensorflow as tf
import tensorflow.keras as keras

import numpy as np
import matplotlib.pyplot as plt
from IPython import display
import pandas as pd

## Download and Prepare the Dataset

in this experiment, I will use the colab built-in dataset-"mnist_train_small.csv"

In [None]:
data=pd.read_csv("/content/sample_data/mnist_train_small.csv")

In [None]:
X_train=data.iloc[:,1:]

In [None]:
X_train.shape

In [None]:
X_train=X_train.to_numpy().reshape(-1,28,28,1)

In [None]:
X_train = X_train.astype(np.float32) / 255 * 2. - 1.

## define the function to plot images

In [None]:
def plot_results(images, n_cols=None):
    '''visualizes fake images'''
    display.clear_output(wait=False)

    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1

    if images.shape[-1] == 1:
        images = np.squeeze(images, axis=-1)

    plt.figure(figsize=(n_cols, n_rows))

    for index, image in enumerate(images):
        plt.subplot(n_rows, n_cols, index + 1)
        plt.imshow(image, cmap="binary")
        plt.axis("off")

### demonstrate image samples

In [None]:
plot_results(X_train[0:32,:,:,0], 8)

In [None]:
BATCH_SIZE = 128

# create batches of tensors to be fed into the model
dataset = tf.data.Dataset.from_tensor_slices(X_train)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True).prefetch(1)

## Build the Model

### Generator

In [None]:
random_normal_dimensions = 100

generator = keras.models.Sequential([
    keras.layers.Dense(7 * 7 * 128, input_shape=[random_normal_dimensions]),
    keras.layers.Reshape([7, 7, 128]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2, padding="SAME",
                                 activation="selu"),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2, padding="SAME",
                                 activation="tanh"),
])

generator.summary()

#### display fake images from untrained generator

In [None]:
test_noise = tf.random.normal([32, random_normal_dimensions])

# feed the batch to the untrained generator
test_image = generator(test_noise)

# visualize sample output
plot_results(test_image, n_cols=8)

print(f'shape of the generated batch: {test_image.shape}')

### Discriminator

In [None]:
discriminator = keras.models.Sequential([
    keras.layers.Conv2D(64, kernel_size=5, strides=2, padding="SAME",
                        activation=keras.layers.LeakyReLU(0.2),
                        input_shape=[28, 28, 1]),
    keras.layers.Dropout(0.4),
    keras.layers.Conv2D(128, kernel_size=5, strides=2, padding="SAME",
                        activation=keras.layers.LeakyReLU(0.2)),
    keras.layers.Dropout(0.4),
    keras.layers.Flatten(),
    keras.layers.Dense(1, activation="sigmoid")
])

discriminator.summary()

## DCGAN.

In [None]:
gan = keras.models.Sequential([generator, discriminator])

## Configure the Model for training

In [None]:
from tensorflow.keras import Model, losses

discriminator.compile(loss=losses.BinaryCrossentropy(), optimizer="rmsprop")
discriminator.trainable = False
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")

## Train the Model

In [None]:
def train_gan(gan, dataset, random_normal_dimensions, n_epochs=50):

    generator, discriminator = gan.layers
    dloss=[]
    gloss=[]

    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs))
        for real_images in dataset:
            # infer batch size from the training batch
            batch_size = real_images.shape[0]

            # Train the discriminator - PHASE 1
            noise = tf.random.normal(shape=[batch_size, random_normal_dimensions])
            fake_images = generator(noise)

            mixed_images = tf.concat([fake_images, real_images], axis=0)
            discriminator_labels = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)

            discriminator.trainable = True
            discriminator.train_on_batch(mixed_images, discriminator_labels)

            #print Discriminator Loss

            with tf.GradientTape() as tape:
              discriminator_loss = losses.BinaryCrossentropy()(tf.ones_like(discriminator(real_images)), discriminator(real_images))
              discriminator_loss += losses.BinaryCrossentropy()(tf.zeros_like(discriminator(fake_images)), discriminator(fake_images))
              gradients = tape.gradient(discriminator_loss, discriminator.trainable_variables)



            # Train the generator - PHASE 2
            noise = tf.random.normal(shape=[batch_size, random_normal_dimensions])
            generator_labels = tf.constant([[1.]] * batch_size)

            # freeze the discriminator
            discriminator.trainable = False

            # train the GAN on the noise with the labels all set to be true
            gan.train_on_batch(noise, generator_labels)

            with tf.GradientTape() as tape:
              generator_loss = losses.BinaryCrossentropy()(tf.ones_like(discriminator(fake_images)), discriminator(fake_images))
              gradients = tape.gradient(generator_loss, generator.trainable_variables)



        print(f"Epoch: {epoch+1}, Discriminator Loss: {discriminator_loss.numpy():.4f}, Generator Loss: {generator_loss.numpy():.4f}")
        dloss.append(discriminator_loss.numpy())
        gloss.append(generator_loss.numpy())

    return generator,dloss,gloss


In [None]:
n_epochs=100
generator,dloss,gloss=train_gan(gan, dataset, random_normal_dimensions, n_epochs)

In [None]:
plt.plot(range(n_epochs),dloss,label="discriminator loss")
plt.plot(range(n_epochs),gloss,label="generator loss")
plt.legend()
plt.xlabel("epochs")
plt.ylabel("loss")
plt.show()

### display the fake images from generator after 20 epochs training

In [None]:
def sythetic_data(generator,num_dt,random_normal_dimensinos):
  noise=tf.random.normal(shape=(num_dt,random_normal_dimensinos))
  sythetic_data=generator(noise)
  return sythetic_data

In [None]:
sythetic_images = sythetic_data(generator,32,random_normal_dimensions)

plot_results(sythetic_images, 8)