完成论文复现的完整模型后，下面是经过几天调参得到的最优结果，在lr设定为0.01时，模型在5000-6000轮处达到最优，后面效果开始变差，其中loss最小稳定在1.0，PSNR在14.6，SSIM在0.45，LPIPS在0.2，远远没达到论文里的要求  
接下来就要检查出了什么问题，可能是参数的问题，可能是代码的问题

In [135]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import vgg16
from PIL import Image
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from lpips import LPIPS
import random
import numpy as np

#from hyperparameters import hyperparameters  # 导入超参数

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False



# DAM
class DualAttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(DualAttentionModule, self).__init__()
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.initial_relu = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.pa_conv1 = nn.Conv2d(in_channels, max(1, in_channels // 8), kernel_size=1)
        self.pa_relu1 = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.pa_bn1 = nn.BatchNorm2d(max(1, in_channels // 8))
        self.pa_conv2 = nn.Conv2d(max(1, in_channels // 8), 1, kernel_size=1)
        self.pa_sigmoid = nn.Sigmoid()
        self.ca_gap = nn.AdaptiveAvgPool2d(1)
        self.ca_conv1 = nn.Conv2d(in_channels, max(1, in_channels // 8), kernel_size=1)
        self.ca_relu1 = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.ca_conv2 = nn.Conv2d(max(1, in_channels // 8), in_channels, kernel_size=1)
        self.ca_sigmoid = nn.Sigmoid()
        self.fe_conv = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3, padding=1)
        self.fe_tanh = nn.Tanh()

    def forward(self, x):
        x2 = self.initial_conv(x)
        x2 = self.initial_relu(x2)
        pa = self.pa_conv1(x2)
        pa = self.pa_relu1(pa)
        pa = self.pa_bn1(pa)
        pa = self.pa_conv2(pa)
        pa = self.pa_sigmoid(pa)
        pa_out = x2 * pa
        ca = self.ca_gap(x2)
        ca = self.ca_conv1(ca)
        ca = self.ca_relu1(ca)
        ca = self.ca_conv2(ca)
        ca = self.ca_sigmoid(ca)
        ca_out = x2 * ca
        concat = torch.cat([pa_out, ca_out], dim=1)
        fe = self.fe_conv(concat)
        fe = self.fe_tanh(fe)
        out = fe + x + x + x2
        return out

class FeatureEnhancementModule(nn.Module):
    def __init__(self, in_channels):
        super(FeatureEnhancementModule, self).__init__()

        def conv_block(in_channels, out_channels, pool=True):
            layers = [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope']),
                nn.BatchNorm2d(out_channels)
            ]
            if pool:
                layers.append(nn.MaxPool2d(2))
            return nn.Sequential(*layers)

        def dilated_conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=2, dilation=2),
                nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope']),
                nn.BatchNorm2d(out_channels)
            )

        self.fd_conv1 = conv_block(in_channels, in_channels, pool=False)
        self.fd_conv2 = conv_block(in_channels, in_channels, pool=False)
        self.fd_downsample1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        self.fd_conv3 = conv_block(in_channels, in_channels, pool=False)
        self.fd_conv4 = conv_block(in_channels, in_channels, pool=False)
        self.fd_downsample2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        self.fd_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.fd_relu5 = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.fd_bn5 = nn.BatchNorm2d(in_channels)
        self.fd_dilated_conv1 = dilated_conv_block(in_channels, in_channels)
        self.fd_dilated_conv2 = dilated_conv_block(in_channels, in_channels)
        self.fd_upsample1 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.fd_upsample2 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.fd_tanh = nn.Tanh()

        self.fc_conv1 = conv_block(in_channels, in_channels, pool=False)
        self.fc_conv2 = conv_block(in_channels, in_channels, pool=False)
        self.fc_downsample1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        self.fc_conv3 = conv_block(in_channels, in_channels, pool=False)
        self.fc_conv4 = conv_block(in_channels, in_channels, pool=False)
        self.fc_downsample2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        self.fc_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.fc_relu5 = nn.LeakyReLU(hyperparameters['leaky_relu_negative_slope'])
        self.fc_bn5 = nn.BatchNorm2d(in_channels)
        self.fc_dilated_conv1 = dilated_conv_block(in_channels, in_channels)
        self.fc_dilated_conv2 = dilated_conv_block(in_channels, in_channels)
        self.fc_upsample1 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.fc_upsample2 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.fc_tanh = nn.Tanh()
        self.fc_sigmoid = nn.Sigmoid()

    def forward(self, x):
        fd = self.fd_conv1(x)
        fd = self.fd_conv2(fd)
        fd = self.fd_downsample1(fd)
        fd = self.fd_conv3(fd)
        fd = self.fd_conv4(fd)
        fd = self.fd_downsample2(fd)
        fd = self.fd_conv5(fd)
        fd = self.fd_relu5(fd)
        fd = self.fd_bn5(fd)
        fd = self.fd_dilated_conv1(fd)
        fd = self.fd_dilated_conv2(fd)
        fd = self.fd_upsample1(fd)
        fd = self.fd_upsample2(fd)
        fd = self.fd_tanh(fd)
        fd_out = fd + x

        fc = self.fc_conv1(x)
        fc = self.fc_conv2(fc)
        fc = self.fc_downsample1(fc)
        fc = self.fc_conv3(fc)
        fc = self.fc_conv4(fc)
        fc = self.fc_downsample2(fc)
        fc = self.fc_conv5(fc)
        fc = self.fc_relu5(fc)
        fc = self.fc_bn5(fc)
        fc = self.fc_dilated_conv1(fc)
        fc = self.fc_dilated_conv2(fc)
        fc = self.fc_upsample1(fc)
        fc = self.fc_upsample2(fc)
        fc = self.fc_tanh(fc)
        fc = self.fc_sigmoid(fc)
        fc_out = fc / x

        out = fd_out + fc_out
        return out

# 定义 HDRNet 模型
class HDRNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(HDRNet, self).__init__()
        self.dam = DualAttentionModule(in_channels)
        self.fem = FeatureEnhancementModule(in_channels)
        self.final_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x1, x2):
        x1_dam = self.dam(x1)
        x2_dam = self.dam(x2)
        x3 = x1 * x1_dam
        x4 = x2 * x2_dam
        x_fusion = x3 + x4
        x_enhanced = self.fem(x_fusion)
        out = self.final_conv(x_enhanced)
        return out

# HDRLoss
class HDRLoss(nn.Module):
    def __init__(self, vgg_model, discriminator, real_label, fake_label, criterion):
        super(HDRLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.vgg = vgg_model
        self.discriminator = discriminator
        self.real_label = real_label
        self.fake_label = fake_label
        self.criterion = criterion

    def forward(self, Ihdr, Iref, Iu, Io):
        mse_loss = self.mse_loss(Ihdr, Iref)
        Ihdr_vgg = self.vgg(Ihdr)
        Iref_vgg = self.vgg(Iref)
        perceptual_loss = F.mse_loss(Ihdr_vgg, Iref_vgg)
        mef_ssim_loss = calculate_mef_ssim(Ihdr, Iref)
        d_loss, g_loss = calculate_adversarial_loss(Ihdr, Iref, self.discriminator, self.real_label, self.fake_label, self.criterion)
        global_local_contrast_loss = calculate_global_local_contrast_loss(Ihdr, Iref, Iu, Io, self.vgg)
        total_loss = (mse_loss + 
                      hyperparameters['mef_ssim_loss_weight'] * mef_ssim_loss + 
                      hyperparameters['adversarial_loss_weight'] * g_loss + 
                      hyperparameters['perceptual_loss_weight'] * perceptual_loss + 
                      hyperparameters['global_local_contrast_loss_weight'] * global_local_contrast_loss)
        return total_loss

def calculate_mef_ssim(Ihdr, Iref):
    Ihdr_np = Ihdr.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    Iref_np = Iref.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    ssim_value = ssim(Ihdr_np, Iref_np, win_size=hyperparameters['win_size'], channel_axis=2, data_range=hyperparameters['ssim_data_range'])
    return 1 - ssim_value

class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)

def calculate_adversarial_loss(Ihdr, Iref, discriminator, real_label, fake_label, criterion):
    real_output = discriminator(Iref)
    fake_output = discriminator(Ihdr)
    
    # 创建与输出大小匹配的标签
    real_label_resized = real_label.expand_as(real_output)
    fake_label_resized = fake_label.expand_as(fake_output)
    
    real_loss = criterion(real_output, real_label_resized)
    fake_loss = criterion(fake_output, fake_label_resized)
    d_loss = (real_loss + fake_loss) / 2
    g_loss = criterion(fake_output, real_label_resized)
    return d_loss, g_loss

def calculate_global_local_contrast_loss(Ihdr, Iref, Iu, Io, vgg_model):
    def extract_features(image):
        features = vgg_model(image)
        return features
    Ihdr_features = extract_features(Ihdr)
    Iref_features = extract_features(Iref)
    Iu_features = extract_features(Iu)
    Io_features = extract_features(Io)
    global_loss = F.mse_loss(Ihdr_features, Iref_features) + F.mse_loss(Ihdr_features, Iu_features) + F.mse_loss(Ihdr_features, Io_features)
    def calculate_local_loss(Ihdr, Iref, Iu, Io, vgg_model):
        P = 4
        Ihdr_patches = Ihdr.unfold(2, P, P).unfold(3, P, P)
        Iref_patches = Iref.unfold(2, P, P).unfold(3, P, P)
        Iu_patches = Iu.unfold(2, P, P).unfold(3, P, P)
        Io_patches = Io.unfold(2, P, P).unfold(3, P, P)
        local_loss = 0
        for i in range(P):
            for j in range(P):
                Ihdr_patch = Ihdr_patches[:, :, i, j, :, :]
                Iref_patch = Iref_patches[:, :, i, j, :, :]
                Iu_patch = Iu_patches[:, :, i, j, :, :]
                Io_patch = Io_patches[:, :, i, j, :, :]
                Ihdr_features = extract_features(Ihdr_patch)
                Iref_features = extract_features(Iref_patch)
                Iu_features = extract_features(Iu_patch)
                Io_features = extract_features(Io_patch)
                local_loss += F.mse_loss(Ihdr_features, Iref_features) + F.mse_loss(Ihdr_features, Iu_features) + F.mse_loss(Ihdr_features, Io_features)
        return local_loss / (P * P)
    local_loss = calculate_local_loss(Ihdr, Iref, Iu, Io, vgg_model)
    return global_loss + local_loss

def load_image(image_path, transform):
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0)

def calculate_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def train_model(model, loss_fn, optimizer, dataloader, num_epochs, device):
    lpips_fn = LPIPS(net='alex').to(device)
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        avg_psnr = 0
        avg_ssim = 0
        avg_lpips = 0
        for batch in dataloader:
            input_image1, input_image2, ref_image = [item.to(device) for item in batch]
            output_image = model(input_image1, input_image2)
            loss = loss_fn(output_image, ref_image, input_image1, input_image2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 计算评价指标
            psnr_value = calculate_psnr(output_image, ref_image).item()
            ssim_value = ssim(output_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy(), 
                              ref_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy(), 
                              win_size=hyperparameters['win_size'], channel_axis=2, data_range=hyperparameters['ssim_data_range'])
            lpips_value = lpips_fn(output_image, ref_image).item()
            avg_psnr += psnr_value
            avg_ssim += ssim_value
            avg_lpips += lpips_value

        avg_psnr /= len(dataloader)
        avg_ssim /= len(dataloader)
        avg_lpips /= len(dataloader)

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}, PSNR: {avg_psnr:.4f}, SSIM: {avg_ssim:.4f}, LPIPS: {avg_lpips:.4f}')

# 创建模型和损失函数实例
vgg_model = vgg16(pretrained=True).features[:16].eval().to('cuda')
for param in vgg_model.parameters():
    param.requires_grad = False

discriminator = Discriminator(in_channels=3).to('cuda')
criterion = nn.BCELoss()
real_label = torch.ones(1, dtype=torch.float32).to('cuda')
fake_label = torch.zeros(1, dtype=torch.float32).to('cuda')

model = HDRNet(in_channels=3, out_channels=3).to('cuda')
loss_fn = HDRLoss(vgg_model=vgg_model, discriminator=discriminator, real_label=real_label, fake_label=fake_label, criterion=criterion)
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparameters['lr'])

transform = transforms.Compose([
    transforms.Resize(hyperparameters['image_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=hyperparameters['mean'], std=hyperparameters['std'])
])

hyperparameters = {
    'lr': 0.01,
    'num_epochs': 8000,
    'batch_size': 1,
    'image_size': (256, 256),
    'mean': [0.5, 0.5, 0.5],
    'std': [0.5, 0.5, 0.5],
    'leaky_relu_negative_slope': 0.1,
    'win_size': 11,
    'ssim_data_range': 1.0,
    'perceptual_loss_weight': 0.1,
    'mef_ssim_loss_weight': 0.1,
    'adversarial_loss_weight': 0.1,
    'global_local_contrast_loss_weight': 0.1,
}

set_seed(44)  
#44
#45best
#53

image_path1 = 'test1/1.JPG'
image_path2 = 'test1/9.JPG'
ref_image_path = 'test1/221.PNG'

input_image1 = load_image(image_path1, transform).to('cuda')
input_image2 = load_image(image_path2, transform).to('cuda')
ref_image = load_image(ref_image_path, transform).to('cuda')

dataloader = [(input_image1, input_image2, ref_image)]

train_model(model, loss_fn, optimizer, dataloader, num_epochs=hyperparameters['num_epochs'], device='cuda')

torch.save(model.state_dict(), 'hdrnet_model.pth')













Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /opt/conda/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
Epoch [1/8000], Loss: 3.9480175971984863, PSNR: 1.7638, SSIM: 0.0759, LPIPS: 0.8026
Epoch [2/8000], Loss: 3.411799430847168, PSNR: 2.8851, SSIM: 0.0598, LPIPS: 0.7801
Epoch [3/8000], Loss: 3.023883819580078, PSNR: 4.2117, SSIM: 0.0332, LPIPS: 0.7636
Epoch [4/8000], Loss: 2.7901253700256348, PSNR: 5.2131, SSIM: 0.0331, LPIPS: 0.7144
Epoch [5/8000], Loss: 2.7034807205200195, PSNR: 6.0118, SSIM: 0.0537, LPIPS: 0.6255
Epoch [6/8000], Loss: 2.704489231109619, PSNR: 6.4672, SSIM: 0.0922, LPIPS: 0.5461
Epoch [7/8000], Loss: 2.66869854927063, PSNR: 6.5960, SSIM: 0.1063, LPIPS: 0.5147
Epoch [8/8000], Loss: 2.617672920227051, PSNR: 6.5193, SSIM: 0.1013, LPIPS: 0.5103
Epoch [9/8000], Loss: 2.5771703720092773, PSNR: 6.3185, SSIM: 0.0881, LPIPS: 0.5179
Epoch [10/8000], Loss: 2.5575568675994873, PSNR: 6.0436, SSIM: 0.0766, LPIPS: 0.5296
Epoch