
# DCGAN (CelebA preferred, CIFAR-10 fallback)

This notebook fixes the earlier CelebA DCGAN setup by automatically loading a usable dataset (CelebA via `tensorflow_datasets`, or CIFAR-10 if CelebA isn't available) and correcting the GAN training loop.


In [10]:
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Reproducibility
tf.random.set_seed(42)
np.random.seed(42)

IMG_SIZE = 64
LATENT_DIM = 128
BATCH_SIZE = 128
BUFFER_SIZE = 50000
OUTPUT_DIR = Path("dcgan_samples")
OUTPUT_DIR.mkdir(exist_ok=True)


In [None]:
# Dataset loader: try CelebA via tensorflow_datasets, fallback to CIFAR-10

def get_dataset():
    try:
        import tensorflow_datasets as tfds

        ds, info = tfds.load(
            "celeb_a", split="train", with_info=True, shuffle_files=True
        )
        print("Using CelebA from tensorflow_datasets")

        def preprocess(sample):
            img = tf.cast(sample["image"], tf.float32)
            img = tf.image.resize(img, (IMG_SIZE, IMG_SIZE))
            img = (img - 127.5) / 127.5  # to [-1, 1]
            return img

        dataset = (
            ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
            .shuffle(BUFFER_SIZE)
            .batch(BATCH_SIZE, drop_remainder=True)
            .prefetch(tf.data.AUTOTUNE)
        )
        sample_shape = (IMG_SIZE, IMG_SIZE, 3)
        dataset_name = "CelebA"
        total = info.splits["train"].num_examples
    except Exception as exc:  # noqa: BLE001
        print(f"CelebA unavailable ({exc}); falling back to CIFAR-10.")
        (x_train, _), (x_test, _) = keras.datasets.cifar10.load_data()
        images = np.concatenate([x_train, x_test], axis=0).astype("float32")
        images = tf.convert_to_tensor(images)

        def preprocess(img):
            img = tf.image.resize(img, (IMG_SIZE, IMG_SIZE))
            img = (img - 127.5) / 127.5
            return img

        dataset = (
            tf.data.Dataset.from_tensor_slices(images)
            .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
            .shuffle(BUFFER_SIZE)
            .batch(BATCH_SIZE, drop_remainder=True)
            .prefetch(tf.data.AUTOTUNE)
        )
        sample_shape = (IMG_SIZE, IMG_SIZE, 3)
        dataset_name = "CIFAR-10"
        total = images.shape[0]

    return dataset, dataset_name, total, sample_shape


train_dataset, dataset_name, num_samples, image_shape = get_dataset()
print(f"Dataset: {dataset_name} | Samples: {num_samples} | Image shape: {image_shape}")


In [None]:
# Peek at a few examples after preprocessing
for images in train_dataset.take(1):
    plt.figure(figsize=(8, 4))
    for i in range(8):
        img = (images[i] * 127.5 + 127.5).numpy().astype("uint8")
        plt.subplot(2, 4, i + 1)
        plt.imshow(img)
        plt.axis("off")
    plt.suptitle(f"Samples from {dataset_name}")
    plt.tight_layout()
    break


In [None]:
def build_discriminator(input_shape=image_shape):
    model = keras.Sequential(
        [
            layers.Input(shape=input_shape),
            layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Dropout(0.3),
            layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Dropout(0.3),
            layers.Conv2D(256, kernel_size=4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Dropout(0.3),
            layers.Flatten(),
            layers.Dense(1),  # logits
        ],
        name="discriminator",
    )
    return model


def build_generator(latent_dim=LATENT_DIM):
    model = keras.Sequential(
        [
            layers.Input(shape=(latent_dim,)),
            layers.Dense(8 * 8 * 256),
            layers.Reshape((8, 8, 256)),
            layers.Conv2DTranspose(256, 4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Conv2DTranspose(128, 4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Conv2DTranspose(64, 4, strides=2, padding="same"),
            layers.LeakyReLU(0.2),
            layers.Conv2DTranspose(3, 3, activation="tanh", padding="same"),
        ],
        name="generator",
    )
    return model


discriminator = build_discriminator()
generator = build_generator()

discriminator.summary()
generator.summary()


In [None]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile(run_eagerly=False)
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]

        # -------------------
        # Train discriminator
        # -------------------
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        generated_images = self.generator(random_latent_vectors)
        combined_images = tf.concat([generated_images, real_images], axis=0)
        labels = tf.concat(
            [tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0
        )
        labels += 0.05 * tf.random.uniform(tf.shape(labels))  # label smoothing

        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)

        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        # ----------------
        # Train generator
        # ----------------
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        misleading_labels = tf.ones((batch_size, 1))

        with tf.GradientTape() as tape:
            fake_images = self.generator(random_latent_vectors)
            predictions = self.discriminator(fake_images)
            g_loss = self.loss_fn(misleading_labels, predictions)

        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {"d_loss": self.d_loss_metric.result(), "g_loss": self.g_loss_metric.result()}


In [None]:
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, output_dir, num_images=9, latent_dim=LATENT_DIM):
        self.output_dir = Path(output_dir)
        self.num_images = num_images
        self.latent_dim = latent_dim
        self.seed = tf.random.normal(shape=(num_images, latent_dim))

    def on_epoch_end(self, epoch, logs=None):
        generated_images = self.model.generator(self.seed)
        generated_images = (generated_images * 127.5 + 127.5).numpy().astype("uint8")

        fig, axes = plt.subplots(3, 3, figsize=(6, 6))
        for ax, img in zip(axes.flatten(), generated_images):
            ax.imshow(img)
            ax.axis("off")
        plt.tight_layout()

        out_path = self.output_dir / f"epoch_{epoch + 1:03d}.png"
        plt.savefig(out_path)
        plt.close(fig)
        print(f"Saved sample grid to {out_path}")


In [None]:
# Train
EPOCHS = 1  # adjust higher for better quality
STEPS_PER_EPOCH = 300  # set to None to use the full dataset each epoch

train_data = train_dataset
if STEPS_PER_EPOCH is not None:
    train_data = train_data.take(STEPS_PER_EPOCH)

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=LATENT_DIM)

gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    g_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

history = gan.fit(
    train_data,
    epochs=EPOCHS,
    callbacks=[GANMonitor(OUTPUT_DIR, latent_dim=LATENT_DIM)],
)



Sample grids are saved to `dcgan_samples/`. Raise `EPOCHS` and optionally remove `STEPS_PER_EPOCH` to train longer.
