In [3]:
import os
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.models import inception_v3
from torch.utils.data import DataLoader, ConcatDataset
import torch.nn.functional as F
import numpy as np
import random
import numpy as np
import torch
import torch.nn.utils.spectral_norm as spectral_norm 
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # 如果使用多个 GPU，设置以下参数
    torch.cuda.manual_seed_all(seed)
    # 确保 cudnn 的确定性，可能会降低性能
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 在代码开头设置随机数种子
set_random_seed(42)

In [4]:
# 自注意力层
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.channel_in = in_dim

        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv   = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim,      kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1)
    def forward(self, x):
        m_batchsize, C, width, height = x.size()

        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
        proj_key   = self.key_conv(x).view(m_batchsize, -1, width * height)
        energy     = torch.bmm(proj_query, proj_key)
        attention  = self.softmax(energy)

        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)

        out = self.gamma * out + x
        return out

# 条件批归一化
class ConditionalBatchNorm2d(nn.Module):
    def __init__(self, num_features, num_classes):
        super(ConditionalBatchNorm2d, self).__init__()
        self.bn = nn.BatchNorm2d(num_features, affine=False)
        self.embed = nn.Embedding(num_classes, num_features * 2)
        # 初始化gamma为1，beta为0
        self.embed.weight.data[:, :num_features].fill_(1)
        self.embed.weight.data[:, num_features:].zero_()

    def forward(self, x, y):
        out = self.bn(x)
        gamma, beta = self.embed(y).chunk(2, 1)
        gamma = gamma.unsqueeze(2).unsqueeze(3)
        beta  = beta.unsqueeze(2).unsqueeze(3)
        out = gamma * out + beta
        return out

# 生成器的残差块
class ResBlockG(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes):
        super(ResBlockG, self).__init__()
        self.bn1 = ConditionalBatchNorm2d(in_channels, num_classes)
        self.activation = nn.ReLU(inplace=True)
        self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1)
        self.bn2 = ConditionalBatchNorm2d(out_channels, num_classes)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels or True:
            self.shortcut = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x, labels):
        out = self.bn1(x, labels)
        out = self.activation(out)
        out = self.conv1(out)
        out = self.bn2(out, labels)
        out = self.activation(out)
        out = self.conv2(out)
        shortcut = self.shortcut(x)
        out += shortcut
        return out

# 判别器的残差块
class ResBlockD(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResBlockD, self).__init__()
        self.conv1 = spectral_norm(nn.Conv2d(in_channels, out_channels, 3, 1, 1))
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = spectral_norm(nn.Conv2d(out_channels, out_channels, 4, 2, 1))

        self.shortcut = spectral_norm(nn.Conv2d(in_channels, out_channels, 4, 2, 1)) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        out = self.conv1(x)
        out = self.activation(out)
        out = self.conv2(out)
        shortcut = self.shortcut(x)
        out += shortcut
        return out

# 交叉注意力机制（Cross Attention）
class CrossAttention(nn.Module):
    def __init__(self, in_dim):
        super(CrossAttention, self).__init__()
        # 将图像特征映射到Q
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
        # 将类嵌入映射到K和V
        self.key_linear = nn.Linear(in_dim, in_dim // 8)
        self.value_linear = nn.Linear(in_dim, in_dim)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, class_emb):
        # x: B,C,H,W
        # class_emb: B,C（C与x的C相同，以便融合）
        B, C, H, W = x.size()
        Q = self.query_conv(x).view(B, -1, H*W).permute(0,2,1) # B,HW,C'
        K = self.key_linear(class_emb).unsqueeze(1) # B,1,C'
        # 计算注意力权重
        energy = torch.bmm(Q, K.transpose(1,2)) # B,HW,1
        attention = self.softmax(energy)
        V = self.value_linear(class_emb).unsqueeze(1) # B,1,C
        out = torch.bmm(attention, V) # B,HW,C
        out = out.permute(0,2,1).view(B,C,H,W)
        out = self.gamma * out + x
        return out

# 生成器
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, channels):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        img_size = 32  # 固定图像尺寸为32

        self.init_size = img_size // 16  # init_size = 2
        self.l1 = nn.Linear(latent_dim, 512 * self.init_size ** 2)

        self.res_blocks = nn.ModuleList([
            ResBlockG(512, 512, num_classes),
            ResBlockG(512, 256, num_classes),
            ResBlockG(256, 128, num_classes),
            ResBlockG(128, 64, num_classes)
        ])

        self.attention = SelfAttention(64)

        # 为交叉注意力准备的类嵌入（与64通道匹配）
        self.class_embed_ca = nn.Embedding(num_classes, 64)
        self.cross_attention = CrossAttention(64)

        # 融合特征的卷积层（将中间层特征进行上采样后与最终特征融合）
        self.fusion_conv = nn.Conv2d(256+64, 64, 1, 1, 0)

        self.bn = ConditionalBatchNorm2d(64, num_classes)
        self.activation = nn.ReLU(inplace=True)
        self.conv_out = nn.Conv2d(64, channels, 3, 1, 1)

    def forward(self, noise, labels):
        out = self.l1(noise)
        out = out.view(out.size(0), 512, self.init_size, self.init_size)

        # 多级特征生成
        out1 = self.res_blocks[0](out, labels)   # 512通道, 尺寸4x4
        out2 = self.res_blocks[1](out1, labels)  # 256通道, 尺寸8x8
        out3 = self.res_blocks[2](out2, labels)  # 128通道, 尺寸16x16
        out4 = self.res_blocks[3](out3, labels)  # 64通道, 尺寸32x32
        out = out4

        # 自注意力
        out = self.attention(out)

        # 交叉注意力融合类特征
        class_emb = self.class_embed_ca(labels)  # B,64
        out = self.cross_attention(out, class_emb)

        # 特征融合（将out2上采样到out相同大小，然后拼接）
        # out2是8x8, out是32x32
        out2_upsampled = F.interpolate(out2, size=out.shape[2:], mode='nearest') 
        fused = torch.cat([out2_upsampled, out], dim=1) # B,(256+64),32,32
        out = self.fusion_conv(fused) # B,64,32,32

        out = self.bn(out, labels)
        out = self.activation(out)
        img = torch.tanh(self.conv_out(out))
        return img

