# Phase 10: Generative Models
## Часть 2: Generative Adversarial Networks (GAN)

### В этом ноутбуке:

1. **GAN архитектура** - Generator vs Discriminator
2. **Adversarial training** - minimax game
3. **Mode collapse** и решения
4. **DCGAN** - Deep Convolutional GAN
5. **Практические советы**

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import warnings
warnings.filterwarnings('ignore')

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 1. GAN Концепция

### Два игрока:

- **Generator (G)**: создаёт fake данные из шума z
- **Discriminator (D)**: отличает real от fake

### Minimax Game:

$$\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$$

## 2. Данные

In [None]:
def generate_patterns(n_samples=2000):
    """Генерация простых паттернов 8x8"""
    patterns = []
    
    for _ in range(n_samples):
        img = np.zeros((8, 8))
        pattern_type = np.random.randint(4)
        
        if pattern_type == 0:  # Горизонтальная линия
            row = np.random.randint(1, 7)
            img[row, 1:7] = 1
        elif pattern_type == 1:  # Вертикальная линия
            col = np.random.randint(1, 7)
            img[1:7, col] = 1
        elif pattern_type == 2:  # Квадрат
            size = np.random.randint(2, 4)
            start = np.random.randint(1, 6-size)
            img[start:start+size, start:start+size] = 1
        else:  # Диагональ
            for i in range(6):
                img[i+1, i+1] = 1
        
        img += np.random.normal(0, 0.05, (8, 8))
        img = np.clip(img, 0, 1)
        patterns.append(img)
    
    return np.array(patterns)

# Генерация
X = generate_patterns(2000)
X = torch.FloatTensor(X).view(-1, 1, 8, 8)

dataset = TensorDataset(X)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

print(f'Data shape: {X.shape}')

## 3. Generator и Discriminator

In [None]:
class Generator(nn.Module):
    """Generator: z -> image"""
    
    def __init__(self, latent_dim=16, img_size=64):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 32),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(32),
            
            nn.Linear(32, 64),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(64),
            
            nn.Linear(64, img_size),
            nn.Sigmoid()
        )
    
    def forward(self, z):
        img = self.model(z)
        return img.view(-1, 1, 8, 8)

class Discriminator(nn.Module):
    """Discriminator: image -> real/fake probability"""
    
    def __init__(self, img_size=64):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(img_size, 64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(64, 32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        return self.model(img)

# Тест
latent_dim = 16
G = Generator(latent_dim)
D = Discriminator()

z = torch.randn(4, latent_dim)
fake_img = G(z)
pred = D(fake_img)

print(f'Noise: {z.shape}')
print(f'Generated: {fake_img.shape}')
print(f'Discriminator output: {pred.shape}')

## 4. GAN Training Loop

In [None]:
def train_gan(G, D, dataloader, epochs=200, latent_dim=16, lr=2e-4):
    """Обучение GAN"""
    
    G = G.to(device)
    D = D.to(device)
    
    optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    
    criterion = nn.BCELoss()
    
    history = {'G_loss': [], 'D_loss': [], 'D_real': [], 'D_fake': []}
    
    for epoch in range(epochs):
        G_losses = []
        D_losses = []
        D_reals = []
        D_fakes = []
        
        for batch in dataloader:
            real_imgs = batch[0].to(device)
            batch_size = real_imgs.size(0)
            
            # Labels
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)
            
            # ---------------------
            # Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            
            # Real images
            real_pred = D(real_imgs)
            d_loss_real = criterion(real_pred, real_labels)
            
            # Fake images
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = G(z)
            fake_pred = D(fake_imgs.detach())
            d_loss_fake = criterion(fake_pred, fake_labels)
            
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_D.step()
            
            # ---------------------
            # Train Generator
            # ---------------------
            optimizer_G.zero_grad()
            
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = G(z)
            fake_pred = D(fake_imgs)
            
            # Generator wants D to think fakes are real
            g_loss = criterion(fake_pred, real_labels)
            g_loss.backward()
            optimizer_G.step()
            
            # Metrics
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())
            D_reals.append(real_pred.mean().item())
            D_fakes.append(fake_pred.mean().item())
        
        history['G_loss'].append(np.mean(G_losses))
        history['D_loss'].append(np.mean(D_losses))
        history['D_real'].append(np.mean(D_reals))
        history['D_fake'].append(np.mean(D_fakes))
        
        if (epoch + 1) % 50 == 0:
            print(f'Epoch {epoch+1}, G: {np.mean(G_losses):.4f}, '
                  f'D: {np.mean(D_losses):.4f}, '
                  f'D(real): {np.mean(D_reals):.2f}, D(fake): {np.mean(D_fakes):.2f}')
    
    return history

