# 🎯 Generative Adversarial Networks (GANs) - Complete Hands-on Tutorial

## Based on Lecture 15: Generative Models - GAN

**Author**: Ho-min Park  
**Email**: homin.park@ghent.ac.kr  
**Notebook Version**: Interactive Learning Edition

---

## 📚 Learning Objectives

By the end of this notebook, you will:
1. Understand the mathematical foundations of GANs
2. Implement a basic GAN from scratch
3. Build and train DCGAN for image generation
4. Explore advanced techniques (WGAN, Conditional GAN)
5. Diagnose and fix common GAN training issues
6. Apply GANs to real-world problems

## 📖 Notebook Structure

- **Part 1**: Introduction and Setup
- **Part 2**: Core Concepts (Mathematical Foundations)
- **Part 3**: Basic GAN Implementation
- **Part 4**: DCGAN for Image Generation
- **Part 5**: Advanced Techniques and Applications


---
# Part 1: Introduction and Setup

## 1.1 Understanding GANs - The Counterfeiter Analogy

Imagine a counterfeiter (Generator) trying to create fake money and a police officer (Discriminator) trying to detect fakes:
- **Generator (G)**: Creates fake samples that look real
- **Discriminator (D)**: Distinguishes real from fake samples
- **Zero-Sum Game**: They compete until fake becomes indistinguishable from real


In [None]:
# Essential imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import clear_output, display, HTML
import warnings
warnings.filterwarnings('ignore')

# Deep Learning imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid

# For interactive visualizations
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'PyTorch version: {torch.__version__}')

---
## Exercise 1: Mathematical Foundation of GANs 🧮

### Concept

GANs optimize a minimax objective:

$$\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]$$

Where:
- $D(x)$: Discriminator's probability that $x$ is real
- $G(z)$: Generator transforms noise $z$ into fake data
- $p_{data}$: Real data distribution
- $p_z$: Noise distribution (usually Gaussian)


In [None]:
# Visualize the GAN objective function
def visualize_gan_objective():
    """Visualize how the GAN objective changes during training"""
    
    # Create synthetic discriminator outputs
    x = np.linspace(0, 1, 100)
    
    # Value function components
    real_term = np.log(x + 1e-7)  # log D(x) for real data
    fake_term = np.log(1 - x + 1e-7)  # log(1 - D(G(z))) for fake data
    
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=('Discriminator Terms', 'Training Dynamics')
    )
    
    # Plot value function components
    fig.add_trace(
        go.Scatter(x=x, y=real_term, name='Real: log D(x)', 
                   line=dict(color='blue', width=2)),
        row=1, col=1
    )
    fig.add_trace(
        go.Scatter(x=x, y=fake_term, name='Fake: log(1-D(G(z)))',
                   line=dict(color='red', width=2)),
        row=1, col=1
    )
    
    # Simulate training dynamics
    epochs = np.arange(0, 100)
    d_loss = 0.5 * np.exp(-epochs/30) * np.sin(epochs/5) + 0.5
    g_loss = 0.5 * np.exp(-epochs/30) * np.cos(epochs/5) + 0.5
    
    fig.add_trace(
        go.Scatter(x=epochs, y=d_loss, name='D Loss',
                   line=dict(color='blue', width=2)),
        row=1, col=2
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=g_loss, name='G Loss',
                   line=dict(color='red', width=2)),
        row=1, col=2
    )
    
    fig.update_xaxes(title_text='D(x)', row=1, col=1)
    fig.update_xaxes(title_text='Training Epoch', row=1, col=2)
    fig.update_yaxes(title_text='Value', row=1, col=1)
    fig.update_yaxes(title_text='Loss', row=1, col=2)
    
    fig.update_layout(
        title='GAN Objective Function Visualization',
        height=400,
        showlegend=True
    )
    
    return fig

# Display the visualization
fig = visualize_gan_objective()
fig.show()

print("📊 Key Insights:")
print("1. The discriminator maximizes its ability to distinguish real from fake")
print("2. The generator minimizes the discriminator's ability to detect fakes")
print("3. Training oscillates as both networks compete")

---
## Exercise 2: Building a Simple 1D GAN 🔧

### Concept
Let's start with a simple 1D GAN that learns to generate samples from a Gaussian distribution.
This helps understand the core mechanics without the complexity of images.


In [None]:
class Simple1D_Generator(nn.Module):
    """Simple generator for 1D data"""
    def __init__(self, noise_dim=10, hidden_dim=128, output_dim=1):
        super(Simple1D_Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, z):
        return self.net(z)

