# Module 13: Implementing Conditional GANs (cGANs)

**Learning Objective**: Understand and implement Conditional GANs that generate class-specific images on demand, with proper label conditioning in both Generator and Discriminator.

## What is a cGAN?

A **Conditional GAN** differs from standard GANs by:
1. **Generator receives class label** - conditions image generation on target class
2. **Discriminator receives class label** - enforces image-class consistency
3. **Label concatenation** - combines noise + embedded label early in generator, image features + embedded label in discriminator
4. **Class-specific synthesis** - can generate any class on demand

Key Benefits:
✓ Generate specific classes of images (e.g., "generate a cat")
✓ Better control over generated content
✓ Useful for data augmentation in classification tasks
✓ Demonstrates label entanglement vs style variation

## Part 1: Setup and Imports

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# Import cGAN models from local files
from models.cgan import (
    ConditionalGenerator,
    ConditionalDiscriminator,
    create_cgan_models,
    print_model_summary,
)
from cgan_training import ConditionalGANTrainer, initialize_weights

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Determine device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"All imports successful!")
print(f"Using device: {device}")
print(f" PyTorch version: {torch.__version__}")

# CIFAR-10 class names
CIFAR10_CLASSES = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]


## Part 2: Load CIFAR-10 Dataset

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ]
)

print("Loading CIFAR-10 dataset...")
train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

batch_size = 64
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,  # MPS compatibility
)

print(f" CIFAR-10 loaded: {len(train_dataset)} images, {len(train_loader)} batches")
print(f"  Classes: {CIFAR10_CLASSES}")

