In [1]:
import os

#文件处理
def Preprocess(folder_path):
    file_names = os.listdir(folder_path)
    labels = [] 
    paths = []
    for file_name in file_names: 
        if os.path.isdir(file_name):
            continue

        file_path = os.path.join(folder_path, file_name) 
        paths.append(file_path)

        label = file_name.split("_")[0]
        labels.append(label)

    return paths,labels

In [2]:
import os
import torch

from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_tensor

from PIL import Image

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_paths, labels = None,transform=None):
        self.data_paths = data_paths
        self.labels =labels
        self.transform = transform

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
       #读取图片
        img_path = self.data_paths[idx]
        img = Image.open(img_path).convert('RGB')
        #获取标签
        
        if self.transform is not None:
            img = self.transform(img)
            
        if self.labels is not None:       
            label = self.labels[idx]

            return img, label
        else:
            return img

In [3]:
import torchvision.transforms as transforms
train_transform = transforms.Compose([ 
    transforms.Resize((64, 64)),                                
    transforms.ToTensor(),
])

In [4]:
train_root="./all-dogs/"


data_paths_train, label_train = Preprocess(train_root)

# 创建自定义数据集实例
train_dataset = CustomDataset(data_paths_train, label_train,train_transform)

print(len(train_dataset))

# 创建数据加载器
batch_size=64
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)

print(train_loader)


20579
<torch.utils.data.dataloader.DataLoader object at 0x00000264C6C104C0>


定义生成器和判别器

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt

In [6]:
# 训练参数

# 网络参数
latent_dim = 100
image_size = 64
hidden_size = 64