class Simple1D_Discriminator(nn.Module):
    """Simple discriminator for 1D data"""
    def __init__(self, input_dim=1, hidden_dim=128):
        super(Simple1D_Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

# Training function for 1D GAN
def train_1d_gan(n_epochs=1000, batch_size=128):
    """Train a simple 1D GAN to learn a Gaussian distribution"""
    
    # Initialize networks
    G = Simple1D_Generator().to(device)
    D = Simple1D_Discriminator().to(device)
    
    # Optimizers
    g_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    # Loss function
    criterion = nn.BCELoss()
    
    # Target distribution: Gaussian(mean=4, std=1.5)
    target_mean, target_std = 4.0, 1.5
    
    # Storage for visualization
    generated_samples = []
    d_losses, g_losses = [], []
    
    for epoch in range(n_epochs):
        # Generate real samples from target distribution
        real_data = torch.randn(batch_size, 1).to(device) * target_std + target_mean
        
        # === Train Discriminator ===
        d_optimizer.zero_grad()
        
        # Real data
        real_labels = torch.ones(batch_size, 1).to(device)
        real_output = D(real_data)
        d_loss_real = criterion(real_output, real_labels)
        
        # Fake data
        noise = torch.randn(batch_size, 10).to(device)
        fake_data = G(noise)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        fake_output = D(fake_data.detach())
        d_loss_fake = criterion(fake_output, fake_labels)
        
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()
        
        # === Train Generator ===
        g_optimizer.zero_grad()
        
        noise = torch.randn(batch_size, 10).to(device)
        fake_data = G(noise)
        fake_output = D(fake_data)
        g_loss = criterion(fake_output, real_labels)  # Generator wants D to think fake is real
        
        g_loss.backward()
        g_optimizer.step()
        
        # Store results
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())
        
        if epoch % 100 == 0:
            with torch.no_grad():
                test_noise = torch.randn(1000, 10).to(device)
                samples = G(test_noise).cpu().numpy()
                generated_samples.append(samples)
        
        if epoch % 200 == 0:
            print(f'Epoch [{epoch}/{n_epochs}] D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}')
    
    return G, D, generated_samples, d_losses, g_losses

# Train the 1D GAN
print("🚀 Training 1D GAN...")
G_1d, D_1d, samples_1d, d_losses_1d, g_losses_1d = train_1d_gan(n_epochs=2000)
print("✅ Training complete!")

In [None]:
# Visualize 1D GAN results
def visualize_1d_gan_results(samples, d_losses, g_losses):
    """Visualize the training results of 1D GAN"""
    
    fig = plt.figure(figsize=(15, 5))
    
    # Plot 1: Generated distribution evolution
    ax1 = plt.subplot(1, 3, 1)
    target_samples = np.random.normal(4.0, 1.5, 1000)
    
    for i, sample in enumerate(samples[::2]):  # Show every other snapshot
        ax1.hist(sample, bins=30, alpha=0.3, density=True, label=f'Epoch {i*200}')
    
    ax1.hist(target_samples, bins=30, alpha=0.5, density=True, 
             color='red', label='Target', histtype='step', linewidth=2)
    ax1.set_xlabel('Value')
    ax1.set_ylabel('Density')
    ax1.set_title('Generated Distribution Evolution')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Loss curves
    ax2 = plt.subplot(1, 3, 2)
    ax2.plot(d_losses, label='Discriminator Loss', alpha=0.7)
    ax2.plot(g_losses, label='Generator Loss', alpha=0.7)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.set_title('Training Losses')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Final comparison
    ax3 = plt.subplot(1, 3, 3)
    ax3.hist(samples[-1], bins=30, alpha=0.7, density=True, 
             color='blue', label='Generated')
    ax3.hist(target_samples, bins=30, alpha=0.7, density=True, 
             color='red', label='Target')
    ax3.set_xlabel('Value')
    ax3.set_ylabel('Density')
    ax3.set_title('Final Result Comparison')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Calculate statistics
    final_mean = np.mean(samples[-1])
    final_std = np.std(samples[-1])
    print(f"\n📊 Final Statistics:")
    print(f"Generated: mean={final_mean:.2f}, std={final_std:.2f}")
    print(f"Target:    mean=4.00, std=1.50")
    print(f"Error:     mean_diff={abs(final_mean-4.0):.3f}, std_diff={abs(final_std-1.5):.3f}")

visualize_1d_gan_results(samples_1d, d_losses_1d, g_losses_1d)

---
## Exercise 3: DCGAN for Image Generation 🖼️

### Concept
Deep Convolutional GAN (DCGAN) introduced architectural guidelines that made GANs stable:
- Replace pooling with strided convolutions
- Use batch normalization in both G and D
- Remove fully connected hidden layers
- Use ReLU in G (except output layer)
- Use LeakyReLU in D


