一直难以解决指标很低的问题，我们重新看论文，并通过逐个板块去掉做消融实验看指标，发现在模型去掉FEM这个板块后表现尤其好，在PSNR和LPIPS两个指标上竟然超过了论文中的最终结果，这说明我们FEM板块很可能出现了问题  
接下来需要重新对FEM板块进行构建调整  

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import vgg16
from PIL import Image
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from lpips import LPIPS
import random
import numpy as np
import cv2
import OpenEXR
import Imath

hyperparameters = {
    'lr': 0.01,
    'num_epochs': 5000,
    'batch_size': 1,
    'image_size': (256, 256),
    'mean': [0.5, 0.5, 0.5],
    'std': [1, 1, 1],
    'leaky_relu_negative_slope': 0.1,
    'win_size': 11,
    'ssim_data_range': 1.0,
    'perceptual_loss_weight': 0.1,
    'mef_ssim_loss_weight': 0.1,
    'adversarial_loss_weight': 0.1,
    'global_local_contrast_loss_weight': 0.1,
}

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

class DualAttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(DualAttentionModule, self).__init__()
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.initial_relu = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.pa_conv1 = nn.Conv2d(in_channels, max(1, in_channels // 8), kernel_size=1)
        self.pa_relu1 = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.pa_bn1 = nn.BatchNorm2d(max(1, in_channels // 8))
        self.pa_conv2 = nn.Conv2d(max(1, in_channels // 8), 1, kernel_size=1)
        self.pa_sigmoid = nn.Sigmoid()
        self.ca_gap = nn.AdaptiveAvgPool2d(1)
        self.ca_conv1 = nn.Conv2d(in_channels, max(1, in_channels // 8), kernel_size=1)
        self.ca_relu1 = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.ca_conv2 = nn.Conv2d(max(1, in_channels // 8), in_channels, kernel_size=1)
        self.ca_sigmoid = nn.Sigmoid()
        self.fe_conv = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3, padding=1)
        self.fe_tanh = nn.Tanh()

    def forward(self, x):
        x2 = self.initial_conv(x)
        x2 = self.initial_relu(x2)
        pa = self.pa_conv1(x2)
        pa = self.pa_relu1(pa)
        pa = self.pa_bn1(pa)
        pa = self.pa_conv2(pa)
        pa = self.pa_sigmoid(pa)
        pa_out = x2 * pa
        ca = self.ca_gap(x2)
        ca = self.ca_conv1(ca)
        ca = self.ca_relu1(ca)
        ca = self.ca_conv2(ca)
        ca = self.ca_sigmoid(ca)
        ca_out = x2 * ca
        concat = torch.cat([pa_out, ca_out], dim=1)
        fe = self.fe_conv(concat)
        fe = self.fe_tanh(fe)
        out = fe + x + x + x2
        return out

class HDRNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(HDRNet, self).__init__()
        self.dam = DualAttentionModule(in_channels)
        self.final_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x1, x2):
        x1_dam = self.dam(x1)
        x2_dam = self.dam(x2)
        x3 = x1 * x1_dam
        x4 = x2 * x2_dam
        x_fusion = x3 + x4
        out = self.final_conv(x_fusion)
        return out

class HDRLoss(nn.Module):
    def __init__(self, vgg_model, discriminator, real_label, fake_label, criterion):
        super(HDRLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.vgg = vgg_model
        self.discriminator = discriminator
        self.real_label = real_label
        self.fake_label = fake_label
        self.criterion = criterion

    def forward(self, Ihdr, Iref, Iu, Io):
        mse_loss = self.mse_loss(Ihdr, Iref)
        Ihdr_vgg = self.vgg(Ihdr)
        Iref_vgg = self.vgg(Iref)
        perceptual_loss = F.mse_loss(Ihdr_vgg, Iref_vgg)
        mef_ssim_loss = calculate_mef_ssim(Ihdr, Iref)
        d_loss, g_loss = calculate_adversarial_loss(Ihdr, Iref, self.discriminator, self.real_label, self.fake_label, self.criterion)
        global_local_contrast_loss = calculate_global_local_contrast_loss(Ihdr, Iref, Iu, Io, self.vgg)
        total_loss = (mse_loss + 
                      hyperparameters['mef_ssim_loss_weight'] * mef_ssim_loss + 
                      hyperparameters['adversarial_loss_weight'] * g_loss + 
                      hyperparameters['perceptual_loss_weight'] * perceptual_loss + 
                      hyperparameters['global_local_contrast_loss_weight'] * global_local_contrast_loss)
        return total_loss

def calculate_mef_ssim(Ihdr, Iref):
    Ihdr_np = Ihdr.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    Iref_np = Iref.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    ssim_value = ssim(Ihdr_np, Iref_np, win_size=hyperparameters['win_size'], channel_axis=2, data_range=hyperparameters['ssim_data_range'])
    return 1 - ssim_value

class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)

def calculate_adversarial_loss(Ihdr, Iref, discriminator, real_label, fake_label, criterion):
    real_output = discriminator(Iref)
    fake_output = discriminator(Ihdr)
    
    # 创建与输出大小匹配的标签
    real_label_resized = real_label.expand_as(real_output)
    fake_label_resized = fake_label.expand_as(fake_output)
    
    real_loss = criterion(real_output, real_label_resized)
    fake_loss = criterion(fake_output, fake_label_resized)
    d_loss = (real_loss + fake_loss) / 2
    g_loss = criterion(fake_output, real_label_resized)
    return d_loss, g_loss

def calculate_global_local_contrast_loss(Ihdr, Iref, Iu, Io, vgg_model):
    def extract_features(image):
        features = vgg_model(image)
        return features
    Ihdr_features = extract_features(Ihdr)
    Iref_features = extract_features(Iref)
    Iu_features = extract_features(Iu)
    Io_features = extract_features(Io)
    global_loss = F.mse_loss(Ihdr_features, Iref_features) + F.mse_loss(Ihdr_features, Iu_features) + F.mse_loss(Ihdr_features, Io_features)
    def calculate_local_loss(Ihdr, Iref, Iu, Io, vgg_model):
        P = 4
        Ihdr_patches = Ihdr.unfold(2, P, P).unfold(3, P, P)
        Iref_patches = Iref.unfold(2, P, P).unfold(3, P, P)
        Iu_patches = Iu.unfold(2, P, P).unfold(3, P, P)
        Io_patches = Io.unfold(2, P, P).unfold(3, P, P)
        local_loss = 0
        for i in range(P):
            for j in range(P):
                Ihdr_patch = Ihdr_patches[:, :, i, j, :, :]
                Iref_patch = Iref_patches[:, :, i, j, :, :]
                Iu_patch = Iu_patches[:, :, i, j, :, :]
                Io_patch = Io_patches[:, :, i, j, :, :]
                Ihdr_features = extract_features(Ihdr_patch)
                Iref_features = extract_features(Iref_patch)
                Iu_features = extract_features(Iu_patch)
                Io_features = extract_features(Io_patch)
                local_loss += F.mse_loss(Ihdr_features, Iref_features) + F.mse_loss(Ihdr_features, Iu_features) + F.mse_loss(Ihdr_features, Io_features)
        return local_loss / (P * P)
    local_loss = calculate_local_loss(Ihdr, Iref, Iu, Io, vgg_model)
    return global_loss + local_loss

def load_image(image_path, transform):
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0)

def calculate_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def train_model(model, loss_fn, optimizer, dataloader, num_epochs, device):
    lpips_fn = LPIPS(net='alex').to(device)
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        avg_psnr = 0
        avg_ssim = 0
        avg_lpips = 0
        for batch in dataloader:
            input_image1, input_image2, ref_image = [item.to(device) for item in batch]
            output_image = model(input_image1, input_image2)
            loss = loss_fn(output_image, ref_image, input_image1, input_image2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 计算评价指标
            psnr_value = calculate_psnr(output_image, ref_image).item()
            ssim_value = ssim(output_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy(), 
                              ref_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy(), 
                              win_size=hyperparameters['win_size'], channel_axis=2, data_range=hyperparameters['ssim_data_range'])
            lpips_value = lpips_fn(output_image, ref_image).item()
            avg_psnr += psnr_value
            avg_ssim += ssim_value
            avg_lpips += lpips_value

        avg_psnr /= len(dataloader)
        avg_ssim /= len(dataloader)
        avg_lpips /= len(dataloader)

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}, PSNR: {avg_psnr:.4f}, SSIM: {avg_ssim:.4f}, LPIPS: {avg_lpips:.4f}')

def save_hdr_image_from_png(png_image_path, hdr_image_path, target_size):
    # 读取PNG图像
    png_image = cv2.imread(png_image_path, cv2.IMREAD_UNCHANGED)
    
    # 检查图像是否成功读取
    if png_image is None:
        raise ValueError(f"无法读取图像文件: {png_image_path}")
    
    # 调整图像大小
    png_image = cv2.resize(png_image, target_size, interpolation=cv2.INTER_AREA)
    
    # 归一化到0-1范围
    png_image = png_image.astype(np.float32) / 255.0
    
    # 保存为Radiance HDR格式
    cv2.imwrite(hdr_image_path, png_image)

# 创建模型和损失函数实例
vgg_model = vgg16(pretrained=True).features[:16].eval().to('cuda')
for param in vgg_model.parameters():
    param.requires_grad = False

discriminator = Discriminator(in_channels=3).to('cuda')
criterion = nn.BCELoss()
real_label = torch.ones(1, dtype=torch.float32).to('cuda')
fake_label = torch.zeros(1, dtype=torch.float32).to('cuda')

model = HDRNet(in_channels=3, out_channels=3).to('cuda')
loss_fn = HDRLoss(vgg_model=vgg_model, discriminator=discriminator, real_label=real_label, fake_label=fake_label, criterion=criterion)
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparameters['lr'])

transform = transforms.Compose([
    transforms.Resize(hyperparameters['image_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=hyperparameters['mean'], std=hyperparameters['std'])
])

set_seed(44)  

image_path1 = 'test1/1.JPG'
image_path2 = 'test1/9.JPG'
ref_image_path = 'test1/221.PNG'

input_image1 = load_image(image_path1, transform).to('cuda')
input_image2 = load_image(image_path2, transform).to('cuda')
ref_image = load_image(ref_image_path, transform).to('cuda')

dataloader = [(input_image1, input_image2, ref_image)]

train_model(model, loss_fn, optimizer, dataloader, num_epochs=hyperparameters['num_epochs'], device='cuda')

output_image = model(input_image1, input_image2)
save_hdr_image_from_png(output_image, 'generated_hdr_image.hdr', (512, 512))

torch.save(model.state_dict(), 'hdrnet_model.pth')



  from .autonotebook import tqdm as notebook_tqdm


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /opt/conda/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
Epoch [1/5000], Loss: 1.6212725639343262, PSNR: 11.8842, SSIM: 0.1234, LPIPS: 0.5739
Epoch [2/5000], Loss: 1.577733039855957, PSNR: 10.5143, SSIM: 0.1859, LPIPS: 0.5688
Epoch [3/5000], Loss: 1.5577796697616577, PSNR: 10.1860, SSIM: 0.1759, LPIPS: 0.5682
Epoch [4/5000], Loss: 1.5378882884979248, PSNR: 10.3694, SSIM: 0.1565, LPIPS: 0.5664
Epoch [5/5000], Loss: 1.5134963989257812, PSNR: 10.8854, SSIM: 0.1267, LPIPS: 0.5587
Epoch [6/5000], Loss: 1.4891283512115479, PSNR: 11.4356, SSIM: 0.0748, LPIPS: 0.5470
Epoch [7/5000], Loss: 1.455182433128357, PSNR: 11.8965, SSIM: 0.0571, LPIPS: 0.5347
Epoch [8/5000], Loss: 1.4209847450256348, PSNR: 12.1532, SSIM: 0.0416, LPIPS: 0.5269
Epoch [9/5000], Loss: 1.3990824222564697, PSNR: 12.1760, SSIM: 0.0266, LPIPS: 0.5167
Epoch [10/5000], Loss: 1.3765590190887451, PSNR: 12.0644, SSIM: 0.0298, LPIPS: 0.5070
Epoch [11/5000], Loss: 1.3653417825698853, PSNR: 11.8761, SSIM:

TypeError: Can't convert object to 'str' for 'filename'