In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import vgg16
from PIL import Image
import matplotlib.pyplot as plt

# DAM
class DualAttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(DualAttentionModule, self).__init__()
        # 输入图像的初始处理
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.initial_relu = nn.LeakyReLU(0.1)
        
        # 像素注意力模块
        self.pa_conv1 = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.pa_relu1 = nn.LeakyReLU(0.1)
        self.pa_bn1 = nn.BatchNorm2d(in_channels // 8)
        self.pa_conv2 = nn.Conv2d(in_channels // 8, 1, kernel_size=1)
        self.pa_sigmoid = nn.Sigmoid()

        # 通道注意力模块
        self.ca_gap = nn.AdaptiveAvgPool2d(1)
        self.ca_conv1 = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.ca_relu1 = nn.LeakyReLU(0.1)
        self.ca_conv2 = nn.Conv2d(in_channels // 8, in_channels, kernel_size=1)
        self.ca_sigmoid = nn.Sigmoid()

        # 特征增强模块
        self.fe_conv = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3, padding=1)
        self.fe_tanh = nn.Tanh()

    def forward(self, x):
        # 初始处理
        x2 = self.initial_conv(x)
        x2 = self.initial_relu(x2)
        
        # 像素注意力
        pa = self.pa_conv1(x2)
        pa = self.pa_relu1(pa)
        pa = self.pa_bn1(pa)
        pa = self.pa_conv2(pa)
        pa = self.pa_sigmoid(pa)
        pa_out = x2 * pa
        
        # 通道注意力
        ca = self.ca_gap(x2)
        ca = self.ca_conv1(ca)
        ca = self.ca_relu1(ca)
        ca = self.ca_conv2(ca)
        ca = self.ca_sigmoid(ca)
        ca_out = x2 * ca
        
        # 拼接PA和CA的输出
        concat = torch.cat([pa_out, ca_out], dim=1)
        
        # 最后两层
        fe = self.fe_conv(concat)
        fe = self.fe_tanh(fe)
        
        # 输出
        out = fe + x + x + x2
        
        return out

# FEM
class FeatureEnhancementModule(nn.Module):
    def __init__(self, in_channels):
        super(FeatureEnhancementModule, self).__init__()

        def conv_block(in_channels, out_channels, pool=True):
            layers = [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.LeakyReLU(0.1),
                nn.BatchNorm2d(out_channels)
            ]
            if pool:
                layers.append(nn.MaxPool2d(2))
            return nn.Sequential(*layers)
        
        def dilated_conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=2, dilation=2),
                nn.LeakyReLU(0.1),
                nn.BatchNorm2d(out_channels)
            )

        self.fd_conv1 = conv_block(in_channels, in_channels, pool=True)
        self.fd_conv2 = conv_block(in_channels, in_channels, pool=True)
        self.fd_downsample1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        
        self.fd_conv3 = conv_block(in_channels, in_channels, pool=True)
        self.fd_conv4 = conv_block(in_channels, in_channels, pool=True)
        self.fd_downsample2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        
        self.fd_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.fd_relu5 = nn.LeakyReLU(0.1)
        self.fd_bn5 = nn.BatchNorm2d(in_channels)

        self.fd_dilated_conv1 = dilated_conv_block(in_channels, in_channels)
        self.fd_dilated_conv2 = dilated_conv_block(in_channels, in_channels)

        self.fd_tanh = nn.Tanh()

        self.fc_conv1 = conv_block(in_channels, in_channels, pool=True)
        self.fc_conv2 = conv_block(in_channels, in_channels, pool=True)
        self.fc_downsample1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        
        self.fc_conv3 = conv_block(in_channels, in_channels, pool=True)
        self.fc_conv4 = conv_block(in_channels, in_channels, pool=True)
        self.fc_downsample2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        
        self.fc_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.fc_relu5 = nn.LeakyReLU(0.1)
        self.fc_bn5 = nn.BatchNorm2d(in_channels)

        self.fc_dilated_conv1 = dilated_conv_block(in_channels, in_channels)
        self.fc_dilated_conv2 = dilated_conv_block(in_channels, in_channels)

        self.fc_tanh = nn.Tanh()
        self.fc_sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 细节补充分支（FD）
        fd = self.fd_conv1(x)
        fd = self.fd_conv2(fd)
        fd = self.fd_downsample1(fd)
        
        fd = self.fd_conv3(fd)
        fd = self.fd_conv4(fd)
        fd = self.fd_downsample2(fd)
        
        fd = self.fd_conv5(fd)
        fd = self.fd_relu5(fd)
        fd = self.fd_bn5(fd)
        
        fd = self.fd_dilated_conv1(fd)
        fd = self.fd_dilated_conv2(fd)
        
        fd = self.fd_tanh(fd)
        fd_out = fd + x

        # 颜色校正分支（FC）
        fc = self.fc_conv1(x)
        fc = self.fc_conv2(fc)
        fc = self.fc_downsample1(fc)
        
        fc = self.fc_conv3(fc)
        fc = self.fc_conv4(fc)
        fc = self.fc_downsample2(fc)
        
        fc = self.fc_conv5(fc)
        fc = self.fc_relu5(fc)
        fc = self.fc_bn5(fc)
        
        fc = self.fc_dilated_conv1(fc)
        fc = self.fc_dilated_conv2(fc)
        
        fc = self.fc_tanh(fc)
        fc = self.fc_sigmoid(fc)
        fc_out = fc / x

        # 最终融合
        out = fd_out + fc_out
        return out