In [None]:
class DCGAN_Generator(nn.Module):
    """DCGAN Generator for 64x64 images"""
    def __init__(self, nz=100, ngf=64, nc=3):
        super(DCGAN_Generator, self).__init__()
        self.main = nn.Sequential(
            # Input: Z vector
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # State: (ngf*8) x 4 x 4
            
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # State: (ngf*4) x 8 x 8
            
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # State: (ngf*2) x 16 x 16
            
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # State: (ngf) x 32 x 32
            
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: (nc) x 64 x 64
        )
    
    def forward(self, input):
        return self.main(input)

class DCGAN_Discriminator(nn.Module):
    """DCGAN Discriminator for 64x64 images"""
    def __init__(self, nc=3, ndf=64):
        super(DCGAN_Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input: (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf) x 32 x 32
            
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*2) x 16 x 16
            
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*4) x 8 x 8
            
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*8) x 4 x 4
            
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # Output: 1
        )
    
    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

# Initialize DCGAN
nz = 100  # Latent vector size
ngf = 64  # Generator feature map size
ndf = 64  # Discriminator feature map size
nc = 1    # Number of channels (1 for grayscale, 3 for RGB)

netG = DCGAN_Generator(nz, ngf, nc).to(device)
netD = DCGAN_Discriminator(nc, ndf).to(device)

# Weight initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

print("✅ DCGAN models initialized")
print(f"Generator parameters: {sum(p.numel() for p in netG.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in netD.parameters()):,}")

---
## Exercise 4: Training DCGAN on MNIST 🔢

### Concept
We'll train DCGAN on MNIST digits to understand the training dynamics.


In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download and load the training data
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                         shuffle=True, num_workers=2)

print(f"✅ Loaded MNIST dataset: {len(trainset)} images")
print(f"   Batch size: 64")
print(f"   Number of batches: {len(dataloader)}")

In [None]:
def train_dcgan(netG, netD, dataloader, num_epochs=5, nz=100):
    """Train DCGAN on image data"""
    
    # Setup optimizers
    lr = 0.0002
    beta1 = 0.5
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    
    # Loss function
    criterion = nn.BCELoss()
    
    # Fixed noise for visualization
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)
    
    # Training stats
    G_losses = []
    D_losses = []
    img_list = []
    
    print("🚀 Starting DCGAN Training...")
    
    for epoch in range(num_epochs):
        for i, (data, _) in enumerate(dataloader):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            netD.zero_grad()
            
            # Train with real batch
            real_data = data.to(device)
            batch_size = real_data.size(0)
            label = torch.full((batch_size,), 1., dtype=torch.float, device=device)
            
            output = netD(real_data)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()
            
            # Train with fake batch
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(0.)
            
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            
            errD = errD_real + errD_fake
            optimizerD.step()
            
            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(1.)  # Fake labels are real for generator cost
            
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            
            optimizerG.step()
            
            # Save losses
            G_losses.append(errG.item())
            D_losses.append(errD.item())
            
            # Print statistics
            if i % 100 == 0:
                print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] '
                      f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                      f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')
        
        # Check how the generator is doing by saving G's output on fixed_noise
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
            img_list.append(fake)
    
    print("✅ Training Complete!")
    return G_losses, D_losses, img_list

# Train the DCGAN
G_losses, D_losses, img_list = train_dcgan(netG, netD, dataloader, num_epochs=5)

In [None]:
# Visualize DCGAN results
def visualize_dcgan_results(G_losses, D_losses, img_list):
    """Visualize DCGAN training results"""
    
    fig = plt.figure(figsize=(15, 6))
    
    # Plot losses
    ax1 = plt.subplot(1, 3, 1)
    ax1.plot(G_losses, label='Generator', alpha=0.7)
    ax1.plot(D_losses, label='Discriminator', alpha=0.7)
    ax1.set_xlabel('Iterations')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Losses')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Show real images
    ax2 = plt.subplot(1, 3, 2)
    real_batch = next(iter(dataloader))
    grid = make_grid(real_batch[0][:64], padding=2, normalize=True, nrow=8)
    ax2.imshow(np.transpose(grid, (1, 2, 0)), cmap='gray')
    ax2.set_title('Real Images')
    ax2.axis('off')
    
    # Show generated images
    ax3 = plt.subplot(1, 3, 3)
    grid = make_grid(img_list[-1], padding=2, normalize=True, nrow=8)
    ax3.imshow(np.transpose(grid, (1, 2, 0)), cmap='gray')
    ax3.set_title('Generated Images')
    ax3.axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_dcgan_results(G_losses, D_losses, img_list)

---
## Exercise 5: Understanding Mode Collapse 🔍

