## 1. 生成对抗网络（Generative Adversarial Network, GAN）

2014年由Ian Goodfellow等人提出，论文地址[GAN](https://arxiv.org/pdf/1711.00937)

在博弈论的思维框架下，通过两个神经网络（生成器和判别器）的对抗训练，模拟了“伪造者 vs 鉴定师”的对抗过程，使生成器能够生成逼真的数据。

In [3]:
from IPython.display import Image, display
url = 'https://production-media.paperswithcode.com/methods/gan.jpeg'
display(Image(url=url, width=600))

## 2. 数学表达式

GAN的训练是轮着进行的  
采用的都是概率建模的方法，具有隐状态空间z  
由于轮换对称性，KL散度是不对称的，所以将KL对称化，成为JS散度，定义：$D_{JS}=\frac{1}{2}(D_{KL}(P||\frac{P+Q}{2})+D_{KL}(Q||\frac{P+Q}{2}))$   
但如果两个概率分布距离很远，重叠区域很小，JS散度值是一个常数，这就意味着没有梯度。  
所以进一步的采用沃瑟斯坦距离（Wasserstein distance），又被称为推土机距离。该距离等价于一个推土机规划问题的极值。定义为：
$$W_p(\mu,\nu):=\left(\mathop{\inf}_{\gamma\in\Gamma(x,y)}\int d(x,y)^p d\gamma(x,y)\right)^{1/p}$$
当然，在高维空间中计算这个距离还是比较困难，因为相当于一个规模很大的线性规划问题。另一方法是通过一个神经网络来迫近这个距离。  

采用W距离，GAN的优化目标为：
$$\mathop{\min}_G \mathop{\max}_{D\in 1-Lipschitz}\mathbb{E}_{x\sim p_data}[D(x)]-\mathbb{E}_{z\sim p_z}[D(G(z))]$$
判别器$D(x)$是一个满足1阶Lipschitz条件的函数，即$$|f(x_2)-f(x_1)|\leqslant 1\cdot |x_2-x_1|$$
这个条件常用来限制神经网络权重的梯度，避免梯度爆炸或消失。 参见d2l [8.5.5. 梯度裁剪](https://zh.d2l.ai/chapter_recurrent-neural-networks/rnn-scratch.html)

后续变体：

DCGAN：首个将卷积网络引入GAN，提升图像生成质量。[DCGAN](https://arxiv.org/pdf/1511.06434)

Conditional GAN（cGAN）：通过条件标签控制生成内容（如指定生成“猫”的图像）。[Conditional GAN（cGAN）](https://arxiv.org/pdf/1411.1784)

ProGAN：渐进式训练，从低分辨率逐步提升至高清图像。[ProGAN](https://arxiv.org/pdf/1710.10196)

BigGAN：大规模训练生成高多样性、高分辨率图像。[BigGAN](https://arxiv.org/pdf/1809.11096)

StyleGAN：通过风格控制生成细节丰富的图像（如人脸毛孔、发丝）。[StyleGAN](https://arxiv.org/pdf/1812.04948)

## 3.简单实现

In [27]:
import os
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import Dataset, DataLoader, Subset
from torchinfo import summary

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [28]:
# 超参数
batch_size = 64
latent_dim = 128
epochs = 30
lr_D = 1e-4
lr_G = 1e-4
lambda_gp = 10
n_critic = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
nz = 100  # 噪声向量维度

In [29]:
# 自定义 CelebA 数据集加载类
class CelebAFromFolder(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_filenames = sorted([
            fname for fname in os.listdir(image_dir)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_filenames[idx])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, 0  # 返回0作为占位标签

In [30]:
# 图像预处理
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5068, 0.4254, 0.3832],
                         std=[0.2971, 0.2789, 0.2815])
])

In [31]:
# 实例化数据集
dataset = CelebAFromFolder('./data/img_align_celeba', transform=transform)
# 只取前10000张图像
subset_dataset = Subset(dataset, range(10000))

In [32]:
# 数据加载器
train_loader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=True)

In [33]:
# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # [B, 64, 32, 32]
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),  # [B, 128, 16, 16]
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),  # [B, 256, 8, 8]
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1),  # [B, 512, 4, 4]
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, 4, 1, 0) # [B, 1, 1, 1]
        )

    def forward(self, x):
        return self.model(x).view(-1)

