In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

### 读取（下载）MNIST数据集

In [2]:
# MNIST 数据集的均值和标准差为 0.1307 和 0.3081
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST(root='../datasets/mnist', train=True, download=True, transform=transform)  # download=True:如果没有, 下载数据集
test_dataset = datasets.MNIST(root='../datasets/mnist', train=False, download=True, transform=transform)  # train=True训练集，=False测试集

In [3]:
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### 构建模型
**Generator 和 Discriminator**

[反卷积(Transposed conv deconv)实现原理（通俗易懂）](https://blog.csdn.net/weixin_39326879/article/details/120797857) : 在GAN中，生成器使用反卷积层将低维随机噪声转换为高分辨率图像

#### 反卷积层（Deconvolution Layer）

反卷积层，也称为**转置卷积层（Transposed Convolution Layer）**，是一种用于上采样的操作。它的作用是将低分辨率的特征图（feature map）转换为高分辨率的特征图。反卷积层在生成对抗网络（GAN）、图像分割、超分辨率等任务中非常常见。


| 特性                | 卷积层（Convolution）                          | 反卷积层（Transposed Convolution）            |
|---------------------|-----------------------------------------------|-----------------------------------------------|
| **目的**            | 下采样，提取特征                              | 上采样，生成高分辨率特征图                    |
| **输入与输出关系**  | 输入尺寸 > 输出尺寸                           | 输入尺寸 < 输出尺寸                           |
| **计算方式**        | 通过滑动窗口和卷积核计算输出                  | 通过填充和卷积核的转置计算输出                |
| **应用场景**        | 特征提取、分类、检测等                        | 图像生成、分割、超分辨率等                    |

#### Generator 
生成器的目标是从随机噪声（latent vector）和条件（condition）生成逼真的图像

In [4]:
class Generator(nn.Module):
    def __init__(self, input_dim=100, output_dim=1, class_num=10):
        '''
        初始化生成网络
        :param input_dim:输入随机噪声的维度，（随机噪声是为了增加输出多样性）
        :param output_dim:生成图像的通道数（灰度图为1，RGB图为3）
        :param class_num:图像种类
        '''
        super(Generator, self).__init__()
        """
         为什么需要拼接随机噪声和条件向量？
         拼接随机噪声和条件向量的目的是将两种信息结合起来，作为生成器的输入：
         随机噪声：提供生成数据的随机性。
         条件向量：提供生成数据的条件信息。
         通过拼接，生成器可以根据条件向量生成符合特定条件的数据, 同时确保每次生成的数据会有所不同
         """
        self.input_dim = input_dim + class_num # 生成器的输入维度是随机噪声的维度加上条件向量的维度
        self.output_dim = output_dim
        
        # 全连接层，将输入向量映射到高维空间，然后通过反卷积层生成图像
        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * 7 * 7),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU(),
        )

        # 反卷积层（转置卷积层），用于将高维特征图逐步上采样为最终图像
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # 输入通道数：128, 输出通道数：64, 卷积核大小：4x4, 步幅：2, 填充：1
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Tanh(), # 激活函数，将输出值限制在 [-1, 1] 范围内，适合生成图像
        )
 
    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 128, 7, 7) # 将全连接层的输出重塑为特征图的形式
        x = self.deconv(x) # 通过反卷积层生成图像
        return x


**Spectral Normalization GAN（SNGAN）的原理**

Spectral Normalization（谱归一化）是一种用于稳定 GAN 训练的技术，主要应用于判别器（Discriminator）。能够避免梯度爆炸或梯度消失问题。
** 为什么 SNGAN 效果好？**
- **稳定训练**：谱归一化有效避免了判别器过于强大导致的梯度消失或梯度爆炸问题。
- **无需额外超参数**：谱归一化不需要像 WGAN-GP 那样引入梯度惩罚，简化了训练过程。
- **通用性强**：谱归一化可以应用于各种 GAN 架构中。

#### Discriminator 
判别器的目标是区分输入图像是真实的还是生成的

In [5]:
import torch.nn.utils.spectral_norm as spectral_norm

