In [None]:
import torch
import torch.nn as nn

# 定义RRDB块
class ResidualDenseBlock(nn.Module):
    def __init__(self, num_filters=64, growth_channel=32):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_filters, growth_channel, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(num_filters + growth_channel, growth_channel, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(num_filters + 2 * growth_channel, growth_channel, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(num_filters + 3 * growth_channel, growth_channel, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(num_filters + 4 * growth_channel, num_filters, kernel_size=3, padding=1)
    
    def forward(self, x):
        inputs = x
        x1 = torch.relu(self.conv1(x))
        x2 = torch.relu(self.conv2(torch.cat([x, x1], dim=1)))
        x3 = torch.relu(self.conv3(torch.cat([x, x1, x2], dim=1)))
        x4 = torch.relu(self.conv4(torch.cat([x, x1, x2, x3], dim=1)))
        x5 = self.conv5(torch.cat([x, x1, x2, x3, x4], dim=1))
        return x5 * 0.2 + inputs  # 使用缩放因子稳定训练

# 定义生成器
class Generator(nn.Module):
    def __init__(self, num_filters=64, num_blocks=23):
        super(Generator, self).__init__()
        self.initial = nn.Conv2d(3, num_filters, kernel_size=3, padding=1)
        self.rrdb_blocks = nn.Sequential(*[ResidualDenseBlock(num_filters) for _ in range(num_blocks)])
        self.conv_hr = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.conv_last = nn.Conv2d(num_filters, 3, kernel_size=3, padding=1)

    def forward(self, x):
        initial = self.initial(x)
        x = self.rrdb_blocks(initial)
        x = self.conv_hr(x) + initial  # 残差连接
        x = self.conv_last(x)
        return x


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        layers = []
        in_channels = 3
        for i in range(4):
            out_channels = 64 * (2 ** i)
            layers.append(nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2))
            in_channels = out_channels
        layers.append(nn.Conv2d(out_channels, 1, 4, padding=0))  # 判别器输出一个值
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


In [None]:
import torchvision.models as models

# 使用预训练的VGG模型计算感知损失
class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        vgg19 = models.vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19.features)[:18])

    def forward(self, x):
        return self.feature_extractor(x)


In [None]:
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
feature_extractor = VGGFeatureExtractor().eval()  # VGG用于感知损失

# 定义损失函数
pixel_loss = nn.MSELoss()
adversarial_loss = nn.BCEWithLogitsLoss()

# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    for lr_img, hr_img in dataloader:
        # ---------------------
        # 训练判别器
        # ---------------------
        optimizer_D.zero_grad()
        real_output = discriminator(hr_img)
        fake_img = generator(lr_img)
        fake_output = discriminator(fake_img.detach())
        
        d_loss_real = adversarial_loss(real_output, torch.ones_like(real_output))
        d_loss_fake = adversarial_loss(fake_output, torch.zeros_like(fake_output))
        d_loss = (d_loss_real + d_loss_fake) / 2
        d_loss.backward()
        optimizer_D.step()

        # ---------------------
        # 训练生成器
        # ---------------------
        optimizer_G.zero_grad()
        fake_output = discriminator(fake_img)
        
        # 像素损失
        pix_loss = pixel_loss(fake_img, hr_img)
        
        # 感知损失
        real_features = feature_extractor(hr_img)
        fake_features = feature_extractor(fake_img)
        perceptual_loss = pixel_loss(fake_features, real_features)
        
        # 对抗损失
        adv_loss = adversarial_loss(fake_output, torch.ones_like(fake_output))

        # 综合损失
        g_loss = pix_loss + 0.006 * perceptual_loss + 0.001 * adv_loss
        g_loss.backward()
        optimizer_G.step()
        
    print(f"Epoch [{epoch}/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
