<a href="https://colab.research.google.com/github/x-Kevin-Paul-x/GANwithMNIST/blob/main/MNIST_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input , Dense, Flatten, Reshape, BatchNormalization, LeakyReLU, Conv2DTranspose, Conv2D


In [29]:
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = X_train / 127.5 - 1.0
X_train = np.expand_dims(X_train, axis=-1)

In [30]:
img_shape = (28,28,1)

In [31]:
def build_generator():
    noise_shape = (100,)


    model = Sequential()

    model.add(Dense(256, input_shape=noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))

    model.summary()

    noise = Input(shape=noise_shape)
    img = model(noise)
    return Model(noise, img)


In [32]:
def build_discriminator():
    model = Sequential()

    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)

In [33]:
optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)

In [34]:
discriminator = build_discriminator()
discriminator.compile(optimizer=tf.keras.optimizers.Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])

Model: "sequential_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_5 (Flatten)         (None, 784)               0         
                                                                 
 dense_39 (Dense)            (None, 512)               401920    
                                                                 
 leaky_re_lu_28 (LeakyReLU)  (None, 512)               0         
                                                                 
 dense_40 (Dense)            (None, 256)               131328    
                                                                 
 leaky_re_lu_29 (LeakyReLU)  (None, 256)               0         
                                                                 
 dense_41 (Dense)            (None, 1)                 257       
                                                                 
Total params: 533505 (2.04 MB)
Trainable params: 5335

In [35]:
generator = build_generator()

z = tf.keras.Input(shape=(100,))
img = generator(z)
discriminator.trainable = False
real_or_fake = discriminator(img)

gan = Model(z, real_or_fake)

gan.compile(optimizer=tf.keras.optimizers.Adam(0.0002 , 0.5), loss='binary_crossentropy')

Model: "sequential_12"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_42 (Dense)            (None, 256)               25856     
                                                                 
 leaky_re_lu_30 (LeakyReLU)  (None, 256)               0         
                                                                 
 batch_normalization_18 (Ba  (None, 256)               1024      
 tchNormalization)                                               
                                                                 
 dense_43 (Dense)            (None, 512)               131584    
                                                                 
 leaky_re_lu_31 (LeakyReLU)  (None, 512)               0         
                                                                 
 batch_normalization_19 (Ba  (None, 512)               2048      
 tchNormalization)                                   

In [36]:
def save_imgs(epoch, num_samples=25):
    noise = np.random.normal(0, 1, (num_samples, 100))
    gen_imgs = generator.predict(noise)
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(5, 5)
    count = 0
    for i in range(5):
        for j in range(5):
            axs[i, j].imshow(gen_imgs[count, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            count += 1
    fig.savefig(f"/content/Photos/mnist_{epoch}.png")
    plt.close()



In [37]:
def train_gan(epochs, batch_size=128, save_interval=1000):
    (X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
    X_train = X_train / 127.5 - 1.0
    X_train = np.expand_dims(X_train, axis=-1)

    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]

        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = gan.train_on_batch(noise, valid)

        print(f"{epoch} [D loss: {d_loss[0]} | D accuracy: {100*d_loss[1]}] [G loss: {g_loss}]")

        if epoch % save_interval == 0:
            save_imgs(epoch)

In [26]:
train_gan(epochs=10000, batch_size=128, save_interval=100)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
7512 [D loss: 0.6643241047859192 | D accuracy: 56.640625] [G loss: 0.8363254070281982]
7513 [D loss: 0.6460631787776947 | D accuracy: 63.671875] [G loss: 0.8545323014259338]
7514 [D loss: 0.6578714847564697 | D accuracy: 59.765625] [G loss: 0.8597559332847595]
7515 [D loss: 0.661029577255249 | D accuracy: 59.765625] [G loss: 0.8329169750213623]
7516 [D loss: 0.6832337379455566 | D accuracy: 58.59375] [G loss: 0.8792902231216431]
7517 [D loss: 0.6823822259902954 | D accuracy: 57.8125] [G loss: 0.84511399269104]
7518 [D loss: 0.6867980360984802 | D accuracy: 53.90625] [G loss: 0.8523104786872864]
7519 [D loss: 0.6771742701530457 | D accuracy: 56.640625] [G loss: 0.864106297492981]
7520 [D loss: 0.6794019341468811 | D accuracy: 59.765625] [G loss: 0.8254757523536682]
7521 [D loss: 0.6281854808330536 | D accuracy: 68.359375] [G loss: 0.8663041591644287]
7522 [D loss: 0.6655555665493011 | D accuracy: 60.9375] [G loss: 0.846800