### Concept
Mode collapse occurs when the generator produces limited variety of samples, failing to capture the full data distribution.

Types:
- **Partial Collapse**: Missing some modes of the data distribution
- **Complete Collapse**: Generator produces nearly identical outputs


In [None]:
def detect_mode_collapse(generator, n_samples=1000, nz=100):
    """Detect mode collapse by analyzing generated sample diversity"""
    
    # Generate samples
    with torch.no_grad():
        noise = torch.randn(n_samples, nz, 1, 1, device=device)
        generated = generator(noise).cpu().numpy()
    
    # Flatten images for analysis
    generated_flat = generated.reshape(n_samples, -1)
    
    # Calculate pairwise distances
    from scipy.spatial.distance import pdist, squareform
    distances = pdist(generated_flat, metric='euclidean')
    dist_matrix = squareform(distances)
    
    # Analyze diversity metrics
    mean_distance = np.mean(distances)
    std_distance = np.std(distances)
    min_distance = np.min(distances)
    
    # Detect near-duplicates (threshold-based)
    threshold = mean_distance * 0.1  # 10% of mean distance
    near_duplicates = np.sum(distances < threshold)
    total_pairs = len(distances)
    duplicate_ratio = near_duplicates / total_pairs
    
    # Visualization
    fig = plt.figure(figsize=(15, 5))
    
    # Distance distribution
    ax1 = plt.subplot(1, 3, 1)
    ax1.hist(distances, bins=50, edgecolor='black', alpha=0.7)
    ax1.axvline(mean_distance, color='red', linestyle='--', label=f'Mean: {mean_distance:.2f}')
    ax1.axvline(threshold, color='orange', linestyle='--', label=f'Threshold: {threshold:.2f}')
    ax1.set_xlabel('Pairwise Distance')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Distance Distribution')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Distance heatmap (subsample for visualization)
    ax2 = plt.subplot(1, 3, 2)
    subsample_idx = np.random.choice(n_samples, 50, replace=False)
    sub_dist_matrix = dist_matrix[np.ix_(subsample_idx, subsample_idx)]
    im = ax2.imshow(sub_dist_matrix, cmap='viridis')
    ax2.set_title('Distance Matrix (50 samples)')
    ax2.set_xlabel('Sample Index')
    ax2.set_ylabel('Sample Index')
    plt.colorbar(im, ax=ax2)
    
    # Mode collapse indicators
    ax3 = plt.subplot(1, 3, 3)
    metrics = ['Mean Distance', 'Std Distance', 'Min Distance', 'Duplicate Ratio']
    values = [mean_distance/100, std_distance/100, min_distance/10, duplicate_ratio*10]
    colors = ['blue', 'green', 'orange', 'red']
    bars = ax3.bar(metrics, values, color=colors, alpha=0.7)
    ax3.set_ylabel('Normalized Value')
    ax3.set_title('Diversity Metrics')
    ax3.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, val in zip(bars, [mean_distance, std_distance, min_distance, duplicate_ratio]):
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height,
                f'{val:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    # Mode collapse assessment
    print("\n🔍 Mode Collapse Analysis:")
    print(f"Mean pairwise distance: {mean_distance:.4f}")
    print(f"Std of distances: {std_distance:.4f}")
    print(f"Minimum distance: {min_distance:.4f}")
    print(f"Near-duplicate ratio: {duplicate_ratio:.4%}")
    
    if duplicate_ratio > 0.1:
        print("⚠️ WARNING: High duplicate ratio indicates possible mode collapse!")
    elif duplicate_ratio > 0.05:
        print("⚠️ CAUTION: Moderate duplicate ratio - monitor for mode collapse")
    else:
        print("✅ Good diversity - no significant mode collapse detected")
    
    return mean_distance, std_distance, duplicate_ratio

# Analyze mode collapse
mean_dist, std_dist, dup_ratio = detect_mode_collapse(netG, n_samples=500)

---
## Exercise 6: Wasserstein GAN (WGAN) Implementation 🌊

### Concept
WGAN uses Wasserstein distance instead of JS divergence:
- More stable training
- Meaningful loss metric
- No mode collapse
- Requires Lipschitz constraint (via weight clipping or gradient penalty)


In [None]:
class WGAN_Critic(nn.Module):
    """WGAN Critic (no sigmoid activation)"""
    def __init__(self, input_dim=1, hidden_dim=128):
        super(WGAN_Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1)  # No sigmoid!
        )
    
    def forward(self, x):
        return self.net(x)

