<a href="https://colab.research.google.com/github/srv96/AI-ML-TensorFlow/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys , os
import gc

In [None]:
mnist = tf.keras.datasets.mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train , X_test = X_train / 255.0 * 2 - 1, X_test / 255.0 * 2 - 1

print("X_train.shape : ",X_train.shape)

X_train.shape :  (60000, 28, 28)


In [None]:
N , H , W = X_train.shape
D = H * W

X_train = X_train.reshape(-1, D)
X_test = X_test.reshape(-1, D)

print("X_train.shape : ",X_train.shape)

X_train.shape :  (60000, 784)


In [None]:
latent_dim = 100

In [None]:
def build_generator(latent_dim):
  i = tf.keras.Input(shape=(latent_dim,))
  x = tf.keras.layers.Dense(256, activation=tf.keras.layers.LeakyReLU(negative_slope=0.2))(i)
  x = tf.keras.layers.BatchNormalization(momentum=0.8)(x)
  x = tf.keras.layers.Dense(512, activation=tf.keras.layers.LeakyReLU(negative_slope=0.2))(x)
  x = tf.keras.layers.BatchNormalization(momentum=0.8)(x)
  x = tf.keras.layers.Dense(1024, activation=tf.keras.layers.LeakyReLU(negative_slope=0.2))(x)
  x = tf.keras.layers.BatchNormalization(momentum=0.8)(x)
  x = tf.keras.layers.Dense(D, activation='tanh')(x)

  model = tf.keras.models.Model(i, x)
  return model

In [None]:
def build_discriminator(img_size):
  i = tf.keras.Input(shape=(img_size,))
  x = tf.keras.layers.Dense(512, activation=tf.keras.layers.LeakyReLU(negative_slope=0.2))(i)
  x = tf.keras.layers.Dense(256, activation=tf.keras.layers.LeakyReLU(negative_slope=0.2))(x)
  x = tf.keras.layers.Dense(1, activation='sigmoid')(x)
  model = tf.keras.models.Model(i, x)
  return model


In [None]:
discriminator = build_discriminator(D)
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
    metrics=['accuracy'],
)

generator = build_generator(latent_dim)

z = tf.keras.Input(shape=(latent_dim,))
img = generator(z)
discriminator.trainable = False
fake_pred = discriminator(img)
combined_model = tf.keras.models.Model(z , fake_pred)
combined_model.compile(
    loss='binary_crossentropy',
    optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
)

In [None]:
combined_model.summary()

In [None]:
batch_size = 32
epochs =  30000
sample_period = 200 #every sample period generate some data

ones = np.ones(batch_size)
zeros = np.zeros(batch_size)

d_losses = []
g_losses = []

if not os.path.exists('gan_images'):
  os.makedirs('gan_images')

In [None]:
def sample_images(epoch):
  rows,cols = 5,5
  noise = np.random.randn(rows * cols , latent_dim)
  imgs = generator.predict(noise)

  #rescale images 0 - 1
  imgs = 0.5 * imgs + 0.5

  fig , axs = plt.subplots(rows , cols)
  idx = 0
  for i in range(rows):
    for j in range(cols):
      axs[i,j].imshow(imgs[idx].reshape(H,W) , cmap='gray')
      axs[i,j].axis('off')
      idx += 1
  fig.savefig("gan_images/%d.png" % epoch)
  plt.close()

In [None]:
for epoch in range(epochs):

  idx = np.random.randint(0,X_train.shape[0],batch_size)
  real_imgs = X_train[idx]

  noise = np.random.randn(batch_size , latent_dim)
  fake_imgs = generator.predict(noise)

  d_loss_real,d_acc_real = discriminator.train_on_batch(real_imgs, ones)
  d_loss_fake,d_acc_fake = discriminator.train_on_batch(fake_imgs, zeros)
  d_loss = 0.5 * (d_loss_real + d_loss_fake)
  d_acc = 0.5 * (d_acc_real + d_acc_fake)

  noise = np.random.randn(batch_size , latent_dim)
  g_loss = combined_model.train_on_batch(noise , ones)

  d_losses.append(d_loss)
  g_losses.append(g_loss)

  if epoch % 100 == 0:
    print(f"epoch: {epoch+1}/{epochs}, d_loss: {d_loss:.2f}, d_acc: {d_acc:.2f}, g_loss: {g_loss[-1]:.2f}")

  if epoch % sample_period == 0:
    sample_images(epoch)

  del real_imgs, fake_imgs, d_loss_real, d_loss_fake
  del g_loss, d_acc_real, d_acc_fake
  del noise
  gc.collect()

  if epoch % 500 == 0:
    tf.keras.backend.clear_session()
