In [6]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

# Load and preprocess the MNIST dataset
(x_train, y_train), (_, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_train = np.expand_dims(x_train, axis=-1)
y_train = tf.keras.utils.to_categorical(y_train, 10)

# Set random seed for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Constants
img_shape = (28, 28, 1)
num_classes = 10
latent_dim = 100

# Conditional GAN Generator
def build_generator():
    model = models.Sequential()

    # Label input
    label_input = layers.Input(shape=(num_classes,))

    # Noise input
    noise_input = layers.Input(shape=(latent_dim,))

    # Merge label and noise
    merged_input = layers.Concatenate()([noise_input, label_input])

    # Build the generator network
    x = layers.Dense(256)(merged_input)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Dense(512)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Dense(np.prod(img_shape), activation='tanh')(x)
    img = layers.Reshape(img_shape)(x)

    model = models.Model([noise_input, label_input], img)
    return model

# Conditional GAN Discriminator
def build_discriminator():
    model = models.Sequential()

    # Image input
    img_input = layers.Input(shape=img_shape)

    # Label input
    label_input = layers.Input(shape=(num_classes,))
    label_embedding = layers.Dense(np.prod(img_shape))(label_input)
    label_embedding = layers.Reshape(img_shape)(label_embedding)

    # Merge image and label
    merged_input = layers.Concatenate()([img_input, label_embedding])

    # Build the discriminator network
    x = layers.Flatten()(merged_input)
    x = layers.Dense(512)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Dense(256)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Dense(1, activation='sigmoid')(x)

    model = models.Model([img_input, label_input], x)
    return model

# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])

# Build the generator
generator = build_generator()

# The generator takes noise and the target label as input and generates an image
noise = layers.Input(shape=(latent_dim,))
label = layers.Input(shape=(num_classes,))
img = generator([noise, label])

# For the combined model, only the generator is trained
discriminator.trainable = False

# The discriminator takes generated image and the label as input and determines validity
validity = discriminator([img, label])

# The combined model (stacked generator and discriminator)
cgan = models.Model([noise, label], validity)
cgan.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(0.0002, 0.5))

# Training the CGAN
def train(epochs, batch_size=128, save_interval=1000):
    half_batch = batch_size // 2

    for epoch in range(epochs):
        # Train Discriminator
        idx = np.random.randint(0, x_train.shape[0], half_batch)
        imgs = x_train[idx]
        labels = y_train[idx]

        noise = np.random.normal(0, 1, (half_batch, latent_dim))
        gen_labels = np.random.randint(0, num_classes, half_batch)
        gen_labels = tf.keras.utils.to_categorical(gen_labels, num_classes)
        gen_imgs = generator.predict([noise, gen_labels])

        d_loss_real = discriminator.train_on_batch([imgs, labels], np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels], np.zeros((half_batch, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train Generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        valid_y = np.ones((batch_size, 1))

        sampled_labels = np.random.randint(0, num_classes, batch_size)
        sampled_labels = tf.keras.utils.to_categorical(sampled_labels, num_classes)

        g_loss = cgan.train_on_batch([noise, sampled_labels], valid_y)

        # Print the progress
        if epoch % save_interval == 0:
            print(f"{epoch} [D loss: {d_loss[0]} | D accuracy: {d_loss[1]*100}%] [G loss: {g_loss}]")
            save_imgs(epoch)

def save_imgs(epoch):
    r, c = 2, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    sampled_labels = np.arange(0, 10).reshape(-1, 1)
    sampled_labels = tf.keras.utils.to_categorical(sampled_labels, num_classes)

    gen_imgs = generator.predict([noise, sampled_labels])

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].set_title(f"Digit: {cnt}")
            axs[i, j].axis('off')
            cnt += 1
    fig.savefig(f"mnist_{epoch}.png")
    plt.close()

# Train the CGAN for 10,000 epochs with a batch size of 64 and save every 200 intervals
train(epochs=10000, batch_size=64, save_interval=500)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
5500 [D loss: 0.2519706189632416 | D accuracy: 89.0625%] [G loss: 5.498242378234863]
6000 [D loss: 0.09826977550983429 | D accuracy: 100.0%] [G loss: 6.747367858886719]
6500 [D loss: 0.22424907982349396 | D accuracy: 93.75%] [G loss: 5.403250217437744]
7000 [D loss: 0.13148233294487 | D accuracy: 96.875%] [G loss: 7.453749179840088]
7500 [D loss: 0.1160978190600872 | D accuracy: 95.3125%] [G loss: 7.5801191329956055]
8000 [D loss: 0.12926102057099342 | D accuracy: 95.3125%] [G loss: 11.599443435668945]
8500 [D loss: 0.21744228899478912 | D accuracy: 92.1875%] [G loss: 7.799016952514648]
9000 [D loss: 0.1725998818874359 | D accuracy: 95.3125%] [G loss: 7.300196647644043]
9500 [D loss: 0.07825219631195068 | D accuracy: 96.875%] [G loss: 10.886626243591309]
