In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import vgg16
from PIL import Image
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from lpips import LPIPS
import random
import numpy as np
import cv2

# 设置随机种子以确保实验的可重复性
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# 双注意力模块
class DualAttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(DualAttentionModule, self).__init__()
        # 初始化卷积和激活层
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.initial_relu = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        # 像素注意力机制的卷积、激活和批归一化层
        self.pa_conv1 = nn.Conv2d(in_channels, max(1, in_channels // 8), kernel_size=1)
        self.pa_relu1 = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.pa_bn1 = nn.BatchNorm2d(max(1, in_channels // 8))
        self.pa_conv2 = nn.Conv2d(max(1, in_channels // 8), 1, kernel_size=1)
        self.pa_sigmoid = nn.Sigmoid()
        # 通道注意力机制的全局平均池化和卷积、激活层
        self.ca_gap = nn.AdaptiveAvgPool2d(1)
        self.ca_conv1 = nn.Conv2d(in_channels, max(1, in_channels // 8), kernel_size=1)
        self.ca_relu1 = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.ca_conv2 = nn.Conv2d(max(1, in_channels // 8), in_channels, kernel_size=1)
        self.ca_sigmoid = nn.Sigmoid()
        # 特征融合的卷积和激活层
        self.fe_conv = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3, padding=1)
        self.fe_tanh = nn.Tanh()

    def forward(self, x):
        # 初始卷积和激活
        x2 = self.initial_conv(x)
        x2 = self.initial_relu(x2)
        # 像素注意力
        pa = self.pa_conv1(x2)
        pa = self.pa_relu1(pa)
        pa = self.pa_bn1(pa)
        pa = self.pa_conv2(pa)
        pa = self.pa_sigmoid(pa)
        pa_out = x2 * pa
        # 通道注意力
        ca = self.ca_gap(x2)
        ca = self.ca_conv1(ca)
        ca = self.ca_relu1(ca)
        ca = self.ca_conv2(ca)
        ca = self.ca_sigmoid(ca)
        ca_out = x2 * ca
        # 特征融合
        concat = torch.cat([pa_out, ca_out], dim=1)
        fe = self.fe_conv(concat)
        fe = self.fe_tanh(fe)
        out = fe + x + x + x2  # 特征增强后的输出
        return out

# 特征增强模块
class FeatureEnhancementModule(nn.Module):
    def __init__(self, in_channels):
        super(FeatureEnhancementModule, self).__init__()

        # 定义卷积块
        def conv_block(in_channels, out_channels, pool=True):
            layers = [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope']),
                nn.BatchNorm2d(out_channels)
            ]
            if pool:
                layers.append(nn.MaxPool2d(2))
            return nn.Sequential(*layers)

        # 定义膨胀卷积块
        def dilated_conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=2, dilation=2),
                nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope']),
                nn.BatchNorm2d(out_channels)
            )

        # 细节补充分支（FD）
        self.fd_conv1 = conv_block(in_channels, in_channels, pool=False)
        self.fd_conv2 = conv_block(in_channels, in_channels, pool=False)
        self.fd_downsample1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        self.fd_conv3 = conv_block(in_channels, in_channels, pool=False)
        self.fd_conv4 = conv_block(in_channels, in_channels, pool=False)
        self.fd_downsample2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        self.fd_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.fd_relu5 = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.fd_bn5 = nn.BatchNorm2d(in_channels)
        self.fd_dilated_conv1 = dilated_conv_block(in_channels, in_channels)
        self.fd_dilated_conv2 = dilated_conv_block(in_channels, in_channels)
        self.fd_upsample1 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.fd_upsample2 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.fd_tanh = nn.Tanh()

        # 颜色校正分支（FC）
        self.fc_conv1 = conv_block(in_channels, in_channels, pool=False)
        self.fc_conv2 = conv_block(in_channels, in_channels, pool=False)
        self.fc_downsample1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        self.fc_conv3 = conv_block(in_channels, in_channels, pool=False)
        self.fc_conv4 = conv_block(in_channels, in_channels, pool=False)
        self.fc_downsample2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        self.fc_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.fc_relu5 = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.fc_bn5 = nn.BatchNorm2d(in_channels)
        self.fc_dilated_conv1 = dilated_conv_block(in_channels, in_channels)
        self.fc_dilated_conv2 = dilated_conv_block(in_channels, in_channels)
        self.fc_upsample1 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.fc_upsample2 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.fc_tanh = nn.Tanh()
        self.fc_sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 细节补充分支（FD）
        fd = self.fd_conv1(x)
        fd = self.fd_conv2(fd)
        fd = self.fd_downsample1(fd)
        fd = self.fd_conv3(fd)
        fd = self.fd_conv4(fd)
        fd = self.fd_downsample2(fd)
        fd = self.fd_conv5(fd)
        fd = self.fd_relu5(fd)
        fd = self.fd_bn5(fd)

        # 对原图像进行两次下采样
        x_downsample1 = F.avg_pool2d(x, kernel_size=2, stride=2)
        x_downsample2 = F.avg_pool2d(x_downsample1, kernel_size=2, stride=2)

        # 与两次下采样的原图像相乘
        fd = fd * x_downsample2

        # 上采样
        fd = self.fd_upsample1(fd)
        fd = self.fd_dilated_conv1(fd)
        fd = self.fd_dilated_conv2(fd)

        # 与一次下采样的原图像相乘
        fd = fd * x_downsample1

        # 上采样
        fd = self.fd_upsample2(fd)
        fd = self.fd_dilated_conv1(fd)
        fd = self.fd_dilated_conv2(fd)

        # 与原图像相乘
        fd_out = fd * x
        fd_out = self.fd_tanh(fd_out)

        # 颜色校正分支（FC）
        fc = self.fc_conv1(x)
        fc = self.fc_conv2(fc)
        fc = self.fc_downsample1(fc)
        fc = self.fc_conv3(fc)
        fc = self.fc_conv4(fc)
        fc = self.fc_downsample2(fc)
        fc = self.fc_conv5(fc)
        fc = self.fc_relu5(fc)
        fc = self.fc_bn5(fc)

        # 与两次下采样的原图像相乘
        fc = fc * x_downsample2

        # 上采样
        fc = self.fc_upsample1(fc)
        fc = self.fc_dilated_conv1(fc)
        fc = self.fc_dilated_conv2(fc)

        # 与一次下采样的原图像相乘
        fc = fc * x_downsample1

        # 上采样
        fc = self.fc_upsample2(fc)
        fc = self.fc_dilated_conv1(fc)
        fc = self.fc_dilated_conv2(fc)

        # 与原图像相乘
        fc_out = fc * x
        fc_out = self.fc_tanh(fc_out)
        fc_out = self.fc_sigmoid(fc_out)

        # 最终融合
        out = fd_out + fc_out
        return out

# 高动态范围（HDR）网络
class HDRNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(HDRNet, self).__init__()
        self.dam = DualAttentionModule(in_channels)  # 双注意力模块
        self.fem = FeatureEnhancementModule(in_channels)  # 特征增强模块
        self.final_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # 最终卷积层

    def forward(self, x1, x2):
        x1_dam = self.dam(x1)  # 对输入1应用双注意力模块
        x2_dam = self.dam(x2)  # 对输入2应用双注意力模块
        x3 = x1 * x1_dam  # 输入1与其双注意力结果相乘
        x4 = x2 * x2_dam  # 输入2与其双注意力结果相乘
        x_fusion = x3 + x4  # 融合两个输入的结果
        x_enhanced = self.fem(x_fusion)  # 对融合结果应用特征增强模块
        out = self.final_conv(x_enhanced)  # 最终卷积生成输出
        return out

# HDR 损失函数
class HDRLoss(nn.Module):
    def __init__(self, vgg_model, discriminator, real_label, fake_label, criterion):
        super(HDRLoss, self).__init__()
        self.mse_loss = nn.MSELoss()  # 均方误差损失
        self.vgg = vgg_model  # 预训练的 VGG 模型
        self.discriminator = discriminator  # 判别器
        self.real_label = real_label  # 真实标签
        self.fake_label = fake_label  # 假标签
        self.criterion = criterion  # 损失准则

    def forward(self, Ihdr, Iref, Iu, Io):
        mse_loss = self.mse_loss(Ihdr, Iref)  # 计算均方误差损失
        Ihdr_vgg = self.vgg(Ihdr)  # 提取 HDR 图像的 VGG 特征
        Iref_vgg = self.vgg(Iref)  # 提取参考图像的 VGG 特征
        perceptual_loss = F.mse_loss(Ihdr_vgg, Iref_vgg)  # 计算感知损失
        mef_ssim_loss = calculate_mef_ssim(Ihdr, Iref)  # 计算 MEF-SSIM 损失
        d_loss, g_loss = calculate_adversarial_loss(Ihdr, Iref, self.discriminator, self.real_label, self.fake_label, self.criterion)  # 计算对抗损失
        global_local_contrast_loss = calculate_global_local_contrast_loss(Ihdr, Iref, Iu, Io, self.vgg)  # 计算全局局部对比损失
        total_loss = (mse_loss + 
                      hyperparameters['mef_ssim_loss_weight'] * mef_ssim_loss + 
                      hyperparameters['adversarial_loss_weight'] * g_loss + 
                      hyperparameters['perceptual_loss_weight'] * perceptual_loss + 
                      hyperparameters['global_local_contrast_loss_weight'] * global_local_contrast_loss)  # 总损失
        return total_loss

# 计算 MEF-SSIM 损失
def calculate_mef_ssim(Ihdr, Iref):
    Ihdr_np = Ihdr.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()  # 转换 HDR 图像为 numpy 格式
    Iref_np = Iref.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()  # 转换参考图像为 numpy 格式
    ssim_value = ssim(Ihdr_np, Iref_np, win_size=hyperparameters['win_size'], channel_axis=2, data_range=hyperparameters['ssim_data_range'])  # 计算 SSIM 值
    return 1 - ssim_value  # 返回 1 减去 SSIM 值作为损失

# 判别器模型
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)  # 判别器输出

# 计算对抗损失
def calculate_adversarial_loss(Ihdr, Iref, discriminator, real_label, fake_label, criterion):
    real_output = discriminator(Iref)  # 判别器对真实图像的输出
    fake_output = discriminator(Ihdr)  # 判别器对假图像的输出
    
    # 创建与输出大小匹配的标签
    real_label_resized = real_label.expand_as(real_output)
    fake_label_resized = fake_label.expand_as(fake_output)
    
    real_loss = criterion(real_output, real_label_resized)  # 计算真实图像的损失
    fake_loss = criterion(fake_output, fake_label_resized)  # 计算假图像的损失
    d_loss = (real_loss + fake_loss) / 2  # 判别器损失
    g_loss = criterion(fake_output, real_label_resized)  # 生成器损失
    return d_loss, g_loss

# 计算全局局部对比损失
def calculate_global_local_contrast_loss(Ihdr, Iref, Iu, Io, vgg_model):
    def extract_features(image):
        features = vgg_model(image)  # 提取图像特征
        return features
    Ihdr_features = extract_features(Ihdr)  # 提取 HDR 图像的特征
    Iref_features = extract_features(Iref)  # 提取参考图像的特征
    Iu_features = extract_features(Iu)  # 提取未增强图像的特征
    Io_features = extract_features(Io)  # 提取输出图像的特征
    global_loss = F.mse_loss(Ihdr_features, Iref_features) + F.mse_loss(Ihdr_features, Iu_features) + F.mse_loss(Ihdr_features, Io_features)  # 计算全局损失
    def calculate_local_loss(Ihdr, Iref, Iu, Io, vgg_model):
        P = 4
        Ihdr_patches = Ihdr.unfold(2, P, P).unfold(3, P, P)  # 切分 HDR 图像为补丁
        Iref_patches = Iref.unfold(2, P, P).unfold(3, P, P)  # 切分参考图像为补丁
        Iu_patches = Iu.unfold(2, P, P).unfold(3, P, P)  # 切分未增强图像为补丁
        Io_patches = Io.unfold(2, P, P).unfold(3, P, P)  # 切分输出图像为补丁
        local_loss = 0
        for i in range(P):
            for j in range(P):
                Ihdr_patch = Ihdr_patches[:, :, i, j, :, :]
                Iref_patch = Iref_patches[:, :, i, j, :, :]
                Iu_patch = Iu_patches[:, :, i, j, :, :]
                Io_patch = Io_patches[:, :, i, j, :, :]
                Ihdr_features = extract_features(Ihdr_patch)  # 提取 HDR 补丁的特征
                Iref_features = extract_features(Iref_patch)  # 提取参考补丁的特征
                Iu_features = extract_features(Iu_patch)  # 提取未增强补丁的特征
                Io_features = extract_features(Io_patch)  # 提取输出补丁的特征
                local_loss += F.mse_loss(Ihdr_features, Iref_features) + F.mse_loss(Ihdr_features, Iu_features) + F.mse_loss(Ihdr_features, Io_features)  # 计算局部损失
        return local_loss / (P * P)
    local_loss = calculate_local_loss(Ihdr, Iref, Iu, Io, vgg_model)  # 计算局部损失
    return global_loss + local_loss  # 返回总损失

# 加载图像并应用预处理
def load_image(image_path, transform):
    image = Image.open(image_path).convert('RGB')  # 打开图像并转换为 RGB 格式
    return transform(image).unsqueeze(0)  # 应用变换并添加批次维度

# 计算 PSNR
def calculate_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)  # 计算均方误差
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))  # 计算 PSNR 值

# 训练模型
def train_model(model, loss_fn, optimizer, dataloader, num_epochs, device):
    lpips_fn = LPIPS(net='alex').to(device)  # LPIPS 评估指标
    model.to(device)  # 将模型移动到设备（GPU 或 CPU）
    for epoch in range(num_epochs):
        model.train()  # 设定模型为训练模式
        avg_psnr = 0
        avg_ssim = 0
        avg_lpips = 0
        for batch in dataloader:
            input_image1, input_image2, ref_image = [item.to(device) for item in batch]  # 获取输入和参考图像
            output_image = model(input_image1, input_image2)  # 生成输出图像
            loss = loss_fn(output_image, ref_image, input_image1, input_image2)  # 计算损失
            optimizer.zero_grad()  # 清零梯度
            loss.backward()  # 反向传播
            optimizer.step()  # 更新权重

            # 计算评价指标
            psnr_value = calculate_psnr(output_image, ref_image).item()  # 计算 PSNR
            ssim_value = ssim(output_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy(), 
                              ref_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy(), 
                              win_size=hyperparameters['win_size'], channel_axis=2, data_range=hyperparameters['ssim_data_range'])  # 计算 SSIM
            lpips_value = lpips_fn(output_image, ref_image).item()  # 计算 LPIPS
            avg_psnr += psnr_value
            avg_ssim += ssim_value
            avg_lpips += lpips_value

        avg_psnr /= len(dataloader)
        avg_ssim /= len(dataloader)
        avg_lpips /= len(dataloader)

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}, PSNR: {avg_psnr:.4f}, SSIM: {avg_ssim:.4f}, LPIPS: {avg_lpips:.4f}')

# 创建模型和损失函数实例
vgg_model = vgg16(pretrained=True).features[:16].eval().to('cuda')  # 加载预训练的 VGG 模型
for param in vgg_model.parameters():
    param.requires_grad = False

discriminator = Discriminator(in_channels=3).to('cuda')  # 创建判别器
criterion = nn.BCELoss()  # 定义二分类交叉熵损失
real_label = torch.ones(1, dtype=torch.float32).to('cuda')  # 定义真实标签
fake_label = torch.zeros(1, dtype=torch.float32).to('cuda')  # 定义假标签

model = HDRNet(in_channels=3, out_channels=3).to('cuda')  # 创建 HDR 网络
loss_fn = HDRLoss(vgg_model=vgg_model, discriminator=discriminator, real_label=real_label, fake_label=fake_label, criterion=criterion)  # 定义损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparameters['lr'])  # 定义优化器

# 定义超参数
hyperparameters = {
    'lr': 0.01,  # 学习率
    'num_epochs': 5000,  # 训练轮数
    'batch_size': 1,  # 批处理大小
    'image_size': (256, 256),  # 图像尺寸
    'mean': [0.5, 0.5, 0.5],  # 图像均值
    'std': [1, 1, 1],  # 图像标准差
    'leaky_relu_negative_slope': 0.1,  # Leaky ReLU 的负斜率
    'win_size': 11,  # SSIM 计算的窗口大小
    'ssim_data_range': 1.0,  # SSIM 计算的数据范围
    'perceptual_loss_weight': 0.1,  # 感知损失的权重
    'mef_ssim_loss_weight': 0.1,  # MEF-SSIM 损失的权重
    'adversarial_loss_weight': 0.1,  # 对抗性损失的权重
    'global_local_contrast_loss_weight': 0.1,  # 全局局部对比损失的权重
}

# 定义图像变换
transform = transforms.Compose([
    transforms.Resize(hyperparameters['image_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=hyperparameters['mean'], std=hyperparameters['std'])
])

set_seed(44)  # 设置随机种子

# 加载图像
image_path1 = 'test1/1.JPG'
image_path2 = 'test1/9.JPG'
ref_image_path = 'test1/221.PNG'

input_image1 = load_image(image_path1, transform).to('cuda')
input_image2 = load_image(image_path2, transform).to('cuda')
ref_image = load_image(ref_image_path, transform).to('cuda')

# 创建数据加载器
dataloader = [(input_image1, input_image2, ref_image)]

# 训练模型
train_model(model, loss_fn, optimizer, dataloader, num_epochs=hyperparameters['num_epochs'], device='cuda')

# 保存模型参数
torch.save(model.state_dict(), 'hdrnet_model.pth')





Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /opt/conda/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
Epoch [1/5000], Loss: 2.937023639678955, PSNR: 4.5649, SSIM: 0.2299, LPIPS: 0.8335
Epoch [2/5000], Loss: 2.7329211235046387, PSNR: 5.0404, SSIM: 0.2282, LPIPS: 0.8324
Epoch [3/5000], Loss: 2.572195291519165, PSNR: 5.4584, SSIM: 0.2262, LPIPS: 0.8290
Epoch [4/5000], Loss: 2.4299521446228027, PSNR: 5.8592, SSIM: 0.2214, LPIPS: 0.8240
Epoch [5/5000], Loss: 2.2906250953674316, PSNR: 6.2861, SSIM: 0.2096, LPIPS: 0.8167
Epoch [6/5000], Loss: 2.144871473312378, PSNR: 6.7517, SSIM: 0.1913, LPIPS: 0.8039
Epoch [7/5000], Loss: 2.011521816253662, PSNR: 7.2443, SSIM: 0.1688, LPIPS: 0.7845
Epoch [8/5000], Loss: 1.8643925189971924, PSNR: 7.7529, SSIM: 0.1445, LPIPS: 0.7598
Epoch [9/5000], Loss: 1.7182223796844482, PSNR: 8.2729, SSIM: 0.1217, LPIPS: 0.7319
Epoch [10/5000], Loss: 1.5933759212493896, PSNR: 8.7843, SSIM: 0.1088, LPIPS: 0.7038
Epoch [11/5000], Loss: 1.489457368850708, PSNR: 9.2680, SSIM: 0.1087, LPIPS