# 09. Mini GAN

This notebook implements a minimal Generative Adversarial Network for synthetic data generation.

## Experiment Overview
- **Goal**: Generate synthetic data using Generative Adversarial Networks
- **Model**: Generator + Discriminator with simple architectures
- **Features**: Training dynamics, generated sample quality, loss curves
- **Learning**: Understanding adversarial training and generative models

## What You'll Learn
- GAN architecture and training
- Generator and discriminator design
- Adversarial training dynamics
- Generated sample quality assessment


In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

# Add scripts directory to path
sys.path.append('../scripts')
from utils import get_device, set_seed

# Set random seed for reproducibility
set_seed(42)

# Get device
device = get_device()
print(f"Using device: {device}")

# Generate real data (2D Gaussian)
def generate_real_data(n_samples=1000):
    """Generate real data from 2D Gaussian."""
    return np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], n_samples)

# Visualize real data
real_data = generate_real_data(1000)
plt.figure(figsize=(8, 6))
plt.scatter(real_data[:, 0], real_data[:, 1], alpha=0.6, c='blue')
plt.title('Real Data Distribution')
plt.xlabel('X1')
plt.ylabel('X2')
plt.grid(True)
plt.savefig('../results/plots/gan_real_data.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# Define Generator and Discriminator
class Generator(nn.Module):
    def __init__(self, noise_dim=2, output_dim=2, hidden_dim=64):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, noise):
        return self.net(noise)

class Discriminator(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

# Create models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

print("Generator:")
print(generator)
print(f"Parameters: {sum(p.numel() for p in generator.parameters()):,}")

print("\nDiscriminator:")
print(discriminator)
print(f"Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

# GAN Training function
def train_gan(generator, discriminator, real_data, epochs=1000, lr=0.0002, batch_size=64):
    """Train GAN."""
    g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    g_losses = []
    d_losses = []
    
    for epoch in range(epochs):
        # Train Discriminator
        d_optimizer.zero_grad()
        
        # Real data
        real_batch = torch.FloatTensor(real_data[np.random.choice(len(real_data), batch_size)]).to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        real_output = discriminator(real_batch)
        d_real_loss = F.binary_cross_entropy(real_output, real_labels)
        
        # Fake data
        noise = torch.randn(batch_size, 2).to(device)
        fake_batch = generator(noise)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        fake_output = discriminator(fake_batch.detach())
        d_fake_loss = F.binary_cross_entropy(fake_output, fake_labels)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        # Train Generator
        g_optimizer.zero_grad()
        noise = torch.randn(batch_size, 2).to(device)
        fake_batch = generator(noise)
        fake_output = discriminator(fake_batch)
        g_loss = F.binary_cross_entropy(fake_output, real_labels)
        g_loss.backward()
        g_optimizer.step()
        
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())
        
        if (epoch + 1) % 100 == 0:
            print(f'Epoch {epoch+1}/{epochs}, G Loss: {g_loss.item():.4f}, D Loss: {d_loss.item():.4f}')
    
    return g_losses, d_losses

# Train GAN
print("Training GAN...")
g_losses, d_losses = train_gan(generator, discriminator, real_data, epochs=1000)

# Plot training losses
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(g_losses, label='Generator Loss')
plt.plot(d_losses, label='Discriminator Loss')
plt.title('GAN Training Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(g_losses, label='Generator Loss')
plt.plot(d_losses, label='Discriminator Loss')
plt.title('GAN Training Losses (Log Scale)')
plt.xlabel('Epoch')
plt.ylabel('Loss (log)')
plt.yscale('log')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('../results/plots/gan_training.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# Generate and visualize fake data
def generate_fake_data(generator, n_samples=1000):
    """Generate fake data using trained generator."""
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(n_samples, 2).to(device)
        fake_data = generator(noise).cpu().numpy()
    return fake_data

# Generate fake data
fake_data = generate_fake_data(generator, 1000)

# Visualize real vs fake data
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.scatter(real_data[:, 0], real_data[:, 1], alpha=0.6, c='blue', label='Real')
plt.title('Real Data')
plt.xlabel('X1')
plt.ylabel('X2')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.scatter(fake_data[:, 0], fake_data[:, 1], alpha=0.6, c='red', label='Fake')
plt.title('Generated Data')
plt.xlabel('X1')
plt.ylabel('X2')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.scatter(real_data[:, 0], real_data[:, 1], alpha=0.6, c='blue', label='Real')
plt.scatter(fake_data[:, 0], fake_data[:, 1], alpha=0.6, c='red', label='Fake')
plt.title('Real vs Generated')
plt.xlabel('X1')
plt.ylabel('X2')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('../results/plots/gan_results.png', dpi=300, bbox_inches='tight')
plt.show()

# Save models
torch.save(generator.state_dict(), '../results/logs/generator.pth')
torch.save(discriminator.state_dict(), '../results/logs/discriminator.pth')

print("GAN training completed and models saved!")