# 定义 HDRNet 模型
class HDRNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(HDRNet, self).__init__()
        self.dam = DualAttentionModule(in_channels)
        self.fem = FeatureEnhancementModule(in_channels)
        self.final_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x1, x2):
        # 欠曝光（记作1）
        x1_dam = self.dam(x1)
        # 过曝光（记作2）
        x2_dam = self.dam(x2)

        # 初步融合
        x3 = x1 * x1_dam
        x4 = x2 * x2_dam
        x_fusion = x3 + x4

        # 特征增强模块
        x_enhanced = self.fem(x_fusion)

        # 最终卷积
        out = self.final_conv(x_enhanced)
        
        return out

# 定义损失函数
class HDRLoss(nn.Module):
    def __init__(self, vgg_model):
        super(HDRLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.vgg = vgg_model

    def forward(self, Ihdr, Iref, Iu, Io):
        mse_loss = self.mse_loss(Ihdr, Iref)
        Ihdr_vgg = self.vgg(Ihdr)
        Iref_vgg = self.vgg(Iref)
        perceptual_loss = F.mse_loss(Ihdr_vgg, Iref_vgg)

        mef_ssim_loss = 1 - calculate_mef_ssim(Ihdr, Iref)
        adversarial_loss = calculate_adversarial_loss(Ihdr, Iref)
        global_local_contrast_loss = calculate_global_local_contrast_loss(Ihdr, Iref, Iu, Io)

        total_loss = mse_loss + 0.1 * mef_ssim_loss + 0.1 * adversarial_loss + 0.1 * perceptual_loss + 0.1 * global_local_contrast_loss
        return total_loss

def calculate_mef_ssim(Ihdr, Iref):
    return 0.9  # 示例返回值

def calculate_adversarial_loss(Ihdr, Iref):
    return 0.1  # 示例返回值

def calculate_global_local_contrast_loss(Ihdr, Iref, Iu, Io):
    return 0.05  # 示例返回值

def load_image(image_path, transform):
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0)

def train_model(model, loss_fn, optimizer, dataloader, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        for batch in dataloader:
            input_image1, input_image2, ref_image = batch
            output_image = model(input_image1, input_image2)
            loss = loss_fn(output_image, ref_image, input_image1, input_image2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

# 加载预训练的VGG-16模型
vgg_model = vgg16(pretrained=True).features[:16].eval()
for param in vgg_model.parameters():
    param.requires_grad = False

model = HDRNet(in_channels=3, out_channels=3)
loss_fn = HDRLoss(vgg_model=vgg_model)
optimizer = optim.Adam(model.parameters(), lr=0.001)

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

image_path1 = 'path_to_underexposed_image.jpg'
image_path2 = 'path_to_overexposed_image.jpg'
ref_image_path = 'path_to_reference_image.jpg'

input_image1 = load_image(image_path1, transform)
input_image2 = load_image(image_path2, transform)
ref_image = load_image(ref_image_path, transform)

dataloader = [(input_image1, input_image2, ref_image)]

train_model(model, loss_fn, optimizer, dataloader, num_epochs=200)

torch.save(model.state_dict(), 'hdrnet_model.pth')