# 判别器
class Discriminator(nn.Module):
    def __init__(self, num_classes, channels):
        super(Discriminator, self).__init__()
        self.num_classes = num_classes

        self.initial = nn.Sequential(
            spectral_norm(nn.Conv2d(channels, 64, 3, 1, 1)),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.res_blocks = nn.ModuleList([
            ResBlockD(64, 128),
            ResBlockD(128, 256),
            ResBlockD(256, 512)
        ])

        self.attention = SelfAttention(512)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = spectral_norm(nn.Linear(512, 1))
        self.embed = nn.Embedding(num_classes, 512)

    def forward(self, img, labels):
        out = self.initial(img)
        for res_block in self.res_blocks:
            out = res_block(out)
        out = self.attention(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)

        validity = self.fc(out)
        embed = self.embed(labels)
        prod = torch.sum(out * embed, dim=1, keepdim=True)
        return validity + prod


In [3]:
import torch
from newtest import evaluate_generator  # 导入评估函数

# Parameters
img_size = 32
latent_dim = 128
channels = 3
num_classes = 100
generator_path = 'best_BIGGANmodel.pth'

# Set random seed for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Evaluate the generator
fid, is_mean, is_std, intra_fid = evaluate_generator(
    generator_path=generator_path,
    generator_class=Generator,  # Replace with your Generator class
    dataset_root='images',
    img_size=img_size,
    latent_dim=latent_dim,
    num_classes=num_classes,
    channels=channels,
    #model_type='biggan'  # Specify model type
)

# Print results
print(f"Evaluating Generator...")
print(f"Loaded Generator: {Generator.__name__}")
print(f"Running on device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
print(f"FID Score: {fid:.4f}")
print(f"Inception Score: {is_mean:.4f} ± {is_std:.4f}")
print(f"Intra-FID: {intra_fid:.4f}")


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to images/cifar-100-python.tar.gz


100%|███████████████████████████████████████████████████████████████████████████| 169001437/169001437 [00:11<00:00, 14776065.91it/s]


Extracting images/cifar-100-python.tar.gz to images




Evaluating Generator...
Loaded Generator: Generator
Running on device: CUDA
FID Score: 18.8052
Inception Score: 5.7888 ± 0.6531
Intra-FID: 65.3762


In [3]:
import torch
from newtest import evaluate_generator  # 导入评估函数

# Parameters
img_size = 32
latent_dim = 128
channels = 3
num_classes = 100
generator_path = 'two_model.pth'

# Set random seed for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Evaluate the generator
fid, is_mean, is_std, intra_fid = evaluate_generator(
    generator_path=generator_path,
    generator_class=Generator,  # Replace with your Generator class
    dataset_root='images',
    img_size=img_size,
    latent_dim=latent_dim,
    num_classes=num_classes,
    channels=channels,
    #model_type='biggan'  # Specify model type
)

# Print results
print(f"Evaluating Generator...")
print(f"Loaded Generator: {Generator.__name__}")
print(f"Running on device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
print(f"FID Score: {fid:.4f}")
print(f"Inception Score: {is_mean:.4f} ± {is_std:.4f}")
print(f"Intra-FID: {intra_fid:.4f}")


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to images/cifar-100-python.tar.gz


100%|███████████████████████████████████████████████████████████████████████████| 169001437/169001437 [00:11<00:00, 15225203.08it/s]


Extracting images/cifar-100-python.tar.gz to images




Evaluating Generator...
Loaded Generator: Generator
Running on device: CUDA
FID Score: 15.8565
Inception Score: 6.0625 ± 0.7847
Intra-FID: 51.2626


In [5]:
import torch
from newtest import evaluate_generator  # 导入评估函数

# Parameters
img_size = 32
latent_dim = 128
channels = 3
num_classes = 100
generator_path = 'three_model.pth'

# Set random seed for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Evaluate the generator
fid, is_mean, is_std, intra_fid = evaluate_generator(
    generator_path=generator_path,
    generator_class=Generator,  # Replace with your Generator class
    dataset_root='images',
    img_size=img_size,
    latent_dim=latent_dim,
    num_classes=num_classes,
    channels=channels,
    #model_type='biggan'  # Specify model type
)

# Print results
print(f"Evaluating Generator...")
print(f"Loaded Generator: {Generator.__name__}")
print(f"Running on device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
print(f"FID Score: {fid:.4f}")
print(f"Inception Score: {is_mean:.4f} ± {is_std:.4f}")
print(f"Intra-FID: {intra_fid:.4f}")


Files already downloaded and verified




Evaluating Generator...
Loaded Generator: Generator
Running on device: CUDA
FID Score: 15.0727
Inception Score: 5.9629 ± 0.8792
Intra-FID: 51.5842