def train_wgan_1d(n_epochs=2000, n_critic=5, clip_value=0.01):
    """Train WGAN with weight clipping"""
    
    # Initialize networks
    G = Simple1D_Generator().to(device)
    C = WGAN_Critic().to(device)  # Critic instead of Discriminator
    
    # Optimizers (RMSprop works better for WGAN)
    g_optimizer = optim.RMSprop(G.parameters(), lr=0.00005)
    c_optimizer = optim.RMSprop(C.parameters(), lr=0.00005)
    
    # Target distribution
    target_mean, target_std = 4.0, 1.5
    
    # Storage
    c_losses, g_losses = [], []
    wasserstein_distances = []
    
    batch_size = 128
    
    for epoch in range(n_epochs):
        # Train Critic multiple times
        for _ in range(n_critic):
            c_optimizer.zero_grad()
            
            # Real data
            real_data = torch.randn(batch_size, 1).to(device) * target_std + target_mean
            real_score = C(real_data)
            
            # Fake data
            noise = torch.randn(batch_size, 10).to(device)
            fake_data = G(noise)
            fake_score = C(fake_data.detach())
            
            # Wasserstein distance
            c_loss = -torch.mean(real_score) + torch.mean(fake_score)
            c_loss.backward()
            c_optimizer.step()
            
            # Clip critic weights (Lipschitz constraint)
            for p in C.parameters():
                p.data.clamp_(-clip_value, clip_value)
        
        # Train Generator
        g_optimizer.zero_grad()
        
        noise = torch.randn(batch_size, 10).to(device)
        fake_data = G(noise)
        fake_score = C(fake_data)
        
        g_loss = -torch.mean(fake_score)
        g_loss.backward()
        g_optimizer.step()
        
        # Store metrics
        c_losses.append(c_loss.item())
        g_losses.append(g_loss.item())
        wasserstein_distances.append(-c_loss.item())  # Negative critic loss = Wasserstein distance
        
        if epoch % 400 == 0:
            print(f'Epoch [{epoch}/{n_epochs}] C_loss: {c_loss:.4f}, G_loss: {g_loss:.4f}, '
                  f'W_distance: {-c_loss.item():.4f}')
    
    return G, C, c_losses, g_losses, wasserstein_distances

# Train WGAN
print("🚀 Training WGAN with weight clipping...")
G_wgan, C_wgan, c_losses_wgan, g_losses_wgan, w_distances = train_wgan_1d()
print("✅ WGAN training complete!")

In [None]:
# Compare WGAN vs Standard GAN
def compare_gan_wgan():
    """Compare standard GAN and WGAN results"""
    
    fig = plt.figure(figsize=(15, 10))
    
    # Generate samples from both models
    with torch.no_grad():
        noise = torch.randn(1000, 10).to(device)
        samples_gan = G_1d(noise).cpu().numpy()
        samples_wgan = G_wgan(noise).cpu().numpy()
    
    target_samples = np.random.normal(4.0, 1.5, 1000)
    
    # Row 1: Distribution comparison
    ax1 = plt.subplot(2, 3, 1)
    ax1.hist(samples_gan, bins=30, alpha=0.7, density=True, color='blue', label='Standard GAN')
    ax1.hist(target_samples, bins=30, alpha=0.5, density=True, 
             color='red', label='Target', histtype='step', linewidth=2)
    ax1.set_title('Standard GAN Results')
    ax1.set_xlabel('Value')
    ax1.set_ylabel('Density')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2 = plt.subplot(2, 3, 2)
    ax2.hist(samples_wgan, bins=30, alpha=0.7, density=True, color='green', label='WGAN')
    ax2.hist(target_samples, bins=30, alpha=0.5, density=True, 
             color='red', label='Target', histtype='step', linewidth=2)
    ax2.set_title('WGAN Results')
    ax2.set_xlabel('Value')
    ax2.set_ylabel('Density')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    ax3 = plt.subplot(2, 3, 3)
    ax3.hist(samples_gan, bins=30, alpha=0.5, density=True, color='blue', label='Standard GAN')
    ax3.hist(samples_wgan, bins=30, alpha=0.5, density=True, color='green', label='WGAN')
    ax3.hist(target_samples, bins=30, alpha=0.5, density=True, 
             color='red', label='Target', histtype='step', linewidth=2)
    ax3.set_title('All Distributions')
    ax3.set_xlabel('Value')
    ax3.set_ylabel('Density')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # Row 2: Training dynamics
    ax4 = plt.subplot(2, 3, 4)
    ax4.plot(d_losses_1d[:500], alpha=0.7, label='D Loss')
    ax4.plot(g_losses_1d[:500], alpha=0.7, label='G Loss')
    ax4.set_title('Standard GAN Training')
    ax4.set_xlabel('Iteration')
    ax4.set_ylabel('Loss')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    ax5 = plt.subplot(2, 3, 5)
    ax5.plot(c_losses_wgan[:500], alpha=0.7, label='Critic Loss')
    ax5.plot(g_losses_wgan[:500], alpha=0.7, label='G Loss')
    ax5.set_title('WGAN Training')
    ax5.set_xlabel('Iteration')
    ax5.set_ylabel('Loss')
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    
    ax6 = plt.subplot(2, 3, 6)
    ax6.plot(w_distances[:500], alpha=0.7, color='purple')
    ax6.set_title('Wasserstein Distance')
    ax6.set_xlabel('Iteration')
    ax6.set_ylabel('Distance')
    ax6.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Calculate statistics
    from scipy import stats
    
    # KL divergence approximation using histogram
    hist_target, bins = np.histogram(target_samples, bins=30, density=True)
    hist_gan, _ = np.histogram(samples_gan, bins=bins, density=True)
    hist_wgan, _ = np.histogram(samples_wgan, bins=bins, density=True)
    
    # Add small epsilon to avoid log(0)
    eps = 1e-10
    kl_gan = np.sum(hist_target * np.log((hist_target + eps) / (hist_gan + eps)))
    kl_wgan = np.sum(hist_target * np.log((hist_target + eps) / (hist_wgan + eps)))
    
    print("\n📊 Comparison Results:")
    print(f"Standard GAN - Mean: {np.mean(samples_gan):.3f}, Std: {np.std(samples_gan):.3f}, KL: {kl_gan:.3f}")
    print(f"WGAN        - Mean: {np.mean(samples_wgan):.3f}, Std: {np.std(samples_wgan):.3f}, KL: {kl_wgan:.3f}")
    print(f"Target      - Mean: 4.000, Std: 1.500")

