In [9]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
from matplotlib import pyplot as plt
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm  # 导入tqdm
from torch.autograd import Variable
import torch.autograd as autograd
import numpy as np

# 自定义数据集类
class AnimeDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image
    
image_dir = '/home/yuchi/AI/anim'
image_size = 64
batch_size = 32

# 图像数据预处理
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载数据集
dataset = AnimeDataset(image_dir=image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [10]:
# 配置参数
latent_dim = 100
img_size = 64
channels = 3
n_epochs = 40
lr = 0.0002
b1 = 0.5
b2 = 0.999
lambda_gp = 10
sample_interval = 400
data_path = "/home/yuchi/AI/anim"  # 数据集路径

# 图像形状
img_shape = (channels, img_size, img_size)

# 检查设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# 定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img


# 定义判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity


# 初始化生成器和判别器
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# 梯度惩罚函数
def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(real_samples.shape[0], 1, device=device, requires_grad=False)
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

cuda


In [11]:
# 开始训练
for epoch in range(n_epochs):
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{n_epochs}")
    for i, imgs in progress_bar:
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)

        # 训练判别器
        optimizer_D.zero_grad()
        z = torch.randn(batch_size, latent_dim, device=device)
        fake_imgs = generator(z)
        real_validity = discriminator(real_imgs)
        fake_validity = discriminator(fake_imgs)
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        if i % 5 == 0:
            optimizer_G.zero_grad()
            fake_imgs = generator(z)
            fake_validity = discriminator(fake_imgs)
            g_loss = -torch.mean(fake_validity)
            g_loss.backward()
            optimizer_G.step()

        # 更新进度条显示损失
        progress_bar.set_description(f"Epoch [{epoch+1}/{n_epochs}]")

    # 每轮训练结束后保存生成器模型
    torch.save(generator.state_dict(), os.path.join(f"/home/yuchi/AI/WGAN-GP/model/generator_epoch_{epoch+1}.pth"))


Epoch 1/40:   0%|          | 0/1987 [00:00<?, ?it/s]

Epoch [1/40]: 100%|██████████| 1987/1987 [01:42<00:00, 19.32it/s]
Epoch [2/40]: 100%|██████████| 1987/1987 [01:46<00:00, 18.67it/s]
Epoch [3/40]: 100%|██████████| 1987/1987 [01:51<00:00, 17.80it/s]
Epoch [4/40]: 100%|██████████| 1987/1987 [01:53<00:00, 17.57it/s]
Epoch [5/40]: 100%|██████████| 1987/1987 [02:05<00:00, 15.86it/s]
Epoch [6/40]: 100%|██████████| 1987/1987 [02:03<00:00, 16.06it/s]
Epoch [7/40]: 100%|██████████| 1987/1987 [01:56<00:00, 17.10it/s]
Epoch [8/40]: 100%|██████████| 1987/1987 [02:14<00:00, 14.80it/s]
Epoch [9/40]: 100%|██████████| 1987/1987 [02:19<00:00, 14.21it/s]
Epoch [10/40]: 100%|██████████| 1987/1987 [02:14<00:00, 14.80it/s]
Epoch [11/40]: 100%|██████████| 1987/1987 [02:27<00:00, 13.44it/s]
Epoch [12/40]: 100%|██████████| 1987/1987 [05:52<00:00,  5.64it/s]
Epoch [13/40]: 100%|██████████| 1987/1987 [05:51<00:00,  5.66it/s]
Epoch [14/40]: 100%|██████████| 1987/1987 [04:55<00:00,  6.72it/s]
Epoch [15/40]: 100%|██████████| 1987/1987 [05:31<00:00,  5.99it/s]
Epoc

In [12]:
import torch
import os
import torch.nn as nn
from torchvision.utils import save_image
from torch.autograd import Variable
import numpy as np


img_size = 64
channels = 3
# 图像形状
img_shape = (channels, img_size, img_size)

# 定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img
    
# 設定參數
latent_dim = 100  # 潛在向量的維度

# 加載生成器模型
generator = Generator()
model_path = "/home/yuchi/AI/WGAN-GP/model/generator_epoch_40.pth"  # 替換成實際的模型路徑
generator.load_state_dict(torch.load(model_path))
generator.eval()

# 設定設備（CUDA or CPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)  # 將模型移到正確設備

if torch.cuda.is_available():
    generator.cuda()


# 生成500張圖片
# batch_size = 32  # 每批生成圖片數量
total_images = 500
generated_count = 0
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

while generated_count < total_images:
    # 確定本批次生成圖片數量
    current_batch_size = min(batch_size, total_images - generated_count)
    z = torch.randn(current_batch_size, latent_dim, device=device)  # 在正確設備上生成隨機向量
    gen_imgs = generator(z)  # 利用生成器生成圖片

    # 保存圖片
    for i in range(current_batch_size):
        img_index = generated_count + i + 1
        save_path = os.path.join(f"/home/yuchi/AI/WGAN-GP/Result/{img_index}.jpg")
        save_image(gen_imgs[i], save_path, normalize=True)

    generated_count += current_batch_size

!python -m pytorch_fid /home/yuchi/AI/anim /home/yuchi/AI/WGAN-GP/Result --batch-size 32


  1%|▎                                        | 15/1987 [00:08<14:11,  2.32it/s]^C
  1%|▎                                        | 15/1987 [00:08<19:03,  1.72it/s]
Traceback (most recent call last):
  File "/home/yuchi/anaconda3/envs/torch230/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/yuchi/anaconda3/envs/torch230/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/yuchi/anaconda3/envs/torch230/lib/python3.10/site-packages/pytorch_fid/__main__.py", line 3, in <module>
    pytorch_fid.fid_score.main()
  File "/home/yuchi/anaconda3/envs/torch230/lib/python3.10/site-packages/pytorch_fid/fid_score.py", line 313, in main
    fid_value = calculate_fid_given_paths(args.path,
  File "/home/yuchi/anaconda3/envs/torch230/lib/python3.10/site-packages/pytorch_fid/fid_score.py", line 259, in calculate_fid_given_paths
    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
