# Image Generation with Conditional DCGAN

This notebook demonstrates **conditional image generation** using a DCGAN (Deep Convolutional GAN) trained on MNIST.

**What makes this fun:**
- Train a GAN that generates handwritten digits
- Control which digit (0-9) to generate
- Fast training on GPU (~10 minutes)
- Watch quality improve in real-time

**Why Conditional DCGAN?** Stable architecture, controllable generation, and impressive results with minimal training time.

In [None]:
# Installation
# !pip install torch torchvision matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

## Part 1: Data Preparation

In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Scale to [-1, 1]
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

print(f'Training samples: {len(train_dataset)}')
print(f'Batches per epoch: {len(train_loader)}')

# Visualize real samples
samples = next(iter(train_loader))[0][:64]
grid = make_grid(samples, nrow=8, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(12, 12))
plt.imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
plt.title('Real MNIST Images', fontsize=16)
plt.axis('off')
plt.show()

## Part 2: Build Conditional DCGAN

The key idea: both Generator and Discriminator receive the class label as input, allowing controlled generation.

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Embedding for class labels
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        
        # Generator architecture
        self.model = nn.Sequential(
            # Input: (latent_dim + num_classes) x 1 x 1
            nn.ConvTranspose2d(latent_dim + num_classes, 256, 7, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 256 x 7 x 7
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 128 x 14 x 14
            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
            nn.Tanh()
            # 1 x 28 x 28
        )
    
    def forward(self, noise, labels):
        # Embed labels and concatenate with noise
        label_embedding = self.label_embedding(labels).unsqueeze(2).unsqueeze(3)
        noise = noise.unsqueeze(2).unsqueeze(3)
        gen_input = torch.cat([noise, label_embedding], dim=1)
        return self.model(gen_input)


class Discriminator(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Embedding for class labels
        self.label_embedding = nn.Embedding(num_classes, 28*28)
        
        # Discriminator architecture
        self.model = nn.Sequential(
            # Input: 2 x 28 x 28 (image + label)
            nn.Conv2d(2, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 128 x 14 x 14
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 256 x 7 x 7
            nn.Conv2d(256, 1, 7, 1, 0, bias=False),
            nn.Sigmoid()
            # 1 x 1 x 1
        )
    
    def forward(self, image, labels):
        # Embed labels and concatenate with image
        label_embedding = self.label_embedding(labels).view(-1, 1, 28, 28)
        disc_input = torch.cat([image, label_embedding], dim=1)
        return self.model(disc_input).view(-1, 1)


# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Initialize weights
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

generator.apply(weights_init)
discriminator.apply(weights_init)

print(f'Generator parameters: {sum(p.numel() for p in generator.parameters()):,}')
print(f'Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}')

## Part 3: Training Setup

In [None]:
# Loss and optimizers
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Fixed noise for visualization
fixed_noise = torch.randn(100, generator.latent_dim).to(device)
fixed_labels = torch.tensor([i % 10 for i in range(100)]).to(device)

print('Training setup complete!')

In [None]:
# Training function
def train_epoch(generator, discriminator, loader, optimizer_g, optimizer_d, criterion, device):
    generator.train()
    discriminator.train()
    
    d_losses = []
    g_losses = []
    
    for real_images, real_labels in tqdm(loader, desc='Training'):
        batch_size = real_images.size(0)
        real_images = real_images.to(device)
        real_labels = real_labels.to(device)
        
        # Labels for real and fake
        real = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)
        
        # ============================================
        # Train Discriminator
        # ============================================
        optimizer_d.zero_grad()
        
        # Real images
        real_output = discriminator(real_images, real_labels)
        d_loss_real = criterion(real_output, real)
        
        # Fake images
        noise = torch.randn(batch_size, generator.latent_dim).to(device)
        fake_labels = torch.randint(0, 10, (batch_size,)).to(device)
        fake_images = generator(noise, fake_labels)
        fake_output = discriminator(fake_images.detach(), fake_labels)
        d_loss_fake = criterion(fake_output, fake)
        
        # Total discriminator loss
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()
        
        # ============================================
        # Train Generator
        # ============================================
        optimizer_g.zero_grad()
        
        # Generate fake images
        noise = torch.randn(batch_size, generator.latent_dim).to(device)
        fake_labels = torch.randint(0, 10, (batch_size,)).to(device)
        fake_images = generator(noise, fake_labels)
        fake_output = discriminator(fake_images, fake_labels)
        
        # Generator wants discriminator to think fakes are real
        g_loss = criterion(fake_output, real)
        g_loss.backward()
        optimizer_g.step()
        
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())
    
    return np.mean(d_losses), np.mean(g_losses)

## Part 4: Train the GAN

Watch the generated images improve over epochs!

In [None]:
epochs = 15
sample_interval = 3  # Show samples every N epochs

history = {'d_loss': [], 'g_loss': []}

for epoch in range(epochs):
    print(f'\n=== Epoch {epoch+1}/{epochs} ===')
    
    d_loss, g_loss = train_epoch(generator, discriminator, train_loader, 
                                  optimizer_g, optimizer_d, criterion, device)
    
    history['d_loss'].append(d_loss)
    history['g_loss'].append(g_loss)
    
    print(f'D Loss: {d_loss:.4f} | G Loss: {g_loss:.4f}')
    
    # Generate samples
    if (epoch + 1) % sample_interval == 0 or epoch == 0:
        generator.eval()
        with torch.no_grad():
            fake_images = generator(fixed_noise, fixed_labels)
        
        grid = make_grid(fake_images[:64], nrow=8, normalize=True, value_range=(-1, 1))
        plt.figure(figsize=(10, 10))
        plt.imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
        plt.title(f'Generated Images - Epoch {epoch+1}', fontsize=16)
        plt.axis('off')
        plt.show()

print('\n✓ Training complete!')

In [None]:
# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(history['d_loss'], label='Discriminator Loss', linewidth=2)
plt.plot(history['g_loss'], label='Generator Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Training Loss Over Time', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()

## Part 5: Generate Specific Digits

The fun part - control which digit to generate!

In [None]:
def generate_digit(digit, num_samples=16):
    """
    Generate samples of a specific digit.
    
    Args:
        digit: Which digit to generate (0-9)
        num_samples: Number of samples to generate
    """
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_samples, generator.latent_dim).to(device)
        labels = torch.full((num_samples,), digit, dtype=torch.long).to(device)
        generated = generator(noise, labels)
    
    grid = make_grid(generated, nrow=4, normalize=True, value_range=(-1, 1))
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
    plt.title(f'Generated Digit: {digit}', fontsize=16)
    plt.axis('off')
    plt.show()

# Generate each digit
print('Generating all digits (0-9)...')
for digit in range(10):
    generate_digit(digit, num_samples=16)

## Part 6: Generate a Grid of All Digits

In [None]:
# Generate 10 samples for each digit
generator.eval()
all_generated = []

with torch.no_grad():
    for digit in range(10):
        noise = torch.randn(10, generator.latent_dim).to(device)
        labels = torch.full((10,), digit, dtype=torch.long).to(device)
        generated = generator(noise, labels)
        all_generated.append(generated)

all_generated = torch.cat(all_generated)

grid = make_grid(all_generated, nrow=10, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(15, 15))
plt.imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
plt.title('All Generated Digits (0-9, 10 samples each)', fontsize=18)
plt.axis('off')
plt.show()

## Part 7: Latent Space Exploration

Interpolate between two random points in latent space for the same digit.

In [None]:
def interpolate_latent(digit, num_steps=10):
    """
    Interpolate between two random latent vectors for a given digit.
    """
    generator.eval()
    
    # Two random starting points
    z1 = torch.randn(1, generator.latent_dim).to(device)
    z2 = torch.randn(1, generator.latent_dim).to(device)
    
    interpolations = []
    with torch.no_grad():
        for alpha in torch.linspace(0, 1, num_steps):
            z = (1 - alpha) * z1 + alpha * z2
            label = torch.tensor([digit]).to(device)
            img = generator(z, label)
            interpolations.append(img)
    
    interpolations = torch.cat(interpolations)
    grid = make_grid(interpolations, nrow=num_steps, normalize=True, value_range=(-1, 1))
    
    plt.figure(figsize=(15, 3))
    plt.imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
    plt.title(f'Latent Space Interpolation - Digit {digit}', fontsize=14)
    plt.axis('off')
    plt.show()

# Show interpolations for a few digits
for digit in [0, 3, 7]:
    interpolate_latent(digit, num_steps=10)

## Summary

**What we accomplished:**
1. Built a Conditional DCGAN that generates handwritten digits
2. Trained efficiently on GPU (~10 minutes for 15 epochs)
3. Generated specific digits on demand (conditional generation)
4. Explored the latent space with smooth interpolations

**Key Takeaways:**
- **Conditional GANs** allow controlled generation by providing class labels
- **DCGAN architecture** with BatchNorm and LeakyReLU provides stable training
- GANs learn to generate realistic samples by adversarial training
- The latent space is continuous and smooth (interpolations work well)

**Why This Matters:**
- Data augmentation: Generate synthetic training data
- Creative applications: Generate art, designs, faces
- Controllable generation: Specify what you want to create
- Research tool: Understand data distributions

**Next Steps:**
- Try on Fashion-MNIST or CIFAR-10
- Experiment with different architectures (Progressive GAN, StyleGAN)
- Add more conditioning (multiple attributes)
- Try modern alternatives (Diffusion Models, VAEs)