# Module 11: Building and Training a DCGAN with PyTorch

**Learning Objective**: Understand and implement Deep Convolutional GANs (DCGANs) using convolutional layers instead of MLPs, and train on CIFAR-10 for higher-resolution image generation.

## DCGAN Overview
- **DCGAN** = Deep Convolutional Generative Adversarial Network
- Replaces MLP layers with **ConvTranspose2d** (Generator) and **Conv2d** (Discriminator)
- Adheres to architectural guidelines: Batch Normalization, LeakyReLU, strided convolutions
- Designed for **32×32 RGB images** (CIFAR-10)
- Produces **higher-quality, more structured images** compared to MLP GANs

Key Improvements:
+ Convolutional layers preserve spatial structure
+ Batch Normalization for stable training
+ Strided convolutions for efficient upsampling/downsampling
+ Better gradient flow and faster convergence

## Part 1: Setup and Imports

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# Import DCGAN models from local files
from models.dcgan import (
    DCGANGenerator,
    DCGANDiscriminator,
    create_dcgan_models,
    print_model_summary,
)
from dcgan_training import DCGANTrainer, initialize_weights

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Determine device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"✓ All imports successful!")
print(f"✓ Using device: {device}")
print(f"✓ PyTorch version: {torch.__version__}")


## Part 2: Load and Preprocess CIFAR-10 Dataset

In [None]:
# Define transforms: normalize to [-1, 1] for Tanh output
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
        ),  # Normalize to [-1, 1]
    ]
)

print("Downloading and loading CIFAR-10 dataset...")
train_dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)

batch_size = 64
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=(
        2 if device.type != "mps" else 0
    ),  # MPS doesn't work well with multiprocessing
)

print(f" CIFAR-10 dataset loaded")
print(f"  Total training images: {len(train_dataset)}")
print(f"  Batch size: {batch_size}")
print(f"  Batches per epoch: {len(train_loader)}")
print(f"  Image shape: (3, 32, 32) - RGB, 32×32 pixels")

