In [59]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from torchvision.utils import save_image
import os

In [60]:
# 设置随机种子
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [61]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 定义生成器
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))))
        self.img_shape = img_shape

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

In [62]:
# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        self.img_shape = img_shape

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

In [63]:
# 定义训练函数
def train(num_epochs, generator, discriminator, dataloader, optimizer_G, optimizer_D, criterion, latent_dim, device):
    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(tqdm(dataloader)):

            # 训练判别器
            real_imgs = imgs.to(device)
            optimizer_D.zero_grad()

            real_labels = torch.ones((real_imgs.size(0), 1)).to(device)
            fake_labels = torch.zeros((real_imgs.size(0), 1)).to(device)

            # 计算真实图像的损失
            real_outputs = discriminator(real_imgs)
            real_loss = criterion(real_outputs, real_labels)

            # 生成假图像并计算损失
            z = torch.randn((real_imgs.size(0), latent_dim)).to(device)
            fake_imgs = generator(z)
            fake_outputs = discriminator(fake_imgs.detach())
            fake_loss = criterion(fake_outputs, fake_labels)

            # 合并真实图像和假图像的损失
            d_loss = real_loss + fake_loss

            # 反向传播和优化
            d_loss.backward()
            optimizer_D.step()

            # 训练生成器
            optimizer_G.zero_grad()

            # 生成假图像并计算生成器损失
            fake_outputs = discriminator(fake_imgs)
            g_loss = criterion(fake_outputs, real_labels)

            # 反向传播和优化
            g_loss.backward()
            optimizer_G.step()

            if i % 100 == 0:
                print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

        # 保存生成的图像
        os.makedirs("images", exist_ok=True)  # 创建目录，如果目录已存在则忽略
        save_image(fake_imgs.data[:25], f"images/{epoch}.jpg", nrow=5, normalize=True)

In [64]:
# 设置超参数
latent_dim = 100
img_shape = (3, 64, 64)
lr = 0.0002
batch_size = 64
num_epochs = 10

In [65]:
# 创建生成器和判别器
generator = Generator(latent_dim, img_shape).to(device)
discriminator = Discriminator(img_shape).to(device)

In [66]:
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = Adam(generator.parameters(), lr=lr)
optimizer_D = Adam(discriminator.parameters(), lr=lr)

