In [None]:
from tensorflow import keras
from keras import Model
from keras.layers import (
    Dense,
    Conv2D,
    LeakyReLU,
    Conv2DTranspose,
    BatchNormalization,
    Reshape,
    Dropout,
    Flatten,
    Input,
)
from keras.utils import plot_model
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets.mnist import load_data
import warnings

warnings.filterwarnings("ignore")

In [None]:
LATENT_DIM = 128
IMAGE_SHAPE = (28, 28, 1)

In [None]:
def load_mnist_data():
    (X_train, _), (_, _) = load_data()
    X_train = np.expand_dims(X_train, axis=-1)
    X = X_train.astype("float32") / 255.0
    return X

In [None]:
def build_discriminator(image_size=(28, 28, 1)):
    inputs = Input(shape=image_size)

    conv_1 = Conv2D(64, (3, 3), strides=(2, 2), padding="same")(inputs)
    relu_1 = LeakyReLU(0.2)(conv_1)
    conv_2 = Conv2D(64, (3, 3), strides=(2, 2), padding="same")(relu_1)
    relu_2 = LeakyReLU(0.2)(conv_2)
    flatten_1 = Flatten()(relu_2)
    dense_1 = Dense(1, activation="sigmoid")(flatten_1)

    discriminator = Model(inputs=inputs, outputs=dense_1, name="discriminator")
    discriminator.compile(
        metrics=["accuracy"],
        loss="binary_crossentropy",
        optimizer=keras.optimizers.Adam(0.002, beta_1=0.5),
    )
    return discriminator

In [None]:
discriminator = build_discriminator()
plot_model(discriminator, show_shapes=True)

In [None]:
def build_generator(latent_dim: int):
    input_noise = Input(shape=(latent_dim,))

    dense_1 = Dense(7 * 7 * 128)(input_noise)
    relut_0 = LeakyReLU(0.2)(dense_1)
    bn_1 = BatchNormalization(momentum=0.8)(relut_0)
    reshape_1 = Reshape((7, 7, 128))(bn_1)

    convt_1 = Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same")(reshape_1)
    relut_1 = LeakyReLU(0.2)(convt_1)
    bn_2 = BatchNormalization(momentum=0.8)(relut_1)

    convt_2 = Conv2DTranspose(512, (4, 4), strides=(2, 2), padding="same")(bn_2)
    relut_2 = LeakyReLU(0.2)(convt_2)
    bn_3 = BatchNormalization(momentum=0.8)(relut_2)

    conv_1 = Conv2D(1, (3, 3), padding="same", activation="sigmoid")(bn_3)

    generator = Model(inputs=input_noise, outputs=conv_1)
    return generator

In [None]:
generator = build_generator(128)
plot_model(generator, show_shapes=True)

In [None]:
def build_gan(generator, discriminator):
    discriminator.trainable = False
    model = keras.Sequential()
    model.add(generator)
    model.add(discriminator)
    opt = keras.optimizers.Adam(learning_rate=0.002, beta_1=0.5)
    model.compile(optimizer=opt, loss="binary_crossentropy")
    return model

In [None]:
def get_real_samples(data: int, n_samples: int):
    idx = np.random.randint(0, len(data), n_samples)
    X = data[idx]
    return X

In [None]:
def generate_latent_points(latent_dim: int, n_samples: int):
    latent_vectors = np.random.randn(latent_dim * n_samples).reshape(
        (n_samples, latent_dim)
    )
    return latent_vectors

In [None]:
def generate_fake_samples(generator, latent_dim: int, n_samples: int):
    noise = generate_latent_points(latent_dim, n_samples)
    X = generator.predict(noise)
    return X

In [None]:
def save_plot(X, epoch: int, n=5):
    for i in range(n * n):
        plt.subplot(n, n, i + 1)
        plt.axis("off")
        plt.imshow(X[i, :, :, 0], cmap="gray_r")

    filename = f"gen_epoch_{epoch}.png"
    plt.savefig(filename)
    plt.close()

In [None]:
def show_images(latent_dim: int, epoch: int):
    r, c = 5, 5
    noise = np.random.rand(r * c, latent_dim)
    generated_images = generator.predict(noise)

    fig, ax = plt.subplots(r, c)
    count = 0

    for i in range(r):
        for j in range(c):
            ax[i, j].imshow(generated_images[count, :, :, 0], cmap="gray_r")
            ax[i, j].axis("off")
    fig.savefig(f"epoch_{epoch}_result.png")
    plt.close()

In [None]:
def summarize_performance(
    epoch, generator, discriminator, data, latent_dim, sample_size=100
):
    X_real = get_real_samples(data, sample_size)
    y_real = np.ones((sample_size, 1))
    X_fake = generate_fake_samples(generator, latent_dim, sample_size)
    y_fake = np.zeros((sample_size, 1))
    _, acc_real = discriminator.evaluate(X_real, y_real, verbose=0)
    _, acc_fake = discriminator.evaluate(X_fake, y_fake, verbose=0)

    print((acc_real * 100, acc_fake * 100))
    save_plot(X_fake, epoch)
    filename = "generator_model_%03d.h5" % (epoch + 1)
    generator.save(filename)

In [None]:
def train(generator, discriminator, gan, data, latent_dim, epochs=250, batch_size=128):
    batch_per_epoch = int(data.shape[0] / batch_size)
    half_batch = batch_size // 2

    y_real = np.ones((half_batch, 1))
    y_fake = np.zeros((half_batch, 1))

    for i in range(epochs):
        for j in range(batch_per_epoch):
            X_real = get_real_samples(data, half_batch)
            X_fake = generate_fake_samples(generator, latent_dim, half_batch)
            print(X_real.shape, X_fake.shape)

            X, y = np.vstack((X_real, X_fake)), np.vstack((y_real, y_fake))
            d_loss, _ = discriminator.train_on_batch(X, y)

            X_gan = generate_latent_points(latent_dim, batch_size)
            y_gan = np.ones((batch_size, 1))

            g_loss = gan.train_on_batch(X_gan, y_gan)

            print(i + 1, j + 1, batch_per_epoch, d_loss, g_loss)

        if (i + 1) % 20 == 0:
            summarize_performance(i, generator, discriminator, data, latent_dim)

In [None]:
gan = build_gan(generator, discriminator)
data = load_mnist_data()
train(generator, discriminator, gan, data, latent_dim=128)

In [None]:
train(generator, discriminator, gan, data, latent_dim=128)