# Visualize sample images
sample_batch, sample_labels = next(iter(train_loader))
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(16):
    ax = axes[i // 8, i % 8]
    img = sample_batch[i]
    # Denormalize for visualization
    img = (img + 1) / 2  # [-1, 1] → [0, 1]
    ax.imshow(img.permute(1, 2, 0).numpy())
    ax.axis("off")

plt.suptitle("CIFAR-10 Sample Images (Training Set)", fontsize=12, fontweight="bold")
plt.tight_layout()
plt.show()

print(
    f"\nClass mapping: 0=airplane, 1=automobile, 2=bird, 3=cat, 4=deer, 5=dog, 6=frog, 7=horse, 8=ship, 9=truck"
)


## Part 3: Understanding DCGAN Architecture

In [None]:
# Visualize DCGAN architecture
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Generator architecture
ax = axes[0]
ax.set_title("DCGAN Generator\n(Noise → Image)", fontsize=13, fontweight="bold")
gen_layers = [
    "Input: z (100)",
    "FC: 100 → 256×4×4",
    "ConvTranspose: 256→128 (4→8)",
    "ConvTranspose: 128→64 (8→16)",
    "ConvTranspose: 64→32 (16→32)",
    "ConvTranspose: 32→3 (32→32)",
    "Output: (3, 32, 32)",
]
y_pos = np.linspace(0.95, 0.05, len(gen_layers))
colors_gen = [
    "lightblue",
    "lightgreen",
    "lightyellow",
    "lightyellow",
    "lightyellow",
    "lightyellow",
    "lightcoral",
]
for i, (layer, y, color) in enumerate(zip(gen_layers, y_pos, colors_gen)):
    ax.text(
        0.5,
        y,
        layer,
        ha="center",
        va="center",
        fontsize=10,
        bbox=dict(boxstyle="round", facecolor=color, alpha=0.8, pad=0.8),
        transform=ax.transAxes,
    )
    if i < len(gen_layers) - 1:
        ax.annotate(
            "",
            xy=(0.5, y_pos[i + 1] + 0.02),
            xytext=(0.5, y - 0.02),
            arrowprops=dict(arrowstyle="->", lw=2, transform=ax.transAxes),
            transform=ax.transAxes,
        )
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis("off")

# Discriminator architecture
ax = axes[1]
ax.set_title(
    "DCGAN Discriminator\n(Image → Classification)", fontsize=13, fontweight="bold"
)
disc_layers = [
    "Input: (3, 32, 32)",
    "Conv: 3→32 (32→16)",
    "Conv: 32→64 (16→8)",
    "Conv: 64→128 (8→4)",
    "Conv: 128→256 (4→2)",
    "Flatten: 256×2×2",
    "FC: 1024 → 1",
    "Output: probability",
]
y_pos = np.linspace(0.95, 0.05, len(disc_layers))
colors_disc = [
    "lightcoral",
    "lightyellow",
    "lightyellow",
    "lightyellow",
    "lightyellow",
    "lightgreen",
    "lightblue",
    "lightgreen",
]
for i, (layer, y, color) in enumerate(zip(disc_layers, y_pos, colors_disc)):
    ax.text(
        0.5,
        y,
        layer,
        ha="center",
        va="center",
        fontsize=10,
        bbox=dict(boxstyle="round", facecolor=color, alpha=0.8, pad=0.8),
        transform=ax.transAxes,
    )
    if i < len(disc_layers) - 1:
        ax.annotate(
            "",
            xy=(0.5, y_pos[i + 1] + 0.02),
            xytext=(0.5, y - 0.02),
            arrowprops=dict(arrowstyle="->", lw=2, transform=ax.transAxes),
            transform=ax.transAxes,
        )
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis("off")

plt.tight_layout()
plt.show()

print("\n Key DCGAN Principles:")
print("━" * 80)
print("1. ConvTranspose2d (Generator): Fractional-strided convolutions for upsampling")
print("   - Transforms (100,) latent vector → (256, 4, 4) → (3, 32, 32)")
print("   - Stride=2 doubles spatial dimensions at each layer")
print("")
print("2. Conv2d (Discriminator): Strided convolutions for downsampling")
print("   - Transforms (3, 32, 32) → (256, 2, 2) → (1,) probability")
print("   - Stride=2 halves spatial dimensions at each layer")
print("")
print("3. Batch Normalization: Stabilizes training")
print("   - Applied to all layers except output")
print("   - Helps with gradient flow and convergence")
print("")
print("4. Activation Functions:")
print("   - Generator: ReLU (hidden), Tanh (output)")
print("   - Discriminator: LeakyReLU(0.2) for stable gradients")
print("━" * 80)


## Part 4: Create and Initialize DCGAN Models

In [None]:
# Create DCGAN models
latent_dim = 100
num_channels = 3

generator, discriminator = create_dcgan_models(
    latent_dim=latent_dim,
    num_channels=num_channels,
    device=device,
)

# Initialize weights (important for DCGAN!)
initialize_weights(generator)
initialize_weights(discriminator)

# Print model summary
print_model_summary(generator, discriminator, latent_dim)

# Test forward passes
print("\n Testing Forward Passes:")
print("━" * 80)
batch_size = 4
z = torch.randn(batch_size, latent_dim, device=device)
fake_images = generator(z)
print(f"Generator Input (noise):     {z.shape} → {fake_images.shape}")

D_output = discriminator(fake_images)
print(f"Discriminator Input (images): {fake_images.shape} → {D_output.shape}")
print("━" * 80)
print(" All shape assertions passed!\n")


## Part 5: Training Setup and Configuration

In [None]:
# Training configuration
num_epochs = 20
lr_g = 0.0002
lr_d = 0.0002
beta1 = 0.5
beta2 = 0.999

# Create trainer
trainer = DCGANTrainer(
    generator=generator,
    discriminator=discriminator,
    device=device,
    lr_g=lr_g,
    lr_d=lr_d,
    beta1=beta1,
    beta2=beta2,
)

print("DCGAN Training Configuration:")
print("━" * 80)
print(f"Number of Epochs:        {num_epochs}")
print(f"Generator LR:            {lr_g}")
print(f"Discriminator LR:        {lr_d}")
print(f"Adam Beta1:              {beta1}")
print(f"Adam Beta2:              {beta2}")
print(f"Batch Size:              {batch_size}")
print(f"Latent Dimension:        {latent_dim}")
print(f"Dataset Size:            {len(train_dataset)}")
print(f"Total Batches:           {len(train_loader)}")
print(f"Device:                  {device}")
print("━" * 80)
print("\n Note: Training will take 15-30 minutes depending on your hardware.")


## Part 6: Train the DCGAN Model

In [None]:
# Train the DCGAN (this takes time!)
results = trainer.train(
    train_loader=train_loader,
    num_epochs=num_epochs,
    latent_dim=latent_dim,
    log_interval=50,
)

print("\n Training complete!")


## Part 7: Visualize Training Loss Curves

In [None]:
# Plot loss curves
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Discriminator loss
ax = axes[0]
ax.plot(results["d_losses"], linewidth=1.5, color="navy", label="D Loss")
ax.axhline(y=0.5, color="gray", linestyle="--", alpha=0.5, label="Ideal (0.5)")
ax.set_xlabel("Training Step", fontsize=11)
ax.set_ylabel("Loss", fontsize=11)
ax.set_title("Discriminator Loss Over Training", fontsize=12, fontweight="bold")
ax.grid(alpha=0.3)
ax.legend(fontsize=10)

# Generator loss
ax = axes[1]
ax.plot(results["g_losses"], linewidth=1.5, color="darkgreen", label="G Loss")
ax.set_xlabel("Training Step", fontsize=11)
ax.set_ylabel("Loss", fontsize=11)
ax.set_title("Generator Loss Over Training", fontsize=12, fontweight="bold")
ax.grid(alpha=0.3)
ax.legend(fontsize=10)

plt.tight_layout()
plt.show()

# Calculate statistics
d_losses = np.array(results["d_losses"])
g_losses = np.array(results["g_losses"])

print("\nTraining Statistics:")
print("━" * 80)
print(f"Discriminator Loss:")
print(f"  Initial: {d_losses[0]:.4f}")
print(f"  Final:   {d_losses[-1]:.4f}")
print(f"  Mean:    {d_losses.mean():.4f}")
print(f"  Std:     {d_losses.std():.4f}")
print()
print(f"Generator Loss:")
print(f"  Initial: {g_losses[0]:.4f}")
print(f"  Final:   {g_losses[-1]:.4f}")
print(f"  Mean:    {g_losses.mean():.4f}")
print(f"  Std:     {g_losses.std():.4f}")
print("━" * 80)


## Part 8: Generate and Visualize Synthetic Images

In [None]:
# Generate synthetic images
num_samples = 32
generator.eval()
with torch.no_grad():
    z = torch.randn(num_samples, latent_dim, device=device)
    synthetic_images = generator(z)

# Denormalize synthetic images to [0, 1]
synthetic_images_cpu = synthetic_images.cpu()
synthetic_images_cpu = (synthetic_images_cpu + 1) / 2  # [-1, 1] → [0, 1]
synthetic_images_cpu = torch.clamp(synthetic_images_cpu, 0, 1)

# Visualize generated images
fig, axes = plt.subplots(4, 8, figsize=(16, 8))
for i in range(32):
    ax = axes[i // 8, i % 8]
    img = synthetic_images_cpu[i].permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.axis("off")

plt.suptitle(
    "DCGAN Generated Synthetic Images from CIFAR-10", fontsize=14, fontweight="bold"
)
plt.tight_layout()
plt.show()

print("\n Generated Images Summary:")
print("━" * 80)
print(f"Number of Samples Generated: {num_samples}")
print(f"Output Resolution: 32×32 pixels (RGB)")
print(f"Image Range: [0, 1] (normalized)")
print("━" * 80)


## Part 9: Compare Real vs Generated Images

In [None]:
# Compare real and generated images side by side
fig, axes = plt.subplots(4, 16, figsize=(20, 5))

# Real images (top 2 rows)
real_batch, _ = next(iter(train_loader))
real_batch = (real_batch + 1) / 2  # Denormalize
real_batch = torch.clamp(real_batch, 0, 1)

for i in range(8):
    # Real images
    ax = axes[0, i]
    ax.imshow(real_batch[i].permute(1, 2, 0).numpy())
    ax.set_title("Real", fontsize=9)
    ax.axis("off")

    ax = axes[1, i]
    ax.imshow(real_batch[8 + i].permute(1, 2, 0).numpy())
    ax.set_title("Real", fontsize=9)
    ax.axis("off")

# Generated images (bottom 2 rows)
with torch.no_grad():
    z = torch.randn(16, latent_dim, device=device)
    fake_batch = generator(z)

fake_batch_cpu = (fake_batch.cpu() + 1) / 2
fake_batch_cpu = torch.clamp(fake_batch_cpu, 0, 1)

for i in range(8):
    ax = axes[2, i]
    ax.imshow(fake_batch_cpu[i].permute(1, 2, 0).numpy())
    ax.set_title("Generated", fontsize=9, color="red")
    ax.axis("off")

    ax = axes[3, i]
    ax.imshow(fake_batch_cpu[8 + i].permute(1, 2, 0).numpy())
    ax.set_title("Generated", fontsize=9, color="red")
    ax.axis("off")

fig.suptitle(
    "Real CIFAR-10 Images vs DCGAN Generated Images", fontsize=13, fontweight="bold"
)
plt.tight_layout()
plt.show()

print("\n Side-by-side comparison complete!")