In [34]:
# 生成器
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(nz, 512, 4, 1, 0),  # [B, 512, 4, 4]
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1),  # [B, 256, 8, 8]
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # [B, 128, 16, 16]
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # [B, 64, 32, 32]
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1),  # [B, 3, 64, 64]
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

In [35]:
# 初始化模型和优化器
netG = Generator().to(device)
netD = Discriminator().to(device)

optimizerD = optim.Adam(netD.parameters(), lr=lr_D, betas=(0.0, 0.9))
optimizerG = optim.Adam(netG.parameters(), lr=lr_G, betas=(0.0, 0.9))

In [36]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

    d_interpolates = D(interpolates)
    fake = torch.ones(d_interpolates.size(), device=device)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [37]:
# 固定噪声用于可视化
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

In [38]:
for epoch in range(epochs):
    for i, (real_images, _) in enumerate(train_loader):
        real_images = real_images.to(device)
        b_size = real_images.size(0)

        # 训练判别器
        for _ in range(n_critic):
            netD.zero_grad()
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake_images = netG(noise).detach()

            D_real = netD(real_images)
            D_fake = netD(fake_images)

            gp = compute_gradient_penalty(netD, real_images.data, fake_images.data)
            lossD = -torch.mean(D_real) + torch.mean(D_fake) + lambda_gp * gp
            lossD.backward()
            optimizerD.step()

        # 训练生成器
        netG.zero_grad()
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake_images = netG(noise)
        lossG = -torch.mean(netD(fake_images))
        lossG.backward()
        optimizerG.step()

        if i % 100 == 0:
            print(f"[{epoch}/{epochs}][{i}/{len(train_loader)}] Loss_D: {lossD.item():.4f}, Loss_G: {lossG.item():.4f}, GP: {gp.item():.4f}")

    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
        utils.save_image(fake, f"./data/GAN/epoch_{epoch:03d}.png", normalize=True, nrow=8)

[0/30][0/157] Loss_D: -11.7737, Loss_G: 6.7276, GP: 0.0091
[0/30][100/157] Loss_D: -319.7731, Loss_G: 161.3455, GP: 12.8314
[1/30][0/157] Loss_D: -627.5335, Loss_G: 278.2659, GP: 4.1834
[1/30][100/157] Loss_D: -1252.0129, Loss_G: 644.4818, GP: 0.9079
[2/30][0/157] Loss_D: -43.8690, Loss_G: -455.1148, GP: 2.3239
[2/30][100/157] Loss_D: -70.7683, Loss_G: -335.0363, GP: 2.1133
[3/30][0/157] Loss_D: -72.4328, Loss_G: -304.3444, GP: 1.3387
[3/30][100/157] Loss_D: -287.6983, Loss_G: -25.4743, GP: 5.2499
[4/30][0/157] Loss_D: -260.5536, Loss_G: -5.6636, GP: 2.5074
[4/30][100/157] Loss_D: -745.1209, Loss_G: 304.5760, GP: 6.3459
[5/30][0/157] Loss_D: -458.0331, Loss_G: 24.6500, GP: 11.1673
[5/30][100/157] Loss_D: -514.1838, Loss_G: 193.6291, GP: 41.4942
[6/30][0/157] Loss_D: -905.1974, Loss_G: 277.9176, GP: 8.7006
[6/30][100/157] Loss_D: -1939.9958, Loss_G: 979.3066, GP: 6.4402
[7/30][0/157] Loss_D: -1672.5925, Loss_G: 524.4237, GP: 1.8691
[7/30][100/157] Loss_D: -2453.7329, Loss_G: 1032.5045, 