In [67]:
# 加载CelebA数据集
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [68]:
dataset = datasets.ImageFolder(root="../test10_AIGC/", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [70]:
# 训练GAN
train(1,generator, discriminator, dataloader, optimizer_G, optimizer_D, criterion, latent_dim, device)

  0%|          | 1/3166 [00:04<4:06:20,  4.67s/it]

[Epoch 0/1] [Batch 0/3166] [D loss: 0.2367042899131775] [G loss: 14.179183959960938]


  3%|▎         | 101/3166 [00:44<14:23,  3.55it/s]

[Epoch 0/1] [Batch 100/3166] [D loss: 0.36556166410446167] [G loss: 14.532881736755371]


  6%|▋         | 201/3166 [01:11<12:42,  3.89it/s]

[Epoch 0/1] [Batch 200/3166] [D loss: 0.422393798828125] [G loss: 14.058670043945312]


 10%|▉         | 301/3166 [01:36<11:30,  4.15it/s]

[Epoch 0/1] [Batch 300/3166] [D loss: 0.4934724271297455] [G loss: 16.494155883789062]


 13%|█▎        | 401/3166 [02:01<13:31,  3.41it/s]

[Epoch 0/1] [Batch 400/3166] [D loss: 0.35111963748931885] [G loss: 10.548120498657227]


 16%|█▌        | 501/3166 [02:27<11:05,  4.00it/s]

[Epoch 0/1] [Batch 500/3166] [D loss: 0.3185058832168579] [G loss: 11.490039825439453]


 19%|█▉        | 601/3166 [02:54<10:53,  3.92it/s]

[Epoch 0/1] [Batch 600/3166] [D loss: 0.33766865730285645] [G loss: 17.56198501586914]


 22%|██▏       | 701/3166 [03:20<10:36,  3.87it/s]

[Epoch 0/1] [Batch 700/3166] [D loss: 0.21791285276412964] [G loss: 10.593469619750977]


 25%|██▌       | 801/3166 [03:46<09:59,  3.94it/s]

[Epoch 0/1] [Batch 800/3166] [D loss: 0.140237957239151] [G loss: 12.04999828338623]


 28%|██▊       | 901/3166 [04:12<09:30,  3.97it/s]

[Epoch 0/1] [Batch 900/3166] [D loss: 0.37353527545928955] [G loss: 19.202289581298828]


 32%|███▏      | 1001/3166 [04:37<09:12,  3.92it/s]

[Epoch 0/1] [Batch 1000/3166] [D loss: 0.22060436010360718] [G loss: 9.209976196289062]


 35%|███▍      | 1101/3166 [05:03<09:46,  3.52it/s]

[Epoch 0/1] [Batch 1100/3166] [D loss: 0.5125187635421753] [G loss: 20.069950103759766]


 38%|███▊      | 1201/3166 [05:30<07:58,  4.11it/s]

[Epoch 0/1] [Batch 1200/3166] [D loss: 0.5096933841705322] [G loss: 13.129371643066406]


 41%|████      | 1301/3166 [05:55<07:36,  4.08it/s]

[Epoch 0/1] [Batch 1300/3166] [D loss: 0.5306791067123413] [G loss: 12.942316055297852]


 44%|████▍     | 1401/3166 [06:20<07:18,  4.02it/s]

[Epoch 0/1] [Batch 1400/3166] [D loss: 0.5307974815368652] [G loss: 8.526413917541504]


 47%|████▋     | 1501/3166 [06:45<07:03,  3.93it/s]

[Epoch 0/1] [Batch 1500/3166] [D loss: 0.5450019240379333] [G loss: 13.565179824829102]


 51%|█████     | 1601/3166 [07:10<06:31,  4.00it/s]

[Epoch 0/1] [Batch 1600/3166] [D loss: 0.4508327841758728] [G loss: 14.26915168762207]


 54%|█████▎    | 1701/3166 [07:38<06:41,  3.65it/s]

[Epoch 0/1] [Batch 1700/3166] [D loss: 0.3850928246974945] [G loss: 6.295385360717773]


 57%|█████▋    | 1801/3166 [08:07<07:13,  3.15it/s]

[Epoch 0/1] [Batch 1800/3166] [D loss: 0.2690119445323944] [G loss: 10.844212532043457]


 60%|██████    | 1901/3166 [08:37<06:00,  3.51it/s]

[Epoch 0/1] [Batch 1900/3166] [D loss: 0.28141558170318604] [G loss: 10.409432411193848]


 63%|██████▎   | 2001/3166 [09:05<05:56,  3.27it/s]

[Epoch 0/1] [Batch 2000/3166] [D loss: 0.20820081233978271] [G loss: 8.345808029174805]


 66%|██████▋   | 2101/3166 [09:32<04:55,  3.60it/s]

[Epoch 0/1] [Batch 2100/3166] [D loss: 0.4939730167388916] [G loss: 8.935524940490723]


 70%|██████▉   | 2201/3166 [10:00<04:07,  3.90it/s]

[Epoch 0/1] [Batch 2200/3166] [D loss: 0.5049639940261841] [G loss: 8.89568042755127]


 73%|███████▎  | 2301/3166 [10:25<03:28,  4.14it/s]

[Epoch 0/1] [Batch 2300/3166] [D loss: 0.4491713047027588] [G loss: 11.953725814819336]


 76%|███████▌  | 2401/3166 [10:50<03:17,  3.87it/s]

[Epoch 0/1] [Batch 2400/3166] [D loss: 0.3108789920806885] [G loss: 11.298907279968262]


 79%|███████▉  | 2501/3166 [11:15<02:50,  3.90it/s]

[Epoch 0/1] [Batch 2500/3166] [D loss: 0.8584447503089905] [G loss: 8.73831558227539]


 82%|████████▏ | 2601/3166 [11:40<02:25,  3.89it/s]

[Epoch 0/1] [Batch 2600/3166] [D loss: 0.37153470516204834] [G loss: 6.591784954071045]


 85%|████████▌ | 2701/3166 [12:06<01:59,  3.90it/s]

[Epoch 0/1] [Batch 2700/3166] [D loss: 0.49931108951568604] [G loss: 8.802799224853516]


 88%|████████▊ | 2801/3166 [12:31<01:29,  4.08it/s]

[Epoch 0/1] [Batch 2800/3166] [D loss: 0.6733227968215942] [G loss: 7.492887496948242]


 92%|█████████▏| 2901/3166 [12:56<01:06,  3.98it/s]

[Epoch 0/1] [Batch 2900/3166] [D loss: 0.4403234124183655] [G loss: 5.627740383148193]


 95%|█████████▍| 3001/3166 [13:21<00:39,  4.17it/s]

[Epoch 0/1] [Batch 3000/3166] [D loss: 0.5627950429916382] [G loss: 7.987661838531494]


 98%|█████████▊| 3101/3166 [13:46<00:16,  3.87it/s]

[Epoch 0/1] [Batch 3100/3166] [D loss: 0.5750430822372437] [G loss: 4.632171154022217]


100%|██████████| 3166/3166 [14:03<00:00,  3.75it/s]