In [7]:
# 生成器模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, hidden_size * 8, 4, 1, 0),
            nn.BatchNorm2d(hidden_size * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1),
            nn.BatchNorm2d(hidden_size * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1),
            nn.BatchNorm2d(hidden_size * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(x.size(0), latent_dim, 1, 1)
        return self.model(x)

# 判别器模型
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, hidden_size, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(hidden_size, hidden_size * 2, 4, 2, 1),
            nn.BatchNorm2d(hidden_size * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(hidden_size * 2, hidden_size * 4, 4, 2, 1),
            nn.BatchNorm2d(hidden_size * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(hidden_size * 4, hidden_size * 8, 4, 2, 1),
            nn.BatchNorm2d(hidden_size * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(hidden_size * 8, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [17]:
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()

# 将生成器和判别器移动到GPU（如果可用）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

# 定义二元交叉熵损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [37]:
num_epochs = 50

In [38]:
from tqdm import tqdm
# 真实图像的标签为1，生成图像的标签为0
real_label = 1
fake_label = 0

# 训练GAN模型
total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        # real_label = torch.ones(batch_size, 1)
        # fake_label = torch.zeros(batch_size, 1)
        # 训练判别器
        discriminator.zero_grad()
        real_images = images.to(device)
        batch_size = real_images.size(0)
        label = torch.full((batch_size,), real_label, device=device,dtype=torch.float)
        
        # 判别器对真实图像的判别结果
        output_real = discriminator(real_images).view(-1)
        loss_D_real = criterion(output_real, label)
        loss_D_real.backward()
        D_x = output_real.mean().item()

        # 生成假图像并计算判别器对假图像的判别结果
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        fake_images = generator(noise)
        label.fill_(fake_label)
        output_fake = discriminator(fake_images.detach()).view(-1)
        loss_D_fake = criterion(output_fake, label)
        loss_D_fake.backward()
        D_G_z1 = output_fake.mean().item()
        
        # 更新判别器的参数
        loss_D = loss_D_real + loss_D_fake
        optimizer_D.step()

        # 训练生成器
        generator.zero_grad()
        label.fill_(real_label)
        output_fake = discriminator(fake_images).view(-1)
        loss_G = criterion(output_fake, label)
        loss_G.backward()
        D_G_z2 = output_fake.mean().item()
        
        # 更新生成器的参数
        optimizer_G.step()

        # 输出训练信息
        if (i + 1) % 322 == 0:
            tqdm.write(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{total_steps}], "
                       f"Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}, "
                       f"D(x): {D_x:.4f}, D(G(z1)): {D_G_z1:.4f}, D(G(z2)): {D_G_z2:.4f}")

    # 每次生成10张图像并保存
    if epoch+1 in [1, 2, 3,5,10,20,30,50]:
        img=[]
        with torch.no_grad():
            for j in range(10):
                # 定义生成器的输入向量维度
                noise_dim = 100
                # 生成固定的随机噪声向量
                fixed_noise = torch.randn(1, noise_dim, 1, 1).to(device)
                fake_images = generator(fixed_noise).detach().cpu()
                img.append(fake_images)
                save_dir_epoch = os.path.join("./img", f"epoch_{epoch+1}")
                os.makedirs(save_dir_epoch, exist_ok=True)
                save_image(img[j], os.path.join(save_dir_epoch, f"generated_image_{j+1}.png"), normalize=True)

            print("save img.")

Epoch 1/50:   0%|          | 0/322 [00:00<?, ?it/s]

Epoch 1/50: 100%|██████████| 322/322 [01:43<00:00,  3.11it/s]


Epoch [1/50], Step [322/322], Loss D: 1.9860, Loss G: 0.3943, D(x): 0.2400, D(G(z1)): 0.0254, D(G(z2)): 0.6879
save img.


Epoch 2/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [2/50], Step [322/322], Loss D: 0.4178, Loss G: 3.0485, D(x): 0.7934, D(G(z1)): 0.1448, D(G(z2)): 0.0566
save img.


Epoch 3/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [3/50], Step [322/322], Loss D: 0.3170, Loss G: 3.0607, D(x): 0.8340, D(G(z1)): 0.1097, D(G(z2)): 0.0554
save img.


Epoch 4/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [4/50], Step [322/322], Loss D: 0.5176, Loss G: 5.5727, D(x): 0.8974, D(G(z1)): 0.3161, D(G(z2)): 0.0050


Epoch 5/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [5/50], Step [322/322], Loss D: 0.5353, Loss G: 2.2474, D(x): 0.7027, D(G(z1)): 0.1367, D(G(z2)): 0.1209
save img.


Epoch 6/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [6/50], Step [322/322], Loss D: 1.5001, Loss G: 2.2976, D(x): 0.2643, D(G(z1)): 0.0082, D(G(z2)): 0.1396


Epoch 7/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [7/50], Step [322/322], Loss D: 0.4938, Loss G: 5.4049, D(x): 0.9458, D(G(z1)): 0.3364, D(G(z2)): 0.0051


Epoch 8/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [8/50], Step [322/322], Loss D: 0.1466, Loss G: 4.4254, D(x): 0.9130, D(G(z1)): 0.0486, D(G(z2)): 0.0169


Epoch 9/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [9/50], Step [322/322], Loss D: 0.6198, Loss G: 2.4150, D(x): 0.6017, D(G(z1)): 0.0257, D(G(z2)): 0.1214


Epoch 10/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [10/50], Step [322/322], Loss D: 0.3263, Loss G: 4.1738, D(x): 0.9013, D(G(z1)): 0.1798, D(G(z2)): 0.0194
save img.


Epoch 11/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [11/50], Step [322/322], Loss D: 0.7313, Loss G: 2.0394, D(x): 0.5500, D(G(z1)): 0.0249, D(G(z2)): 0.1671


Epoch 12/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [12/50], Step [322/322], Loss D: 0.2977, Loss G: 3.7712, D(x): 0.8826, D(G(z1)): 0.1461, D(G(z2)): 0.0306


Epoch 13/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [13/50], Step [322/322], Loss D: 0.5694, Loss G: 3.3831, D(x): 0.8898, D(G(z1)): 0.3349, D(G(z2)): 0.0397


Epoch 14/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [14/50], Step [322/322], Loss D: 0.2883, Loss G: 3.1216, D(x): 0.8294, D(G(z1)): 0.0817, D(G(z2)): 0.0516


Epoch 15/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [15/50], Step [322/322], Loss D: 0.3562, Loss G: 5.4584, D(x): 0.9724, D(G(z1)): 0.2656, D(G(z2)): 0.0057


Epoch 16/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [16/50], Step [322/322], Loss D: 0.4381, Loss G: 3.7667, D(x): 0.8569, D(G(z1)): 0.2284, D(G(z2)): 0.0303


Epoch 17/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [17/50], Step [322/322], Loss D: 0.2089, Loss G: 3.5924, D(x): 0.8836, D(G(z1)): 0.0737, D(G(z2)): 0.0347


Epoch 18/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [18/50], Step [322/322], Loss D: 0.4248, Loss G: 5.5958, D(x): 0.9450, D(G(z1)): 0.2749, D(G(z2)): 0.0054


Epoch 19/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [19/50], Step [322/322], Loss D: 0.5959, Loss G: 2.8464, D(x): 0.6076, D(G(z1)): 0.0152, D(G(z2)): 0.0833


Epoch 20/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [20/50], Step [322/322], Loss D: 0.6578, Loss G: 6.4718, D(x): 0.8909, D(G(z1)): 0.3619, D(G(z2)): 0.0022
save img.


Epoch 21/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [21/50], Step [322/322], Loss D: 0.3218, Loss G: 1.5990, D(x): 0.7754, D(G(z1)): 0.0282, D(G(z2)): 0.2801


Epoch 22/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [22/50], Step [322/322], Loss D: 0.1889, Loss G: 4.1315, D(x): 0.9289, D(G(z1)): 0.0982, D(G(z2)): 0.0231


Epoch 23/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [23/50], Step [322/322], Loss D: 0.2884, Loss G: 2.9204, D(x): 0.8130, D(G(z1)): 0.0493, D(G(z2)): 0.0728


Epoch 24/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [24/50], Step [322/322], Loss D: 0.3599, Loss G: 2.1498, D(x): 0.7522, D(G(z1)): 0.0376, D(G(z2)): 0.1667


Epoch 25/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [25/50], Step [322/322], Loss D: 0.3210, Loss G: 5.3336, D(x): 0.9182, D(G(z1)): 0.1898, D(G(z2)): 0.0074


Epoch 26/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [26/50], Step [322/322], Loss D: 0.2698, Loss G: 3.6152, D(x): 0.8687, D(G(z1)): 0.0941, D(G(z2)): 0.0385


Epoch 27/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [27/50], Step [322/322], Loss D: 0.2524, Loss G: 4.2968, D(x): 0.8263, D(G(z1)): 0.0279, D(G(z2)): 0.0236


Epoch 28/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [28/50], Step [322/322], Loss D: 0.2201, Loss G: 4.3172, D(x): 0.8370, D(G(z1)): 0.0332, D(G(z2)): 0.0293


Epoch 29/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [29/50], Step [322/322], Loss D: 0.3954, Loss G: 2.5578, D(x): 0.7812, D(G(z1)): 0.0812, D(G(z2)): 0.1160


Epoch 30/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [30/50], Step [322/322], Loss D: 0.3071, Loss G: 3.6381, D(x): 0.8715, D(G(z1)): 0.1370, D(G(z2)): 0.0382
save img.


Epoch 31/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [31/50], Step [322/322], Loss D: 0.2330, Loss G: 3.0654, D(x): 0.8619, D(G(z1)): 0.0686, D(G(z2)): 0.0625


Epoch 32/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [32/50], Step [322/322], Loss D: 0.3383, Loss G: 3.2472, D(x): 0.8885, D(G(z1)): 0.1515, D(G(z2)): 0.0553


Epoch 33/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [33/50], Step [322/322], Loss D: 0.5870, Loss G: 0.5470, D(x): 0.6201, D(G(z1)): 0.0095, D(G(z2)): 0.6566


Epoch 34/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [34/50], Step [322/322], Loss D: 0.1330, Loss G: 3.1016, D(x): 0.9421, D(G(z1)): 0.0642, D(G(z2)): 0.0587


Epoch 35/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [35/50], Step [322/322], Loss D: 0.2916, Loss G: 3.0782, D(x): 0.8036, D(G(z1)): 0.0310, D(G(z2)): 0.0720


Epoch 36/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [36/50], Step [322/322], Loss D: 0.3155, Loss G: 4.2631, D(x): 0.7746, D(G(z1)): 0.0103, D(G(z2)): 0.0276


Epoch 37/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [37/50], Step [322/322], Loss D: 0.1602, Loss G: 3.7391, D(x): 0.9327, D(G(z1)): 0.0803, D(G(z2)): 0.0365


Epoch 38/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [38/50], Step [322/322], Loss D: 0.1326, Loss G: 4.6050, D(x): 0.9300, D(G(z1)): 0.0538, D(G(z2)): 0.0163


Epoch 39/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [39/50], Step [322/322], Loss D: 0.3446, Loss G: 4.1164, D(x): 0.8480, D(G(z1)): 0.1395, D(G(z2)): 0.0267


Epoch 40/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [40/50], Step [322/322], Loss D: 0.1504, Loss G: 4.4944, D(x): 0.8993, D(G(z1)): 0.0354, D(G(z2)): 0.0244


Epoch 41/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [41/50], Step [322/322], Loss D: 0.1485, Loss G: 4.9468, D(x): 0.9551, D(G(z1)): 0.0899, D(G(z2)): 0.0122


Epoch 42/50: 100%|██████████| 322/322 [01:29<00:00,  3.61it/s]


Epoch [42/50], Step [322/322], Loss D: 0.3658, Loss G: 4.7751, D(x): 0.8770, D(G(z1)): 0.1816, D(G(z2)): 0.0117


Epoch 43/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [43/50], Step [322/322], Loss D: 0.3460, Loss G: 3.6452, D(x): 0.8288, D(G(z1)): 0.1220, D(G(z2)): 0.0428


Epoch 44/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [44/50], Step [322/322], Loss D: 0.0798, Loss G: 4.4517, D(x): 0.9425, D(G(z1)): 0.0188, D(G(z2)): 0.0184


Epoch 45/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [45/50], Step [322/322], Loss D: 0.3753, Loss G: 2.5834, D(x): 0.7684, D(G(z1)): 0.0612, D(G(z2)): 0.1237


Epoch 46/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [46/50], Step [322/322], Loss D: 0.3508, Loss G: 1.9424, D(x): 0.7357, D(G(z1)): 0.0129, D(G(z2)): 0.2035


Epoch 47/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [47/50], Step [322/322], Loss D: 0.1248, Loss G: 4.2292, D(x): 0.9405, D(G(z1)): 0.0555, D(G(z2)): 0.0232


Epoch 48/50: 100%|██████████| 322/322 [01:28<00:00,  3.64it/s]


Epoch [48/50], Step [322/322], Loss D: 0.1167, Loss G: 4.9058, D(x): 0.9221, D(G(z1)): 0.0284, D(G(z2)): 0.0122


Epoch 49/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]


Epoch [49/50], Step [322/322], Loss D: 0.3628, Loss G: 2.3516, D(x): 0.7919, D(G(z1)): 0.0783, D(G(z2)): 0.1293


Epoch 50/50: 100%|██████████| 322/322 [01:28<00:00,  3.63it/s]

Epoch [50/50], Step [322/322], Loss D: 0.2129, Loss G: 3.4583, D(x): 0.8568, D(G(z1)): 0.0366, D(G(z2)): 0.0621
save img.





In [39]:
torch.save(generator,'./model/generator_3')
torch.save(discriminator,'./model/discriminator_3')

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
from scipy.linalg import sqrtm

加载Inception-v3模型和预训练参数：

In [10]:
inception_v3 = models.inception_v3(pretrained=True, transform_input=False)
inception_v3.eval()



Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

定义辅助函数，用于提取特征和计算统计信息：

In [14]:
def extract_features(images, transform):
    features = []
    for image in images:
        image = transform(image).unsqueeze(0)
        feature = inception_v3(image)
        features.append(feature.flatten())
    features = torch.stack(features)
    return features

def calculate_statistics(features):
    mu = torch.mean(features, dim=0)
    features_centered = features - mu
    sigma = torch.matmul(features_centered.t(), features_centered) / (features_centered.size(0) - 1)
    return mu, sigma

定义计算FID的函数：

In [17]:
def calculate_fid(real_images, generated_images):
    # 定义图片预处理转换
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
    ])

    # 提取真实图片和生成图片的特征向量
    real_features = extract_features(real_images, transform)
    generated_features = extract_features(generated_images, transform)

    # 计算真实图片和生成图片的均值和协方差
    mu_real, sigma_real = calculate_statistics(real_features)
    mu_generated, sigma_generated = calculate_statistics(generated_features)

    # 计算FID
    # mean squared difference between means
    mse = torch.sum((mu_real - mu_generated)**2)
    # trace covariance product
    trace_cov = torch.trace(torch.from_numpy(sigma_real.detach().numpy()) + torch.from_numpy(sigma_generated.detach().numpy()) - 2*sqrtm(torch.mm(torch.from_numpy(sigma_real.detach().numpy()), torch.from_numpy(sigma_real.detach().numpy()))))

    fid_score = mse + trace_cov
    return fid_score.item()

计算图片集合的

In [28]:
import os
from PIL import Image
import numpy as np

# 设置目标图像大小
target_size = (64, 64)

# 定义文件夹路径
ge_path = './img/epoch_50/'

def get_img(folder_path, max_images=10):
    # 初始化图片列表和计数器
    images = []
    count = 0

    # 遍历文件夹中的图片文件
    for filename in os.listdir(folder_path):
        # 检查图片数量是否已达到上限
        if count >= max_images:
            break

        # 构建图片文件的完整路径
        image_path = os.path.join(folder_path, filename)
        
        # 使用PIL库加载图片
        image = Image.open(image_path).convert('RGB')
        
        # 调整图像大小
        image = image.resize(target_size)
        
        # 将加载的图片添加到列表中
        images.append(image)
        count += 1

    # 将图片列表转换为numpy数组
    images_array = np.array(images)

    # 关闭所有图片对象
    for image in images:
        image.close()
    
    return images_array 

re_images = get_img("./all-dogs/")
ge_images = get_img(ge_path)

# 使用calculate_fid函数计算FID分数
fid_score = calculate_fid(re_images, ge_images)

# 打印FID分数
print('FID Score:', fid_score)


TypeError: Unexpected type <class 'numpy.ndarray'>

计算单个图片的

In [27]:
def extract_features(image, transform):
    image = transform(image).unsqueeze(0)
    feature = inception_v3(image)
    return feature.flatten()

def calculate_statistics(feature_vector):
    mu = torch.mean(feature_vector, dim=0)
    sigma = torch.matmul(feature_vector.T, feature_vector) / feature_vector.size(0)
    return mu, sigma

def calculate_fid(real_image, generated_image):
    # 定义图片预处理转换
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
    ])

    # 提取真实图片和生成图片的特征向量
    real_feature = extract_features(real_image, transform)
    generated_feature = extract_features(generated_image, transform)

    # 计算真实图片和生成图片的均值和协方差
    mu_real, sigma_real = calculate_statistics(real_feature.unsqueeze(0))
    mu_generated, sigma_generated = calculate_statistics(generated_feature.unsqueeze(0))

    # 计算FID
    # mean squared difference between means
    mse = torch.sum((mu_real - mu_generated)**2)
    # trace covariance product
    trace_cov = torch.trace(torch.from_numpy(sigma_real.detach().numpy()) + torch.from_numpy(sigma_generated.detach().numpy()) - 2*sqrtm(torch.mm(torch.from_numpy(sigma_real.detach().numpy()), torch.from_numpy(sigma_real.detach().numpy()))))

    fid_score = mse + trace_cov
    return fid_score.item()

# 示例
real_image = Image.open("./all-dogs/n02085620_473.jpg")
generated_image = Image.open("./img/300/generated_image_53.png")
fid_score = calculate_fid(real_image, generated_image)
print("FID score:", fid_score)

FID score: (2030.87744140625-2.2163403034210205j)


生成300张图片

In [18]:
generator=torch.load('./model/generator_3')
discriminator=torch.load('./model/discriminator_3')

In [20]:
img=[]
with torch.no_grad():
    for j in range(300):
        # 定义生成器的输入向量维度
        noise_dim = 100
        # 生成固定的随机噪声向量
        fixed_noise = torch.randn(1, noise_dim, 1, 1).to(device)
        fake_images = generator(fixed_noise).detach().cpu()
        img.append(fake_images)
        save_dir_epoch = os.path.join("./img/300")
        os.makedirs(save_dir_epoch, exist_ok=True)
        save_image(img[j], os.path.join(save_dir_epoch, f"generated_image_{j+1}.png"), normalize=True)

    print("save img.")

save img.