# Обучение
latent_dim = 16
G = Generator(latent_dim)
D = Discriminator()

print('Обучение GAN...\n')
history = train_gan(G, D, dataloader, epochs=200, latent_dim=latent_dim)

In [None]:
# Визуализация обучения
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Losses
axes[0].plot(history['G_loss'], label='Generator')
axes[0].plot(history['D_loss'], label='Discriminator')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('GAN Training Losses')
axes[0].legend()

# Discriminator outputs
axes[1].plot(history['D_real'], label='D(real)')
axes[1].plot(history['D_fake'], label='D(fake)')
axes[1].axhline(y=0.5, color='gray', linestyle='--')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Probability')
axes[1].set_title('Discriminator Outputs')
axes[1].legend()

plt.tight_layout()
plt.show()

## 5. Генерация образцов

In [None]:
# Генерация
G.eval()
with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    generated = G(z).cpu()

# Визуализация
fig, axes = plt.subplots(2, 8, figsize=(12, 3))
for i in range(16):
    ax = axes[i//8, i%8]
    ax.imshow(generated[i, 0], cmap='gray')
    ax.axis('off')

plt.suptitle('GAN Generated Samples')
plt.show()

## 6. Сравнение Real vs Generated

In [None]:
# Сравнение
fig, axes = plt.subplots(2, 8, figsize=(12, 3))

# Real
for i in range(8):
    axes[0, i].imshow(X[i, 0], cmap='gray')
    axes[0, i].axis('off')
axes[0, 0].set_ylabel('Real', fontsize=12)

# Generated
G.eval()
with torch.no_grad():
    z = torch.randn(8, latent_dim).to(device)
    fake = G(z).cpu()

for i in range(8):
    axes[1, i].imshow(fake[i, 0], cmap='gray')
    axes[1, i].axis('off')
axes[1, 0].set_ylabel('Fake', fontsize=12)

plt.suptitle('Real vs Generated')
plt.tight_layout()
plt.show()

## 7. Latent Space Interpolation

In [None]:
# Интерполяция в latent space
G.eval()
with torch.no_grad():
    z1 = torch.randn(1, latent_dim).to(device)
    z2 = torch.randn(1, latent_dim).to(device)
    
    n_steps = 8
    interpolations = []
    
    for alpha in np.linspace(0, 1, n_steps):
        z = (1 - alpha) * z1 + alpha * z2
        img = G(z)
        interpolations.append(img[0, 0].cpu())

# Визуализация
fig, axes = plt.subplots(1, n_steps, figsize=(12, 2))
for i, img in enumerate(interpolations):
    axes[i].imshow(img, cmap='gray')
    axes[i].axis('off')

plt.suptitle('Latent Space Interpolation')
plt.show()

## 8. Практические советы

### Проблемы GAN:

1. **Mode Collapse** - генератор производит ограниченное разнообразие
2. **Training Instability** - баланс между G и D
3. **Vanishing Gradients** - если D слишком хорош

### Решения:

| Проблема | Решение |
|----------|----------|
| Mode Collapse | Mini-batch discrimination, Unrolled GAN |
| Instability | Spectral normalization, WGAN-GP |
| Vanishing gradients | Wasserstein loss |

### Best Practices:

- LeakyReLU в D, ReLU/Tanh в G
- BatchNorm в G (не в первом слое D)
- Adam с beta1=0.5
- Label smoothing (0.9 вместо 1.0)

## Итоги

### Что мы изучили:

1. **GAN архитектура** - Generator и Discriminator
2. **Adversarial training** - minimax game
3. **Training dynamics** - баланс G и D
4. **Latent space** - интерполяция

### Ключевая формула:

$$\min_G \max_D \mathbb{E}[\log D(x)] + \mathbb{E}[\log(1 - D(G(z)))]$$

### Варианты GAN:

- **DCGAN** - convolutional GAN
- **WGAN** - Wasserstein distance
- **StyleGAN** - style-based generator
- **CycleGAN** - unpaired image translation

### Следующий шаг:

В ноутбуке 03 изучим Diffusion Models - современный state-of-the-art в генерации.