<a href="https://colab.research.google.com/github/rrrovalle/basic_gan/blob/main/basic_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import keras
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np

x,y = mnist.load_data()  
img_rows  = 28
img_cols  = 28
channels  = 1
img_shape = (img_rows, img_cols, channels) 

def build_generator():
  
  noise_shape = (100,) #1D array of size 100(latentvector/noise)

  model = Sequential()
  # First Layer
  model.add(Dense(256,input_shape=noise_shape))
  model.add(LeakyReLU(alpha=0.2))
  model.add(BatchNormalization(momentum=0.8))
  # Second Layer
  model.add(Dense(512))
  model.add(LeakyReLU(alpha=0.2))
  model.add(BatchNormalization(momentum=0.8))
  # Third Layer
  model.add(Dense(1024))
  model.add(LeakyReLU(alpha=0.2))
  model.add(BatchNormalization(momentum=0.8))

  model.add(Dense(np.prod(img_shape), activation='tanh'))
  model.add(Reshape(img_shape))

  model.summary()
  
  noise = Input(shape = noise_shape)
  img   = model(noise) # Generated image

  return Model(noise, img)


def build_discriminator():
  
  model = Sequential()

  model.add(Flatten(input_shape = img_shape))
  model.add(Dense(512)) 
  model.add(LeakyReLU(alpha=0.2)) 
  model.add(Dense(256)) 
  model.add(LeakyReLU(alpha=0.2)) 
  model.add(Dense(1,activation='sigmoid'))

  img      = Input(shape = img_shape)
  validity = model(img)
  
  return Model(img, validity) # Validity is the discriminator's guess of input being real or not.


def train(epochs, batch_size=128, save_interval=50):

  #Load the dataset
  (X_train, _), (_, _) = mnist.load_data()

  # Convert to float nd Rescale -1 to 1 
  X_train = (X_train.astype(np.float32) - 127.5) / 127.5 

  # Add channels dimension. As the input to out gen and disc. has shape 28x28x1
  X_train = np.expand_dims(X_train, axis=3)

  half_batch = int(batch_size/2)

  for epoch in range(epochs):
    # Train Discriminator

    # Select a random half batch of real images
    idx  = np.random.randint(0, X_train.shape[0], half_batch)
    imgs = X_train[idx]

    noise = np.random.normal(0,1,(half_batch,100))

    gen_imgs = generator.predict(noise)

    d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
    d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch,1)))

    # take avera loss from real and fake images
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train Generator
    noise = np.random.normal(0,1,(batch_size, 100))

    # Generator wants the discriminator to label the generated samples as valid ones
    # this is where the generator is trying to trick discriminator into believing (hence value of 1 for y)
    valid_y = np.array([1]*batch_size) #create an array of all ones of size-batch size
    
    # Generator is part of combined where it got directly linked with the disc.
    #  train the generator with noise as x and 1 as y. Again, 1 as the output as it is 
    #  adversarial and if generator did a great job of folling the disc. the output would be 1 (true) 
    g_loss = combined.train_on_batch(noise, valid_y)


    print("%d [D loss: %f, acc.: %.2f%% [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1],g_loss))

    #if at save interval -> save generated image samples
    if epoch % save_interval == 0:
      save_imgs(epoch)


def save_imgs(epoch):
  r,c = 5,5
  noise = np.random.normal(0,1,(r*c,100))
  gen_imgs = generator.predict(noise)
  
  # rescale images 0-1
  gen_imgs = 0.5 * gen_imgs + 0.5

  fig, axs = plt.subplots(r,c)
  cnt = 0
  for i in range(r):
    for j in range(c):
      axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
      axs[i,j].axis('off')
      cnt += 1
  fig.savefig("images/mnits_%d.png" % epoch)
  plt.close()


optimizer = Adam(0.0002,0.5)
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])

generator = build_generator()
generator.compile(loss='binary_crossentropy', optimizer=optimizer)

z   = Input(shape=(100,))
img = generator(z)

discriminator.trainable = False
valid = discriminator(img)

combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
train(epochs=30000, batch_size=32, save_interval=200)

generator.save('generator_model_test.h5')
