In [None]:
# Trains CGAN on MNIST using Keras
# Conditional GAN - z vector is conditioned by a one-hot label
# Based on DCGAN structure

from tensorflow.keras.layers import Activation, Dense, Input, Conv2D, Flatten, Reshape, Conv2DTranspose, LeakyReLU, BatchNormalization, concatenate
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
import math
import matplotlib.pyplot as plt
import os

# Generator model
def build_generator(inputs, labels, image_size):
    image_resize = image_size // 4
    kernel_size = 5
    layer_filters = [128, 64, 32, 1]

    x = concatenate([inputs, labels], axis=1)
    x = Dense(image_resize * image_resize * layer_filters[0])(x)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x)

    for filters in layer_filters:
        strides = 2 if filters > layer_filters[-2] else 1
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(x)

    x = Activation('sigmoid')(x)
    return Model([inputs, labels], x, name='generator')

# Discriminator model
def build_discriminator(inputs, labels, image_size):
    kernel_size = 5
    layer_filters = [32, 64, 128, 256]

    y = Dense(image_size * image_size)(labels)
    y = Reshape((image_size, image_size, 1))(y)
    x = concatenate([inputs, y])

    for filters in layer_filters:
        strides = 1 if filters == layer_filters[-1] else 2
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(x)

    x = Flatten()(x)
    x = Dense(1)(x)
    x = Activation('sigmoid')(x)
    return Model([inputs, labels], x, name='discriminator')

# Training function
def train(models, data, params):
    generator, discriminator, adversarial = models
    x_train, y_train = data
    batch_size, latent_size, train_steps, num_labels, model_name = params
    save_interval = 500
    noise_input = np.random.uniform(-1.0, 1.0, size=[16, latent_size])
    noise_class = np.eye(num_labels)[np.arange(0, 16) % num_labels]
    train_size = x_train.shape[0]

    for i in range(train_steps):
        # Train discriminator
        rand_indexes = np.random.randint(0, train_size, size=batch_size)
        real_images = x_train[rand_indexes]
        real_labels = y_train[rand_indexes]

        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels, batch_size)]
        fake_images = generator.predict([noise, fake_labels])

        x = np.concatenate((real_images, fake_images))
        labels = np.concatenate((real_labels, fake_labels))
        y = np.ones([2 * batch_size, 1])
        y[batch_size:, :] = 0.0

        d_loss, d_acc = discriminator.train_on_batch([x, labels], y)

        # Train adversarial (freeze discriminator)
        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels, batch_size)]
        y = np.ones([batch_size, 1])
        a_loss, a_acc = adversarial.train_on_batch([noise, fake_labels], y)

        print(f"{i}: [D loss: {d_loss:.4f}, acc: {d_acc:.4f}] [A loss: {a_loss:.4f}, acc: {a_acc:.4f}]")

        if (i + 1) % save_interval == 0:
            plot_images(generator, noise_input, noise_class, step=(i + 1), model_name=model_name)

    generator.save(model_name + ".h5")

# Plot generated images
def plot_images(generator, noise_input, noise_class, step=0, model_name="cgan"):
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, f"{step:05d}.png")
    images = generator.predict([noise_input, noise_class])
    plt.figure(figsize=(2.2, 2.2))
    rows = int(math.sqrt(noise_input.shape[0]))
    for i in range(images.shape[0]):
        plt.subplot(rows, rows, i + 1)
        plt.imshow(images[i].reshape(images.shape[1], images.shape[2]), cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

# Main logic
def main():
    (x_train, y_train), (_, _) = mnist.load_data()
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1]).astype('float32') / 255
    y_train = to_categorical(y_train)
    num_labels = y_train.shape[1]
    latent_size = 100
    batch_size = 64
    train_steps = 10000
    model_name = "cgan_mnist"

    inputs = Input(shape=(latent_size,))
    labels = Input(shape=(num_labels,))
    generator = build_generator(inputs, labels, image_size)

    image = Input(shape=(image_size, image_size, 1))
    labels_d = Input(shape=(num_labels,))
    discriminator = build_discriminator(image, labels_d, image_size)
    discriminator.compile(loss='binary_crossentropy', optimizer=RMSprop(learning_rate=0.0002), metrics=['accuracy'])

    discriminator.trainable = False
    fake_image = generator([inputs, labels])
    fake = discriminator([fake_image, labels])
    adversarial = Model([inputs, labels], fake)
    adversarial.compile(loss='binary_crossentropy', optimizer=RMSprop(learning_rate=0.0002), metrics=['accuracy'])

    models = (generator, discriminator, adversarial)
    data = (x_train, y_train)
    params = (batch_size, latent_size, train_steps, num_labels, model_name)
    train(models, data, params)

if __name__ == '__main__':
    main()