class Discriminator(nn.Module):
    def __init__(self, input_dim=1, output_dim=1):
        '''
        初始化判别网络
        :param input_dim:输入通道数
        :param output_dim:输出通道数
        '''
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        # 卷积层，用于提取图像特征（应用谱归一化）
        self.conv = nn.Sequential(
            spectral_norm(nn.Conv2d(self.input_dim, 64, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
            nn.LeakyReLU(0.2),
        )
        
        # 全连接层，将特征图映射为最终的判别结果（应用谱归一化）
        self.fc = nn.Sequential(
            spectral_norm(nn.Linear(128 * 7 * 7, 1024)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Linear(1024, self.output_dim)),
        )
 
    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * 7 * 7)
        x = self.fc(x)
        return x

In [6]:
# SNGAN 通常使用 Hinge Loss 作为损失函数
def hinge_loss_discriminator(real_scores, fake_scores):
    # 判别器损失
    real_loss = torch.mean(torch.relu(1 - real_scores))  # 真实图像的损失
    fake_loss = torch.mean(torch.relu(1 + fake_scores))  # 生成图像的损失
    return real_loss + fake_loss

def hinge_loss_generator(fake_scores):
    # 生成器损失
    return -torch.mean(fake_scores)

### 模型训练

In [7]:
def train(dataloader, generator, discriminator, optimizer_G, optimizer_D):
    generator.train()  # 设置生成器为训练模式
    discriminator.train()  # 设置判别器为训练模式

    running_loss_G = 0.0
    running_loss_D = 0.0

    # 使用 tqdm 包裹数据加载器，显示进度条
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for index, (images, labels) in enumerate(progress_bar):
        # 将数据移动到设备
        images = images.to(device)
        labels = F.one_hot(labels, num_classes=10).float() # 将 labels 转换为 one-hot 编码(对应十个类别)
        labels = labels.to(device)
        
        batch_size = images.size(0)

        # 生成随机噪声
        z = torch.randn(batch_size, generator.input_dim - labels.size(1)).to(device)
        z = torch.cat([z, labels], dim=1)  # 拼接噪声和条件向量

        # 每 50 个 batch 更新一次判别器
        if index % 100 == 0:
            ### 更新判别器 ###
            optimizer_D.zero_grad()
            # 真实图像输入判别器
            D_real = discriminator(images)
            # 生成虚假图像并输入判别器
            images_fake = generator(z)
            D_fake = discriminator(images_fake.detach())
            # 计算判别器损失
            D_loss = hinge_loss_discriminator(D_real, D_fake)
            D_loss.backward()
            optimizer_D.step()
            # 统计判别器损失
            running_loss_D += D_loss.item()

        ### 更新生成器 ###
        optimizer_G.zero_grad()
        # 生成虚假图像并输入判别器
        images_fake = generator(z)
        D_fake = discriminator(images_fake)
        # 生成器的目标是让判别器认为虚假图像是真实的
        G_loss = hinge_loss_generator(D_fake)
        # 更新生成器
        G_loss.backward()
        optimizer_G.step()

        # 统计生成器损失
        running_loss_G += G_loss.item()

        # 更新进度条描述
        progress_bar.set_postfix(epoch=epoch+1, loss_G=G_loss.item())

    avg_loss_G = running_loss_G / len(dataloader)
    avg_loss_D = running_loss_D / len(dataloader)
    return avg_loss_G, avg_loss_D

In [8]:
def evaluate(dataloader, generator, discriminator):
    generator.eval()  # 设置生成器为评估模式
    discriminator.eval()  # 设置判别器为评估模式

    running_loss_G = 0.0
    running_loss_D = 0.0

    with torch.no_grad():  # 关闭梯度计算
        progress_bar = tqdm(dataloader, desc="Evaluating", leave=False)
        for images, labels in progress_bar:
            # 将数据移动到设备
            images = images.to(device)
            labels = F.one_hot(labels, num_classes=10).float()
            labels = labels.to(device)
            batch_size = images.size(0)

            # 生成随机噪声
            z = torch.randn(batch_size, generator.input_dim - labels.size(1)).to(device)
            z = torch.cat([z, labels], dim=1)  # 拼接噪声和条件向量

            # 真实标签和虚假标签
            y_real = torch.ones(batch_size, 1).to(device)
            y_fake = torch.zeros(batch_size, 1).to(device)

            ### 计算判别器损失 ###
            # 真实图像输入判别器
            D_real = discriminator(images)
            # 生成虚假图像并输入判别器
            images_fake = generator(z)
            D_fake = discriminator(images_fake)
            D_loss = hinge_loss_discriminator(D_real, D_fake)

            ### 计算生成器损失 ###
            G_loss = hinge_loss_generator(D_fake)

            # 统计损失
            running_loss_G += G_loss.item()
            running_loss_D += D_loss.item()

            # 更新进度条描述
            progress_bar.set_postfix(epoch=epoch+1, loss_G=G_loss.item(), loss_D=D_loss.item())

    avg_loss_G = running_loss_G / len(dataloader)
    avg_loss_D = running_loss_D / len(dataloader)
    return avg_loss_G, avg_loss_D


In [9]:
# 定义生成器和判别器
generator = Generator(input_dim=100, output_dim=1, class_num=10).to(device)
discriminator = Discriminator(input_dim=1, output_dim=1).to(device)

# 定义优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# loss_fn = nn.BCELoss() # GAN 的损失函数通常使用二元交叉熵损失（Binary Cross Entropy Loss）

#### 为什么在 GAN 中使用 `betas=(0.5, 0.999)`？

在 GAN 的训练中，生成器和判别器是两个对抗的模型，训练过程非常不稳定。为了加快收敛并提高稳定性，通常会对优化器的超参数进行调整：

1. **`beta1=0.5`**：
   - 使优化器更关注当前梯度，减少历史梯度的影响。
   - 这有助于生成器和判别器更快地响应对方的更新，从而加快对抗训练的进程。

2. **`beta2=0.999`**：
   - 保持对梯度方差的平滑估计，避免优化器过于敏感。


In [None]:
# 开始训练
num_epochs = 30
train_loss_G = []
train_loss_D = []
test_loss_G = []
test_loss_D = []

for epoch in range(num_epochs):
    # print(f"Epoch {epoch+1}/{num_epochs}")

    # 训练
    epoch_train_loss_G, epoch_train_loss_D = train(train_loader, generator, discriminator, optimizer_G, optimizer_D)

    # 在测试集上评估
    epoch_test_loss_G, epoch_test_loss_D = evaluate(test_loader, generator, discriminator)

    # 记录损失
    train_loss_G.append(epoch_train_loss_G)
    train_loss_D.append(epoch_train_loss_D)
    test_loss_G.append(epoch_test_loss_G)
    test_loss_D.append(epoch_test_loss_D)

    # 打印训练和测试结果
    template = ('Epoch:{:2d}, Train_loss_G:{:.3f}, Train_loss_D:{:.3f}, Test_loss_G:{:.3f}, Test_loss_D:{:.3f}')
    print(template.format(epoch+1, epoch_train_loss_G, epoch_train_loss_D, epoch_test_loss_G, epoch_test_loss_D))

print("训练完成!")

                                                                                                                       

Epoch: 1, Train_loss_G:-0.338, Train_loss_D:0.019, Test_loss_G:0.006, Test_loss_D:1.301


                                                                                                                       

Epoch: 2, Train_loss_G:0.666, Train_loss_D:0.010, Test_loss_G:0.568, Test_loss_D:0.572


                                                                                                                       

Epoch: 3, Train_loss_G:1.140, Train_loss_D:0.005, Test_loss_G:0.507, Test_loss_D:0.583


                                                                                                                       

Epoch: 4, Train_loss_G:1.230, Train_loss_D:0.004, Test_loss_G:1.537, Test_loss_D:0.184


                                                                                                                       

Epoch: 5, Train_loss_G:1.223, Train_loss_D:0.002, Test_loss_G:1.060, Test_loss_D:0.087


                                                                                                                       

Epoch: 6, Train_loss_G:1.417, Train_loss_D:0.002, Test_loss_G:1.263, Test_loss_D:0.102


                                                                                                                       

Epoch: 7, Train_loss_G:1.348, Train_loss_D:0.002, Test_loss_G:1.113, Test_loss_D:0.063


                                                                                                                       

Epoch: 8, Train_loss_G:1.364, Train_loss_D:0.002, Test_loss_G:0.734, Test_loss_D:0.302


                                                                                                                       

Epoch: 9, Train_loss_G:1.511, Train_loss_D:0.002, Test_loss_G:0.863, Test_loss_D:0.170


                                                                                                                       

Epoch:10, Train_loss_G:1.590, Train_loss_D:0.002, Test_loss_G:1.373, Test_loss_D:0.048


                                                                                                                       

Epoch:11, Train_loss_G:1.448, Train_loss_D:0.001, Test_loss_G:1.559, Test_loss_D:0.101


                                                                                                                       

Epoch:12, Train_loss_G:1.449, Train_loss_D:0.001, Test_loss_G:1.904, Test_loss_D:0.159


                                                                                                                       

Epoch:13, Train_loss_G:1.454, Train_loss_D:0.001, Test_loss_G:1.265, Test_loss_D:0.033


                                                                                                                       

Epoch:14, Train_loss_G:1.337, Train_loss_D:0.001, Test_loss_G:0.901, Test_loss_D:0.124


                                                                                                                       

Epoch:15, Train_loss_G:1.556, Train_loss_D:0.001, Test_loss_G:1.968, Test_loss_D:0.157


                                                                                                                       

Epoch:16, Train_loss_G:1.468, Train_loss_D:0.001, Test_loss_G:1.727, Test_loss_D:0.078


                                                                                                                       

Epoch:17, Train_loss_G:1.466, Train_loss_D:0.001, Test_loss_G:1.406, Test_loss_D:0.046


                                                                                                                       

Epoch:18, Train_loss_G:1.452, Train_loss_D:0.000, Test_loss_G:1.121, Test_loss_D:0.022


Training:  76%|███████████████████████████████████▏          | 1433/1875 [00:25<00:07, 56.67it/s, epoch=19, loss_G=1.2]

### 结果可视化

In [None]:
epochs_range = range(num_epochs)

plt.figure(figsize=(12, 3))

plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_loss_G, label='Training Generator Loss')
plt.plot(epochs_range, train_loss_D, label='Training Discriminator Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, test_loss_G, label='Test Generator Loss')
plt.plot(epochs_range, test_loss_D, label='Test Discriminator Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()

### 效果展示

In [None]:
def generate_digit_image(generator, digit, noise_dim=100):
    # 检查输入的数字是否合法
    if digit < 0 or digit > 9:
        raise ValueError("输入的数字必须在 0 到 9 之间")

    # 生成随机噪声
    z = torch.randn(1, noise_dim)  # 生成一个随机噪声

    # 生成条件向量（one-hot 编码）
    class_num = 10  # 类别数
    label = torch.tensor([digit])  # 将数字转换为张量
    label_one_hot = torch.nn.functional.one_hot(label, num_classes=class_num).float()  # 转换为 one-hot 编码

    # 拼接随机噪声和条件向量
    generator_input = torch.cat([z, label_one_hot], dim=1).to(device)

    # 生成图像
    generator.eval()  # 设置生成器为评估模式
    with torch.no_grad():  # 关闭梯度计算
        generated_image = generator(generator_input)

    # 将生成的图像从张量转换为 numpy 数组
    generated_image = generated_image.cpu().numpy()

    # 将图像从 [-1, 1] 范围转换到 [0, 1] 范围
    generated_image = (generated_image + 1) / 2

    # 去掉通道维度（灰度图）
    generated_image = generated_image.squeeze()

    return generated_image

In [None]:
# 创建一个 4x3 的子图网格
fig, axes = plt.subplots(4, 3, figsize=(8, 10))
fig.suptitle('Generated Digits (0-9)', fontsize=16)

# 生成并显示每个数字的图像
for digit in range(10):
    generated_image = generate_digit_image(generator, digit)
    # 计算子图的位置
    row = digit // 3  # 行索引
    col = digit % 3   # 列索引

    # 显示图像
    ax = axes[row, col]
    ax.imshow(generated_image, cmap='gray')
    ax.set_title(f'Digit: {digit}')
    ax.axis('off')  # 关闭坐标轴

    # 隐藏多余的子图（因为 10 个数字无法完全填满 4x3 的网格）
    for i in range(10, 12):
        row = i // 3
        col = i % 3
        axes[row, col].axis('off')

# 调整布局并显示
plt.tight_layout()
plt.show()