# Visualize sample images per class
fig, axes = plt.subplots(2, 5, figsize=(16, 6))
for i, class_idx in enumerate(range(10)):
    ax = axes[i // 5, i % 5]
    idx = np.where(np.array(train_dataset.targets) == class_idx)[0][0]
    img = (train_dataset[idx][0] + 1) / 2
    ax.imshow(img.permute(1, 2, 0))
    ax.set_title(f"{CIFAR10_CLASSES[class_idx]} (class {class_idx})", fontweight="bold")
    ax.axis("off")

plt.suptitle("CIFAR-10 Sample Images (One per Class)", fontsize=13, fontweight="bold")
plt.tight_layout()
plt.show()


## Part 3: cGAN Architecture Explained

In [None]:
print(" CONDITIONING STRATEGIES IN cGAN")
print("=" * 80)
print("\n1. GENERATOR CONDITIONING:")
print("   Input: noise z (100,) + class label (one-hot)")
print("   Step 1: Embed class label: one-hot (10,) → continuous (50,)")
print("   Step 2: Concatenate: [z (100,) + embedded_label (50,)] → (150,)")
print("   Step 3: FC layers expand to 256×4×4 spatial features")
print("   Step 4: ConvTranspose layers upsample to 32×32 image")
print("   Result: Image that belongs to the target class\n")

print("2. DISCRIMINATOR CONDITIONING:")
print("   Input: image (3, 32, 32) + class label (one-hot)")
print("   Step 1: Conv layers extract image features → (256, 2, 2)")
print("   Step 2: Embed class label: one-hot (10,) → continuous (50,)")
print("   Step 3: Flatten image features: (256×2×2=1024,)")
print("   Step 4: Concatenate: [image_features (1024,) + embedded_label (50,)]")
print("   Step 5: FC layers classify as real/fake")
print("   Result: Ensures image matches the claimed class\n")

print("=" * 80)
print("\n KEY INSIGHT:")
print("Both G and D know the target class:")
print("  • G tries to generate images of that class")
print("  • D checks if image belongs to that class")
print("  • Together they learn to disentangle class from style")
print("=" * 80)


## Part 4: Create and Initialize cGAN Models

In [None]:
latent_dim = 100
num_classes = 10
label_dim = 50
num_channels = 3

# Create cGAN models
generator, discriminator = create_cgan_models(
    latent_dim=latent_dim,
    num_classes=num_classes,
    label_dim=label_dim,
    num_channels=num_channels,
    device=device,
)

# Initialize weights
initialize_weights(generator)
initialize_weights(discriminator)

# Print model summary
print_model_summary(generator, discriminator, latent_dim, num_classes)

# Test forward passes
print("\n✓ Testing Forward Passes:")
print("=" * 80)
batch_size = 4
z = torch.randn(batch_size, latent_dim, device=device)
labels = torch.randint(0, num_classes, (batch_size,), device=device)

fake_images = generator(z, labels)
print(f"Generator Input (noise):     {z.shape}")
print(f"Generator Input (labels):    {labels.shape}")
print(f"Generator Output (images):   {fake_images.shape}")

D_output = discriminator(fake_images, labels)
print(f"Discriminator Input (images): {fake_images.shape}")
print(f"Discriminator Input (labels): {labels.shape}")
print(f"Discriminator Output (prob):  {D_output.shape}")
print("=" * 80)
print("✓ All assertions passed!\n")


## Part 5: Train the cGAN Model

⏱️ **Note:** Training will take 20-30 minutes depending on your hardware.

In [None]:
num_epochs = 20

# Create trainer
trainer = ConditionalGANTrainer(
    generator=generator,
    discriminator=discriminator,
    device=device,
    lr_g=0.0002,
    lr_d=0.0002,
    beta1=0.5,
    beta2=0.999,
)

# Train
results = trainer.train(
    train_loader=train_loader,
    num_epochs=num_epochs,
    latent_dim=latent_dim,
    num_classes=num_classes,
    log_interval=50,
)


## Part 6: Loss Curves and Training Analysis

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

ax = axes[0]
ax.plot(results["d_losses"], linewidth=1, color="navy", label="D Loss", alpha=0.7)
ax.axhline(y=0.5, color="gray", linestyle="--", alpha=0.5, label="Ideal (0.5)")
ax.set_xlabel("Training Step", fontsize=11)
ax.set_ylabel("Loss", fontsize=11)
ax.set_title("Discriminator Loss Over Training", fontsize=12, fontweight="bold")
ax.grid(alpha=0.3)
ax.legend(fontsize=10)

ax = axes[1]
ax.plot(results["g_losses"], linewidth=1, color="darkgreen", label="G Loss", alpha=0.7)
ax.set_xlabel("Training Step", fontsize=11)
ax.set_ylabel("Loss", fontsize=11)
ax.set_title("Generator Loss Over Training", fontsize=12, fontweight="bold")
ax.grid(alpha=0.3)
ax.legend(fontsize=10)

plt.tight_layout()
plt.show()

d_losses = np.array(results["d_losses"])
g_losses = np.array(results["g_losses"])
print(f"\n Training Statistics:")
print(f"  D Loss - Mean: {d_losses.mean():.4f}, Std: {d_losses.std():.4f}")
print(f"  G Loss - Mean: {g_losses.mean():.4f}, Std: {g_losses.std():.4f}")


## Part 7: Generate Class-Specific Images (10×10 Grid)

In [None]:
# Generate 10x10 class grid (same noise, different classes)
print("Generating 10×10 class grid (same noise z, different classes y)...")
print("This shows the cGAN's ability to disentangle class from other variations.\n")

grid_images = trainer.generate_all_classes_grid(
    num_classes=10,
    samples_per_class=10,
    latent_dim=latent_dim,
    shared_z=None,  # Create a fixed noise vector
)

# Denormalize
grid_images_cpu = (grid_images.cpu() + 1) / 2
grid_images_cpu = torch.clamp(grid_images_cpu, 0, 1)

# Visualize as grid
fig, axes = plt.subplots(10, 10, figsize=(16, 16))
for class_id in range(10):
    for sample_id in range(10):
        idx = class_id * 10 + sample_id
        ax = axes[class_id, sample_id]
        img = grid_images_cpu[idx].permute(1, 2, 0).numpy()
        ax.imshow(img)
        ax.axis("off")
        if sample_id == 0:
            ax.set_ylabel(CIFAR10_CLASSES[class_id], fontsize=10, fontweight="bold")

plt.suptitle(
    "cGAN Class Disentanglement: 10×10 Grid\n(Rows: Classes 0-9 | Columns: Same Noise, Different Classes)",
    fontsize=14,
    fontweight="bold",
    y=0.995,
)
plt.tight_layout()
plt.show()

print("✓ 10×10 class grid generated successfully!")
print(f"  Grid shape: {grid_images.shape}")
print(f"  Each row represents one class")
print(f"  Each column uses same noise but different class label")


## Summary: Key Takeaways

### What You Learned:
1. **Conditional GANs (cGANs)** - Generate images of specific classes on demand
2. **Label Conditioning** - How to incorporate class information in G and D
3. **Class Disentanglement** - Separating class identity from style variation
4. **10×10 Grid Technique** - Visualizing class control vs noise variation
5. **Data Augmentation** - Using cGANs to generate synthetic training data

### Key Differences from Unconditional GANs:
- Generator takes class label → class-specific generation
- Discriminator takes class label → enforces class consistency
- Both use label embedding and concatenation
- Training ensures label-image correspondence

### Next Steps:
- Train longer for better quality (50+ epochs)
- Try with CIFAR-100 (100 classes)
- Use generated images to augment classification datasets
- Experiment with different label embedding dimensions
- Combine with other GAN tricks (spectral normalization, Wasserstein loss)