In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import vgg16
from PIL import Image
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim

# HDRLoss 
class HDRLoss(nn.Module):
    def __init__(self, vgg_model, discriminator, real_label, fake_label, criterion):
        super(HDRLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.vgg = vgg_model
        self.discriminator = discriminator
        self.real_label = real_label
        self.fake_label = fake_label
        self.criterion = criterion

    def forward(self, Ihdr, Iref, Iu, Io):
        mse_loss = self.mse_loss(Ihdr, Iref)
        Ihdr_vgg = self.vgg(Ihdr)
        Iref_vgg = self.vgg(Iref)
        perceptual_loss = F.mse_loss(Ihdr_vgg, Iref_vgg)

        mef_ssim_loss = calculate_mef_ssim(Ihdr, Iref)
        d_loss, g_loss = calculate_adversarial_loss(Ihdr, Iref, self.discriminator, self.real_label, self.fake_label, self.criterion)
        global_local_contrast_loss = calculate_global_local_contrast_loss(Ihdr, Iref, Iu, Io, self.vgg)

        total_loss = mse_loss + 0.1 * mef_ssim_loss + 0.1 * g_loss + 0.1 * perceptual_loss + 0.1 * global_local_contrast_loss
        return total_loss

def calculate_mef_ssim(Ihdr, Iref):
    # Ihdr 和 Iref 的形状应为 (B, C, H, W)，需要转换为 (H, W, C)
    Ihdr_np = Ihdr.squeeze(0).permute(1, 2, 0).cpu().numpy()
    Iref_np = Iref.squeeze(0).permute(1, 2, 0).cpu().numpy()
    
    # 计算 SSIM
    ssim_value = ssim(Ihdr_np, Iref_np, multichannel=True)
    return 1 - ssim_value  # 返回 MEF-SSIM 损失

class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)

def calculate_adversarial_loss(Ihdr, Iref, discriminator, real_label, fake_label, criterion):
    real_output = discriminator(Iref)
    fake_output = discriminator(Ihdr)
    
    real_loss = criterion(real_output, real_label)
    fake_loss = criterion(fake_output, fake_label)
    
    d_loss = (real_loss + fake_loss) / 2
    g_loss = criterion(fake_output, real_label)
    
    return d_loss, g_loss


def calculate_global_local_contrast_loss(Ihdr, Iref, Iu, Io, vgg_model):
    def extract_features(image):
        features = vgg_model(image)
        return features
    
    Ihdr_features = extract_features(Ihdr)
    Iref_features = extract_features(Iref)
    Iu_features = extract_features(Iu)
    Io_features = extract_features(Io)
    
    global_loss = F.mse_loss(Ihdr_features, Iref_features) + F.mse_loss(Ihdr_features, Iu_features) + F.mse_loss(Ihdr_features, Io_features)
    
    # 局部对比损失可以通过裁剪图像块来计算
    def calculate_local_loss(Ihdr, Iref, Iu, Io, vgg_model):
        P = 4  # 图像块数量
        Ihdr_patches = Ihdr.unfold(2, P, P).unfold(3, P, P)
        Iref_patches = Iref.unfold(2, P, P).unfold(3, P, P)
        Iu_patches = Iu.unfold(2, P, P).unfold(3, P, P)
        Io_patches = Io.unfold(2, P, P).unfold(3, P, P)
        
        local_loss = 0
        for i in range(P):
            for j in range(P):
                Ihdr_patch = Ihdr_patches[:, :, i, j, :, :]
                Iref_patch = Iref_patches[:, :, i, j, :, :]
                Iu_patch = Iu_patches[:, :, i, j, :, :]
                Io_patch = Io_patches[:, :, i, j, :, :]
                
                Ihdr_features = extract_features(Ihdr_patch)
                Iref_features = extract_features(Iref_patch)
                Iu_features = extract_features(Iu_patch)
                Io_features = extract_features(Io_patch)
                
                local_loss += F.mse_loss(Ihdr_features, Iref_features) + F.mse_loss(Ihdr_features, Iu_features) + F.mse_loss(Ihdr_features, Io_features)
        
        return local_loss / (P * P)
    
    local_loss = calculate_local_loss(Ihdr, Iref, Iu, Io, vgg_model)
    
    return global_loss + local_loss