compare_gan_wgan()

---
## Exercise 7: Conditional GAN (cGAN) 🎨

### Concept
Conditional GANs allow us to control the generation process by providing additional information (labels, text, images).

Key modification: Both G and D receive the conditional information y:
- G(z, y): Generate samples conditioned on y
- D(x, y): Classify real/fake given condition y


In [None]:
class ConditionalGenerator(nn.Module):
    """Conditional GAN Generator"""
    def __init__(self, n_classes=10, nz=100, n_channels=1, ngf=64):
        super(ConditionalGenerator, self).__init__()
        
        # Embedding for class labels
        self.label_emb = nn.Embedding(n_classes, n_classes)
        
        self.main = nn.Sequential(
            # Input: (nz + n_classes) x 1 x 1
            nn.ConvTranspose2d(nz + n_classes, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf, n_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, noise, labels):
        # Embed labels and concatenate with noise
        label_embedding = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
        gen_input = torch.cat((noise, label_embedding), 1)
        return self.main(gen_input)

class ConditionalDiscriminator(nn.Module):
    """Conditional GAN Discriminator"""
    def __init__(self, n_classes=10, n_channels=1, ndf=64):
        super(ConditionalDiscriminator, self).__init__()
        
        # Embedding for class labels
        self.label_embedding = nn.Embedding(n_classes, 64*64)
        
        self.main = nn.Sequential(
            # Input: (n_channels + 1) x 64 x 64 (image + embedded label)
            nn.Conv2d(n_channels + 1, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        # Embed labels and reshape to image dimensions
        label_embedding = self.label_embedding(labels).view(-1, 1, 64, 64)
        # Concatenate image and label embedding
        d_input = torch.cat((img, label_embedding), 1)
        return self.main(d_input).view(-1, 1).squeeze(1)

# Initialize Conditional GAN
n_classes = 10  # For MNIST digits 0-9
cG = ConditionalGenerator(n_classes=n_classes).to(device)
cD = ConditionalDiscriminator(n_classes=n_classes).to(device)

# Initialize weights
cG.apply(weights_init)
cD.apply(weights_init)

print("✅ Conditional GAN initialized")
print(f"Classes: {n_classes} (digits 0-9)")
print(f"Generator parameters: {sum(p.numel() for p in cG.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in cD.parameters()):,}")

In [None]:
def generate_conditional_samples(generator, n_samples=100, n_classes=10, nz=100):
    """Generate samples for each class using conditional GAN"""
    
    samples_per_class = n_samples // n_classes
    
    all_samples = []
    all_labels = []
    
    with torch.no_grad():
        for class_idx in range(n_classes):
            # Create labels for this class
            labels = torch.full((samples_per_class,), class_idx, dtype=torch.long, device=device)
            
            # Generate noise
            noise = torch.randn(samples_per_class, nz, 1, 1, device=device)
            
            # Generate samples
            fake_images = generator(noise, labels)
            
            all_samples.append(fake_images)
            all_labels.extend([class_idx] * samples_per_class)
    
    # Concatenate all samples
    all_samples = torch.cat(all_samples, dim=0)
    
    # Visualization
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.flatten()
    
    for class_idx in range(n_classes):
        # Get samples for this class
        class_samples = all_samples[class_idx * samples_per_class:(class_idx + 1) * samples_per_class]
        
        # Create grid for visualization
        grid = make_grid(class_samples[:8], nrow=4, normalize=True, padding=2)
        
        # Display
        axes[class_idx].imshow(np.transpose(grid.cpu(), (1, 2, 0)), cmap='gray')
        axes[class_idx].set_title(f'Class {class_idx}')
        axes[class_idx].axis('off')
    
    plt.suptitle('Conditional GAN: Generated Samples by Class', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print("\n📊 Conditional Generation Summary:")
    print(f"Total samples generated: {len(all_samples)}")
    print(f"Samples per class: {samples_per_class}")
    print(f"Classes: {list(range(n_classes))}")

# Generate conditional samples (using untrained model for demonstration)
print("🎨 Generating conditional samples...")
generate_conditional_samples(cG, n_samples=80, n_classes=10)

---
## Exercise 8: GAN Evaluation Metrics 📏

### Concept
Evaluating GANs is challenging. Common metrics include:

1. **Inception Score (IS)**: Measures quality and diversity
2. **Fréchet Inception Distance (FID)**: Compares feature statistics
3. **Precision & Recall**: Quality vs coverage trade-off


In [None]:
def calculate_inception_score(images, splits=10):
    """Calculate Inception Score for generated images"""
    
    # Simplified IS calculation (for demonstration)
    # In practice, use a pre-trained Inception model
    
    N = len(images)
    
    # Simulate class probabilities (would come from Inception model)
    # For demonstration, we'll create synthetic probabilities
    np.random.seed(42)
    
    # Higher quality images have more confident predictions
    quality_factor = np.random.uniform(0.5, 1.0, N)
    
    # Create probability distributions
    probs = np.zeros((N, 10))  # 10 classes
    for i in range(N):
        # Dominant class
        dominant_class = np.random.randint(0, 10)
        probs[i, dominant_class] = quality_factor[i]
        
        # Distribute remaining probability
        remaining = 1 - quality_factor[i]
        other_probs = np.random.dirichlet(np.ones(9)) * remaining
        other_classes = [j for j in range(10) if j != dominant_class]
        probs[i, other_classes] = other_probs
    
    # Calculate IS
    split_scores = []
    
    for k in range(splits):
        part = probs[k * (N // splits): (k + 1) * (N // splits)]
        py = np.mean(part, axis=0)
        
        scores = []
        for p_yx in part:
            kl = np.sum(p_yx * (np.log(p_yx + 1e-10) - np.log(py + 1e-10)))
            scores.append(np.exp(kl))
        
        split_scores.append(np.mean(scores))
    
    return np.mean(split_scores), np.std(split_scores)

def calculate_fid_score(real_features, fake_features):
    """Calculate Fréchet Inception Distance"""
    
    # Calculate statistics
    mu_real = np.mean(real_features, axis=0)
    mu_fake = np.mean(fake_features, axis=0)
    
    sigma_real = np.cov(real_features, rowvar=False)
    sigma_fake = np.cov(fake_features, rowvar=False)
    
    # Calculate FID
    diff = mu_real - mu_fake
    
    # Simplified calculation (full version uses matrix square root)
    covmean = np.trace(sigma_real + sigma_fake - 2 * np.sqrt(sigma_real @ sigma_fake))
    
    fid = np.sum(diff ** 2) + covmean
    
    return fid

# Evaluate GAN metrics
def evaluate_gan_metrics(generator, n_samples=1000):
    """Comprehensive GAN evaluation"""
    
    print("📊 Evaluating GAN Metrics...\n")
    
    # Generate samples
    with torch.no_grad():
        noise = torch.randn(n_samples, 100, 1, 1, device=device)
        fake_images = generator(noise).cpu().numpy()
    
    # Flatten for feature extraction
    fake_features = fake_images.reshape(n_samples, -1)
    
    # Generate "real" features for comparison (synthetic for demo)
    real_features = np.random.normal(0, 1, fake_features.shape)
    
    # Calculate Inception Score
    is_mean, is_std = calculate_inception_score(fake_images)
    print(f"📈 Inception Score: {is_mean:.2f} ± {is_std:.2f}")
    print("   (Higher is better, typical range: 1-10)")
    
    # Calculate FID
    fid = calculate_fid_score(real_features[:, :100], fake_features[:, :100])
    print(f"\n📉 FID Score: {fid:.2f}")
    print("   (Lower is better, typical range: 0-100)")
    
    # Visualize metrics
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # IS visualization
    metrics = ['Quality', 'Diversity', 'Overall IS']
    values = [is_mean * 0.6, is_mean * 0.4, is_mean]
    colors = ['#4CAF50', '#2196F3', '#FF9800']
    
    bars = ax1.bar(metrics, values, color=colors, alpha=0.7)
    ax1.set_ylabel('Score')
    ax1.set_title('Inception Score Components')
    ax1.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bar, val in zip(bars, values):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{val:.2f}', ha='center', va='bottom')
    
    # FID visualization
    categories = ['FID\n(Our Model)', 'Good GAN\n(typical)', 'Poor GAN\n(typical)']
    fid_values = [fid, 20, 80]
    colors = ['#FF6B6B', '#4CAF50', '#FF9800']
    
    bars = ax2.bar(categories, fid_values, color=colors, alpha=0.7)
    ax2.set_ylabel('FID Score (lower is better)')
    ax2.set_title('FID Score Comparison')
    ax2.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bar, val in zip(bars, fid_values):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height,
                f'{val:.1f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    return is_mean, fid

# Evaluate the trained DCGAN
is_score, fid_score = evaluate_gan_metrics(netG, n_samples=500)

---
# Part 5: Summary and Practice Exercises 🎯

## Key Takeaways

### 1. **GAN Fundamentals**
- Minimax game between Generator and Discriminator
- Generator learns to map noise to data distribution
- Discriminator learns to distinguish real from fake

### 2. **Training Challenges**
- Mode collapse: Generator produces limited variety
- Vanishing gradients: When D is too successful
- Training instability: Oscillating losses
- No clear stopping criterion

### 3. **Improvements**
- **DCGAN**: Architectural guidelines for stability
- **WGAN**: Wasserstein distance for meaningful loss
- **Conditional GAN**: Controlled generation
- **Progressive GAN**: High-resolution generation

### 4. **Practical Tips**
- Normalize inputs to [-1, 1]
- Use batch normalization (except G output, D input)
- LeakyReLU in D, ReLU in G
- Monitor loss curves and sample quality
- Use Adam with β₁ = 0.5


---
## 🏋️ Practice Exercises

### Exercise A: Implement Non-Saturating Loss
Modify the basic GAN to use non-saturating loss for the generator.

### Exercise B: Add Gradient Penalty
Implement WGAN-GP (Gradient Penalty) instead of weight clipping.

### Exercise C: Build a Simple StyleGAN
Create a simplified version of StyleGAN with style injection.

### Exercise D: Implement Spectral Normalization
Add spectral normalization to stabilize training.


In [None]:
# Exercise A: Non-Saturating Loss
def train_gan_non_saturating():
    """
    TODO: Implement GAN with non-saturating loss
    Instead of minimizing log(1 - D(G(z))), maximize log(D(G(z)))
    """
    # Your code here
    pass

# Exercise B: Gradient Penalty
def gradient_penalty(discriminator, real_samples, fake_samples, device):
    """
    TODO: Calculate gradient penalty for WGAN-GP
    Hint: Interpolate between real and fake samples, calculate gradients
    """
    # Your code here
    pass

# Exercise C: Style Injection
class SimpleStyleGenerator(nn.Module):
    """
    TODO: Implement a generator with style injection
    Hint: Use AdaIN (Adaptive Instance Normalization)
    """
    def __init__(self):
        super().__init__()
        # Your code here
        pass
    
    def forward(self, z, style):
        # Your code here
        pass

print("📝 Complete the exercises above!")
print("Hints provided in the function docstrings.")

---
## 🎉 Congratulations!

You've completed the comprehensive GAN tutorial! You've learned:

✅ Mathematical foundations of GANs  
✅ Implementation from scratch  
✅ DCGAN architecture  
✅ WGAN improvements  
✅ Conditional generation  
✅ Evaluation metrics  
✅ Common problems and solutions  

### 🚀 Next Steps

1. **Experiment with different architectures**: Try ProGAN, StyleGAN, or BigGAN
2. **Apply to your domain**: Use GANs for your specific application
3. **Explore recent advances**: Diffusion models, DALL-E, etc.
4. **Read the original papers**: Deep dive into the theory

### 📚 Recommended Reading

- [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661) - Original GAN paper
- [Unsupervised Representation Learning with DCGANs](https://arxiv.org/abs/1511.06434)
- [Wasserstein GAN](https://arxiv.org/abs/1701.07875)
- [Progressive Growing of GANs](https://arxiv.org/abs/1710.10196)
- [StyleGAN](https://arxiv.org/abs/1812.04948)

---
**Thank you for learning with us!** 🙏

If you have questions or feedback, please reach out to:  
📧 homin.park@ghent.ac.kr
