In [1]:
import torch
import os
import torch.nn as nn
from torchvision.utils import save_image
from torch.autograd import Variable
import numpy as np


img_size = 64
channels = 3
# 图像形状
img_shape = (channels, img_size, img_size)

# 定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img
    
# 設定參數
latent_dim = 100  # 潛在向量的維度

# 加載生成器模型
generator = Generator()
model_path = "/home/yuchi/AI/WGAN-GP/model/generator_epoch_40.pth"  # 替換成實際的模型路徑
generator.load_state_dict(torch.load(model_path))
generator.eval()

# 設定設備（CUDA or CPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)  # 將模型移到正確設備

if torch.cuda.is_available():
    generator.cuda()


# 生成500張圖片
batch_size = 16  # 每批生成圖片數量
total_images = 500
generated_count = 0
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

while generated_count < total_images:
    # 確定本批次生成圖片數量
    current_batch_size = min(batch_size, total_images - generated_count)
    z = torch.randn(current_batch_size, latent_dim, device=device)  # 在正確設備上生成隨機向量
    gen_imgs = generator(z)  # 利用生成器生成圖片

    # 保存圖片
    for i in range(current_batch_size):
        img_index = generated_count + i + 1
        save_path = os.path.join(f"/home/yuchi/AI/WGAN-GP/Result/{img_index}.jpg")
        save_image(gen_imgs[i], save_path, normalize=True)

    generated_count += current_batch_size

!python -m pytorch_fid /home/yuchi/AI/anim /home/yuchi/AI/WGAN-GP/Result --batch-size 16


100%|███████████████████████████████████████| 3973/3973 [14:45<00:00,  4.49it/s]
100%|███████████████████████████████████████████| 32/32 [00:08<00:00,  4.00it/s]
FID:  183.4882343679278
