In [12]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

In [13]:
(X_train, _), (_, _) = mnist.load_data()

X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=-1)


latent_dim = 100
img_shape = (28, 28, 1)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [14]:
def build_generator():
    model = models.Sequential()


    model.add(layers.Dense(128 * 7 * 7, activation="relu", input_dim=latent_dim))
    model.add(layers.Reshape((7, 7, 128)))


    model.add(layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same'))
    model.add(layers.ReLU())


    model.add(layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same'))
    model.add(layers.ReLU())


    model.add(layers.Conv2D(1, (3, 3), padding='same', activation='tanh'))

    return model


In [15]:
def build_discriminator():
    model = models.Sequential()

    model.add(layers.Conv2D(64, (3, 3), strides=(2, 2), input_shape=img_shape, padding="same"))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1, activation='sigmoid'))

    return model

In [16]:
generator = build_generator()

discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

discriminator.trainable = False

gan_input = layers.Input(shape=(latent_dim,))
generated_img = generator(gan_input)
validity = discriminator(generated_img)

gan = models.Model(gan_input, validity)
gan.compile(loss='binary_crossentropy', optimizer='adam')

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [23]:
def train_gan(epochs, batch_size=128, save_interval=100):
    half_batch = batch_size // 2

    for epoch in range(epochs):

        idx = np.random.randint(0, X_train.shape[0], half_batch)
        real_imgs = X_train[idx]

        noise = np.random.normal(0, 1, (half_batch, latent_dim))
        fake_imgs = generator.predict(noise)

        real_labels = np.ones((half_batch, 1))
        fake_labels = np.zeros((half_batch, 1))

        d_loss_real = discriminator.train_on_batch(real_imgs, real_labels)
        d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        misleading_labels = np.ones((batch_size, 1))

        g_loss = gan.train_on_batch(noise, misleading_labels)

        if epoch % save_interval == 0:
            print(f"Epoch {epoch}, D Loss: {d_loss[0]:.4f}, D Acc: {100 * d_loss[1]:.2f}%, G Loss: {g_loss:.4f}")
            save_generated_images(epoch)

In [26]:
def save_generated_images(epoch, num_images=10):
    noise = np.random.normal(0, 1, (num_images, latent_dim))
    generated_imgs = generator.predict(noise)

    generated_imgs = 0.5 * generated_imgs + 0.5

    fig, axs = plt.subplots(1, num_images, figsize=(20, 2))
    for i in range(num_images):
        axs[i].imshow(generated_imgs[i, :, :, 0], cmap='gray')
        axs[i].axis('off')

    plt.savefig(f"generated_images_epoch_{epoch}.png")
    plt.close()