
# DCGAN with CIFAR-10

Clean, runnable DCGAN example that downloads CIFAR-10 automatically (no manual CelebA files needed).


In [8]:
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Reproducibility
tf.random.set_seed(42)
np.random.seed(42)

LATENT_DIM = 128
BATCH_SIZE = 128
IMAGE_SHAPE = (32, 32, 3)
BUFFER_SIZE = 50000
OUTPUT_DIR = Path("dcgan_samples")
OUTPUT_DIR.mkdir(exist_ok=True)


In [9]:
# Load CIFAR-10, normalize to [-1, 1], and build a tf.data pipeline
(x_train, _), (x_test, _) = keras.datasets.cifar10.load_data()
images = np.concatenate([x_train, x_test], axis=0).astype("float32")
images = (images - 127.5) / 127.5  # to [-1, 1]

train_dataset = (
    tf.data.Dataset.from_tensor_slices(images)
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
)

print(f"Dataset ready: {len(images)} images, batches per epoch: {len(list(train_dataset))}")

# Quick peek at a handful of examples
plt.figure(figsize=(8, 4))
for i, sample in enumerate(images[:8]):
    plt.subplot(2, 4, i + 1)
    plt.imshow((sample * 127.5 + 127.5).astype("uint8"))
    plt.axis("off")
plt.tight_layout()


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m  5283840/170498071[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m9:43[0m 4us/step

KeyboardInterrupt: 

In [None]:
def build_discriminator(input_shape=IMAGE_SHAPE):
    model = keras.Sequential(
        [
            layers.Input(shape=input_shape),
            layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Dropout(0.3),
            layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Dropout(0.3),
            layers.Conv2D(256, kernel_size=4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Dropout(0.3),
            layers.Flatten(),
            layers.Dense(1),  # from_logits=True in loss
        ],
        name="discriminator",
    )
    return model


def build_generator(latent_dim=LATENT_DIM):
    model = keras.Sequential(
        [
            layers.Input(shape=(latent_dim,)),
            layers.Dense(4 * 4 * 256),
            layers.Reshape((4, 4, 256)),
            layers.Conv2DTranspose(256, 4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Conv2DTranspose(128, 4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Conv2DTranspose(64, 4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Conv2DTranspose(3, 3, activation="tanh", padding="same"),
        ],
        name="generator",
    )
    return model


discriminator = build_discriminator()
generator = build_generator()

discriminator.summary()
generator.summary()


In [None]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile(run_eagerly=False)
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]

        # Discriminator step
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        generated_images = self.generator(random_latent_vectors)
        combined_images = tf.concat([generated_images, real_images], axis=0)
        labels = tf.concat(
            [tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0
        )
        labels += 0.05 * tf.random.uniform(tf.shape(labels))  # label smoothing

        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)

        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        # Generator step
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        misleading_labels = tf.ones((batch_size, 1))

        with tf.GradientTape() as tape:
            generated_images = self.generator(random_latent_vectors)
            predictions = self.discriminator(generated_images)
            g_loss = self.loss_fn(misleading_labels, predictions)

        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {"d_loss": self.d_loss_metric.result(), "g_loss": self.g_loss_metric.result()}


In [None]:
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, output_dir, num_images=9, latent_dim=LATENT_DIM):
        self.output_dir = Path(output_dir)
        self.num_images = num_images
        self.latent_dim = latent_dim
        self.seed = tf.random.normal(shape=(num_images, latent_dim))

    def on_epoch_end(self, epoch, logs=None):
        generated_images = self.model.generator(self.seed)
        generated_images = (generated_images * 127.5 + 127.5).numpy().astype("uint8")

        fig, axes = plt.subplots(3, 3, figsize=(6, 6))
        for ax, img in zip(axes.flatten(), generated_images):
            ax.imshow(img)
            ax.axis("off")
        plt.tight_layout()

        out_path = self.output_dir / f"epoch_{epoch + 1:03d}.png"
        plt.savefig(out_path)
        plt.close(fig)
        print(f"Saved sample grid to {out_path}")


In [None]:
# Train
EPOCHS = 5
STEPS_PER_EPOCH = 400  # limits runtime; set None to use full dataset

train_data = train_dataset
if STEPS_PER_EPOCH is not None:
    train_data = train_data.take(STEPS_PER_EPOCH)

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=LATENT_DIM)

gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    g_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

history = gan.fit(
    train_data,
    epochs=EPOCHS,
    callbacks=[GANMonitor(OUTPUT_DIR, latent_dim=LATENT_DIM)],
)



After training, sample images are saved in `dcgan_samples/`. Increase `EPOCHS` or set `STEPS_PER_EPOCH = None` to train on the full dataset.
