# Zero-DCE Improved: Development Notebook

**Authors:** Imrose Batterywala (314540010), Shahzeb Mohammed (314540021)

This notebook implements the improvements outlined in the survey:
1. **Bright/Dark Balance Loss** - Dual-histogram regularization
2. **Texture-Aware Lighting Maps** - Gradient-respecting kernels
3. **Hybrid Exposure Fusion** - Multi-exposure fusion
4. **Perceptual Co-training** - NIMA/NIQE-based losses

We'll implement and test each improvement incrementally.


In [None]:
# Setup and Imports
import sys
import os
from pathlib import Path

# Add parent directory to path to import baseline modules
current_dir = Path.cwd()
if 'improv' in str(current_dir):
    parent_dir = current_dir.parent
else:
    parent_dir = current_dir
sys.path.insert(0, str(parent_dir))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import time

# Import baseline modules
import model
import Myloss
import dataloader

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set paths - adjust based on notebook location
PROJECT_ROOT = parent_dir
print(f"Project root: {PROJECT_ROOT}")


## 1. Bright/Dark Balance Loss

This loss addresses the dual problem of persistent dark pixels and missing bright pixels by:
- Penalizing high fraction of pixels < 0.2 (dark regions)
- Encouraging presence of pixels > 0.9 (bright regions)
- Balancing with existing exposure control loss


In [None]:
class BrightDarkBalanceLoss(nn.Module):
    """
    Dual-histogram regularization loss to address:
    - Persistent dark pixels (< 0.2)
    - Missing bright pixels (> 0.9)
    """
    def __init__(self, dark_threshold=0.2, bright_threshold=0.9, 
                 dark_target=0.05, bright_target=0.01):
        super(BrightDarkBalanceLoss, self).__init__()
        self.dark_threshold = dark_threshold
        self.bright_threshold = bright_threshold
        self.dark_target = dark_target  # Target fraction of dark pixels
        self.bright_target = bright_target  # Target fraction of bright pixels
        
    def compute_luminance(self, x):
        """Convert RGB to luminance"""
        # Using standard luminance weights: 0.299*R + 0.587*G + 0.114*B
        if x.shape[1] == 3:  # RGB image
            weights = torch.tensor([0.299, 0.587, 0.114], device=x.device).view(1, 3, 1, 1)
            luma = torch.sum(x * weights, dim=1, keepdim=True)
        else:
            luma = x
        return luma
    
    def forward(self, enhanced_image):
        """
        Args:
            enhanced_image: Tensor of shape (B, C, H, W) in range [0, 1]
        Returns:
            loss: Combined dark and bright pixel balance loss
        """
        # Compute luminance
        luma = self.compute_luminance(enhanced_image)
        
        # Compute dark pixel fraction (pixels < dark_threshold)
        dark_mask = (luma < self.dark_threshold).float()
        dark_fraction = torch.mean(dark_mask)
        
        # Compute bright pixel fraction (pixels > bright_threshold)
        bright_mask = (luma > self.bright_threshold).float()
        bright_fraction = torch.mean(bright_mask)
        
        # Dark pixel reduction loss: penalize high dark fraction
        # Use L2 loss to push dark_fraction toward target
        dark_loss = torch.pow(dark_fraction - self.dark_target, 2)
        
        # Bright pixel promotion loss: encourage presence of bright pixels
        # Use inverse relationship - if bright_fraction is too low, increase loss
        if bright_fraction < self.bright_target:
            bright_loss = torch.pow(self.bright_target - bright_fraction, 2)
        else:
            # If we have enough bright pixels, just maintain (small penalty)
            bright_loss = 0.1 * torch.pow(bright_fraction - self.bright_target, 2)
        
        # Combined loss
        total_loss = dark_loss + bright_loss
        
        return total_loss, {
            'dark_fraction': dark_fraction.item(),
            'bright_fraction': bright_fraction.item(),
            'dark_loss': dark_loss.item(),
            'bright_loss': bright_loss.item()
        }

# Test the loss function
print("Bright/Dark Balance Loss implementation complete!")
print("\nLoss function parameters:")
print(f"  Dark threshold: 0.2")
print(f"  Bright threshold: 0.9")
print(f"  Dark target fraction: 0.05")
print(f"  Bright target fraction: 0.01")


## 2. Texture-Aware Lighting Maps

This improvement uses gradient-respecting kernels to preserve texture while enhancing illumination.


In [None]:
class TextureAwareSmoothnessLoss(nn.Module):
    """
    Gradient-aware illumination smoothness loss that respects texture boundaries.
    This prevents harsh contrast jumps while maintaining smooth illumination maps.
    """
    def __init__(self):
        super(TextureAwareSmoothnessLoss, self).__init__()
        
        # Sobel kernels for gradient computation
        sobel_x = torch.FloatTensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).view(1, 1, 3, 3)
        sobel_y = torch.FloatTensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).view(1, 1, 3, 3)
        
        self.register_buffer('sobel_x', sobel_x)
        self.register_buffer('sobel_y', sobel_y)
        
    def compute_gradient_magnitude(self, x):
        """Compute gradient magnitude map"""
        # Convert to grayscale if RGB
        if x.shape[1] == 3:
            weights = torch.tensor([0.299, 0.587, 0.114], device=x.device).view(1, 3, 1, 1)
            gray = torch.sum(x * weights, dim=1, keepdim=True)
        else:
            gray = x
            
        # Compute gradients
        grad_x = F.conv2d(gray, self.sobel_x, padding=1)
        grad_y = F.conv2d(gray, self.sobel_y, padding=1)
        
        # Gradient magnitude
        grad_mag = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6)
        return grad_mag
    
    def forward(self, illumination_map, input_image):
        """
        Args:
            illumination_map: Curve parameter maps (B, 24, H, W) or similar
            input_image: Original input image (B, 3, H, W) for gradient computation
        Returns:
            loss: Gradient-aware smoothness loss
        """
        # Compute gradient magnitude of input image
        input_grad = self.compute_gradient_magnitude(input_image)
        
        # Normalize gradient to [0, 1] for weighting
        input_grad_norm = (input_grad - input_grad.min()) / (input_grad.max() - input_grad.min() + 1e-6)
        
        # Compute gradients of illumination map
        # For each channel in illumination map
        total_loss = 0.0
        num_channels = illumination_map.shape[1]
        
        for i in range(num_channels):
            channel_map = illumination_map[:, i:i+1, :, :]
            
            # Compute horizontal and vertical gradients
            grad_x = torch.abs(channel_map[:, :, :, 1:] - channel_map[:, :, :, :-1])
            grad_y = torch.abs(channel_map[:, :, 1:, :] - channel_map[:, :, :-1, :])
            
            # Apply gradient-aware weighting
            # In high-gradient regions (texture boundaries), allow more variation
            # In low-gradient regions (smooth areas), enforce more smoothness
            weight_x = 1.0 / (input_grad_norm[:, :, :, :-1] + 0.1)  # Inverse weighting
            weight_y = 1.0 / (input_grad_norm[:, :, :-1, :] + 0.1)
            
            # Weighted smoothness loss
            channel_loss = torch.mean(weight_x * grad_x) + torch.mean(weight_y * grad_y)
            total_loss += channel_loss
        
        return total_loss / num_channels

print("Texture-Aware Smoothness Loss implementation complete!")


In [None]:
class ExposureFusion:
    """
    Multi-scale exposure fusion to combine Zero-DCE output with synthetic exposure brackets.
    This helps preserve highlights while lifting shadows.
    """
    def __init__(self, num_scales=5, exposure_values=[-2, -1, 0, 1, 2]):
        self.num_scales = num_scales
        self.exposure_values = exposure_values  # Exposure values in EV
        
    def generate_exposure_brackets(self, image, base_exposure=0):
        """
        Generate synthetic exposure brackets from an image.
        
        Args:
            image: Tensor (B, C, H, W) in range [0, 1]
            base_exposure: Base exposure value (0 = no change)
        Returns:
            brackets: List of exposure-bracketed images
        """
        brackets = []
        for ev in self.exposure_values:
            # Exposure adjustment: multiply by 2^EV
            exposure_factor = 2.0 ** ev
            exposed = image * exposure_factor
            # Clip to valid range
            exposed = torch.clamp(exposed, 0.0, 1.0)
            brackets.append(exposed)
        return brackets
    
    def compute_weights(self, image):
        """
        Compute fusion weights based on:
        - Well-exposedness (closeness to 0.5)
        - Saturation (colorfulness)
        - Contrast (local variance)
        """
        B, C, H, W = image.shape
        
        # Well-exposedness: Gaussian centered at 0.5
        well_exposed = torch.exp(-0.5 * torch.pow(image - 0.5, 2) / (0.2 ** 2))
        well_exposed = torch.mean(well_exposed, dim=1, keepdim=True)  # Average across channels
        
        # Saturation: standard deviation across color channels
        saturation = torch.std(image, dim=1, keepdim=True)
        
        # Contrast: local variance using a simple kernel
        kernel_size = 5
        kernel = torch.ones(1, 1, kernel_size, kernel_size, device=image.device) / (kernel_size ** 2)
        mean = F.conv2d(torch.mean(image, dim=1, keepdim=True), kernel, padding=kernel_size//2)
        variance = F.conv2d(torch.pow(torch.mean(image, dim=1, keepdim=True) - mean, 2), 
                          kernel, padding=kernel_size//2)
        contrast = variance
        
        # Normalize each component
        well_exposed = (well_exposed - well_exposed.min()) / (well_exposed.max() - well_exposed.min() + 1e-6)
        saturation = (saturation - saturation.min()) / (saturation.max() - saturation.min() + 1e-6)
        contrast = (contrast - contrast.min()) / (contrast.max() - contrast.min() + 1e-6)
        
        # Combined weight
        weight = well_exposed * saturation * (contrast + 0.1)
        return weight
    
    def multi_scale_fusion(self, images, weights):
        """
        Multi-scale fusion using Laplacian pyramid.
        
        Args:
            images: List of images to fuse
            weights: List of weight maps
        Returns:
            fused: Fused image
        """
        # Simple weighted average (can be extended to Laplacian pyramid)
        weights_sum = sum(weights)
        weights_sum = torch.clamp(weights_sum, min=1e-6)
        
        fused = sum(img * w for img, w in zip(images, weights)) / weights_sum
        return fused
    
    def fuse(self, zero_dce_output, input_image):
        """
        Fuse Zero-DCE output with exposure brackets.
        
        Args:
            zero_dce_output: Enhanced image from Zero-DCE (B, C, H, W)
            input_image: Original input image (B, C, H, W)
        Returns:
            fused_image: Fused result
        """
        # Generate exposure brackets from Zero-DCE output
        brackets = self.generate_exposure_brackets(zero_dce_output)
        
        # Include Zero-DCE output as one of the brackets
        all_images = [zero_dce_output] + brackets
        
        # Compute weights for each image
        weights = [self.compute_weights(img) for img in all_images]
        
        # Multi-scale fusion
        fused = self.multi_scale_fusion(all_images, weights)
        
        # Adaptive blending with Zero-DCE output based on local luminance
        luma_weights = torch.tensor([0.299, 0.587, 0.114], device=zero_dce_output.device).view(1, 3, 1, 1)
        luma = torch.sum(zero_dce_output * luma_weights, dim=1, keepdim=True)
        
        # In dark regions, use more of fused result; in bright regions, use more of Zero-DCE
        blend_weight = torch.clamp(luma * 2.0, 0.0, 1.0)  # More fusion in dark areas
        final = blend_weight * fused + (1 - blend_weight) * zero_dce_output
        
        return torch.clamp(final, 0.0, 1.0)

print("Exposure Fusion implementation complete!")


## 4. Perceptual Co-training Loss

This loss uses NIMA and NIQE metrics to optimize for perceptual quality.


In [None]:
class PerceptualLoss(nn.Module):
    """
    Perceptual loss using VGG features (similar to existing perception_loss in Myloss.py)
    Extended to work with NIMA/NIQE concepts.
    """
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        # Use VGG features for perceptual similarity
        # This is a simplified version - full NIMA/NIQE integration would require their models
        try:
            from torchvision.models import vgg16
            vgg = vgg16(pretrained=True).features
            self.feature_extractor = nn.Sequential(*list(vgg.children())[:23])  # Up to relu4_3
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
        except:
            print("Warning: Could not load VGG16. Using simplified perceptual loss.")
            self.feature_extractor = None
    
    def forward(self, enhanced_image, target_image=None):
        """
        Compute perceptual loss.
        If target_image is None, assumes we want to maximize quality (no reference).
        """
        if self.feature_extractor is None:
            # Fallback: use simple L2 loss on image
            return torch.mean(torch.pow(enhanced_image, 2))
        
        # Extract features
        enhanced_features = self.feature_extractor(enhanced_image)
        
        if target_image is not None:
            # Reference-based perceptual loss
            target_features = self.feature_extractor(target_image)
            loss = F.mse_loss(enhanced_features, target_features)
        else:
            # No-reference: encourage natural-looking features
            # Penalize extreme values (overexposure/underexposure)
            loss = torch.mean(torch.pow(enhanced_features - 0.5, 2))
        
        return loss

# Note: Full NIMA/NIQE integration would require:
# 1. Loading pre-trained NIMA model for aesthetic scoring
# 2. Computing NIQE score (typically non-differentiable, may need approximation)
# For now, we use VGG-based perceptual loss as a proxy

print("Perceptual Loss implementation complete!")
print("Note: Full NIMA/NIQE integration requires their pre-trained models.")


## 5. Integrated Training Function

Combine all improvements into a unified training loop.


In [None]:
def create_improved_loss_function(use_bright_dark=True, use_texture_aware=True, 
                                   use_perceptual=True, device='cuda'):
    """
    Create a combined loss function with all improvements.
    
    Returns:
        loss_fn: Function that computes total loss
        loss_components: Dictionary to store individual loss values
    """
    # Baseline losses
    L_color = Myloss.L_color().to(device)
    L_spa = Myloss.L_spa().to(device)
    L_exp = Myloss.L_exp(16, 0.6).to(device)
    L_TV = Myloss.L_TV().to(device)
    
    # New losses
    bright_dark_loss = BrightDarkBalanceLoss().to(device) if use_bright_dark else None
    texture_aware_loss = TextureAwareSmoothnessLoss().to(device) if use_texture_aware else None
    perceptual_loss = PerceptualLoss().to(device) if use_perceptual else None
    
    def compute_loss(enhanced_image_1, enhanced_image, A, img_lowlight, 
                     loss_weights=None):
        """
        Compute combined loss with all improvements.
        
        Args:
            enhanced_image_1: Intermediate enhanced image
            enhanced_image: Final enhanced image
            A: Illumination adjustment curves (24 channels)
            img_lowlight: Original low-light input
            loss_weights: Dictionary of loss weights (optional)
        Returns:
            total_loss: Combined loss
            loss_dict: Dictionary of individual losses
        """
        if loss_weights is None:
            loss_weights = {
                'tv': 200.0,
                'spa': 1.0,
                'color': 5.0,
                'exp': 10.0,
                'bright_dark': 2.0,
                'texture_aware': 1.0,
                'perceptual': 0.5
            }
        
        loss_dict = {}
        
        # Baseline losses
        loss_tv = loss_weights['tv'] * L_TV(A)
        loss_spa = loss_weights['spa'] * torch.mean(L_spa(enhanced_image, img_lowlight))
        loss_col = loss_weights['color'] * torch.mean(L_color(enhanced_image))
        loss_exp = loss_weights['exp'] * torch.mean(L_exp(enhanced_image))
        
        loss_dict['tv'] = loss_tv.item()
        loss_dict['spa'] = loss_spa.item()
        loss_dict['color'] = loss_col.item()
        loss_dict['exp'] = loss_exp.item()
        
        total_loss = loss_tv + loss_spa + loss_col + loss_exp
        
        # New losses
        if bright_dark_loss is not None:
            bd_loss, bd_info = bright_dark_loss(enhanced_image)
            total_loss += loss_weights['bright_dark'] * bd_loss
            loss_dict['bright_dark'] = bd_loss.item()
            loss_dict.update({f'bd_{k}': v for k, v in bd_info.items()})
        
        if texture_aware_loss is not None:
            ta_loss = texture_aware_loss(A, img_lowlight)
            total_loss += loss_weights['texture_aware'] * ta_loss
            loss_dict['texture_aware'] = ta_loss.item()
        
        if perceptual_loss is not None:
            perc_loss = perceptual_loss(enhanced_image)
            total_loss += loss_weights['perceptual'] * perc_loss
            loss_dict['perceptual'] = perc_loss.item()
        
        loss_dict['total'] = total_loss.item()
        
        return total_loss, loss_dict
    
    return compute_loss

print("Integrated loss function created!")


## 6. Testing and Evaluation

Let's test the improvements on a sample image.


In [None]:
# Load pre-trained model
def load_model(weights_path, device):
    """Load Zero-DCE model"""
    net = model.enhance_net_nopool().to(device)
    state_dict = torch.load(weights_path, map_location=device)
    net.load_state_dict(state_dict)
    net.eval()
    return net

# Test on a sample image
def test_improvements(image_path, model_path, device='cuda', use_fusion=True):
    """
    Test all improvements on a single image.
    """
    # Load model
    net = load_model(model_path, device)
    
    # Load and preprocess image
    img = Image.open(image_path).convert('RGB')
    img_array = np.asarray(img).astype(np.float32) / 255.0
    img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(device)
    
    # Baseline enhancement
    with torch.no_grad():
        enhanced_1, enhanced, curves = net(img_tensor)
    
    # Apply exposure fusion if enabled
    if use_fusion:
        fusion = ExposureFusion()
        enhanced_fused = fusion.fuse(enhanced, img_tensor)
    else:
        enhanced_fused = enhanced
    
    # Convert to numpy for visualization
    def tensor_to_numpy(t):
        t = t.squeeze(0).cpu().permute(1, 2, 0).numpy()
        return np.clip(t, 0, 1)
    
    original = tensor_to_numpy(img_tensor)
    enhanced_np = tensor_to_numpy(enhanced)
    enhanced_fused_np = tensor_to_numpy(enhanced_fused) if use_fusion else None
    
    return {
        'original': original,
        'enhanced': enhanced_np,
        'enhanced_fused': enhanced_fused_np,
        'curves': curves
    }

print("Testing functions ready!")
print("\nTo test, run:")
print("results = test_improvements('path/to/image.jpg', 'snapshots/Epoch99.pth', device)")


## 7. Visualization Helper

Function to visualize results and compare baseline vs improved.


In [None]:
def visualize_results(results, save_path=None):
    """
    Visualize original, enhanced, and fused results.
    """
    fig, axes = plt.subplots(1, 3 if results['enhanced_fused'] is not None else 2, 
                            figsize=(15, 5))
    
    axes[0].imshow(results['original'])
    axes[0].set_title('Original Low-Light Image')
    axes[0].axis('off')
    
    axes[1].imshow(results['enhanced'])
    axes[1].set_title('Zero-DCE Enhanced')
    axes[1].axis('off')
    
    if results['enhanced_fused'] is not None:
        axes[2].imshow(results['enhanced_fused'])
        axes[2].set_title('With Exposure Fusion')
        axes[2].axis('off')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

def compute_metrics(image):
    """
    Compute evaluation metrics for an image.
    """
    # Convert to luminance
    if len(image.shape) == 3:
        luma = 0.299 * image[:, :, 0] + 0.587 * image[:, :, 1] + 0.114 * image[:, :, 2]
    else:
        luma = image
    
    # Dark pixel fraction
    dark_fraction = np.mean(luma < 0.2)
    
    # Bright pixel fraction
    bright_fraction = np.mean(luma > 0.9)
    
    # Patch contrast (4x4 grid)
    h, w = luma.shape
    patch_h, patch_w = h // 4, w // 4
    patches = []
    for i in range(4):
        for j in range(4):
            patch = luma[i*patch_h:(i+1)*patch_h, j*patch_w:(j+1)*patch_w]
            patches.append(np.std(patch))
    patch_contrast = np.mean(patches)
    
    return {
        'dark_fraction': dark_fraction,
        'bright_fraction': bright_fraction,
        'patch_contrast': patch_contrast,
        'mean_luminance': np.mean(luma)
    }

print("Visualization and metrics functions ready!")


## 8. Example Usage

Test the improvements on a sample image from the dataset.


In [None]:
# Example: Test on a sample image
# Uncomment and modify paths as needed

# # Set paths
# test_image_path = PROJECT_ROOT / "data/test_data/DICM/01.jpg"
# model_path = PROJECT_ROOT / "snapshots/Epoch99.pth"
# 
# if test_image_path.exists() and model_path.exists():
#     print("Testing improvements on sample image...")
#     results = test_improvements(str(test_image_path), str(model_path), device)
#     
#     # Compute metrics
#     orig_metrics = compute_metrics(results['original'])
#     enh_metrics = compute_metrics(results['enhanced'])
#     
#     print("\n=== Metrics Comparison ===")
#     print(f"Dark Pixel Fraction: {orig_metrics['dark_fraction']:.4f} -> {enh_metrics['dark_fraction']:.4f}")
#     print(f"Bright Pixel Fraction: {orig_metrics['bright_fraction']:.6f} -> {enh_metrics['bright_fraction']:.6f}")
#     print(f"Patch Contrast: {orig_metrics['patch_contrast']:.4f} -> {enh_metrics['patch_contrast']:.4f}")
#     
#     # Visualize
#     visualize_results(results)
# else:
#     print("Test image or model not found. Please check paths.")

print("Ready to test! Uncomment the code above and adjust paths as needed.")


## 9. Training Script Integration

To use these improvements in training, integrate them into the training loop.


In [None]:
# Example training loop with improvements
# This can be integrated into lowlight_train.py

def train_with_improvements(config, use_bright_dark=True, use_texture_aware=True, 
                            use_perceptual=True, use_fusion=False):
    """
    Training function with all improvements integrated.
    Note: use_fusion=False during training (only for inference)
    """
    # Initialize model
    DCE_net = model.enhance_net_nopool().to(device)
    DCE_net.apply(lambda m: m.weight.data.normal_(0.0, 0.02) if isinstance(m, nn.Conv2d) else None)
    
    if config.load_pretrain:
        DCE_net.load_state_dict(torch.load(config.pretrain_dir, map_location=device))
    
    # Data loader
    train_dataset = dataloader.lowlight_loader(config.lowlight_images_path)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.train_batch_size, 
        shuffle=True, num_workers=config.num_workers, pin_memory=True
    )
    
    # Loss function with improvements
    compute_loss = create_improved_loss_function(
        use_bright_dark=use_bright_dark,
        use_texture_aware=use_texture_aware,
        use_perceptual=use_perceptual,
        device=device
    )
    
    # Optimizer
    optimizer = optim.Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    
    DCE_net.train()
    
    print("Starting training with improvements...")
    print(f"  Bright/Dark Balance: {use_bright_dark}")
    print(f"  Texture-Aware: {use_texture_aware}")
    print(f"  Perceptual: {use_perceptual}")
    
    for epoch in range(config.num_epochs):
        for iteration, img_lowlight in enumerate(train_loader):
            img_lowlight = img_lowlight.to(device)
            
            # Forward pass
            enhanced_image_1, enhanced_image, A = DCE_net(img_lowlight)
            
            # Compute loss
            loss, loss_dict = compute_loss(enhanced_image_1, enhanced_image, A, img_lowlight)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(DCE_net.parameters(), config.grad_clip_norm)
            optimizer.step()
            
            # Logging
            if (iteration + 1) % config.display_iter == 0:
                print(f"Epoch {epoch}, Iteration {iteration+1}")
                print(f"  Total Loss: {loss_dict['total']:.4f}")
                if 'bright_dark' in loss_dict:
                    print(f"  Dark Fraction: {loss_dict.get('bd_dark_fraction', 0):.4f}, "
                          f"Bright Fraction: {loss_dict.get('bd_bright_fraction', 0):.6f}")
            
            # Save checkpoint
            if (iteration + 1) % config.snapshot_iter == 0:
                save_path = config.snapshots_folder + f"Epoch{epoch}_improved.pth"
                torch.save(DCE_net.state_dict(), save_path)
                print(f"  Saved checkpoint: {save_path}")
    
    return DCE_net

print("Training function with improvements ready!")
print("\nTo use, create a config object and call:")
print("  model = train_with_improvements(config)")


## Summary

This notebook implements all four improvements from the survey:

1. ✅ **Bright/Dark Balance Loss** - Dual-histogram regularization
2. ✅ **Texture-Aware Lighting Maps** - Gradient-respecting smoothness
3. ✅ **Hybrid Exposure Fusion** - Multi-exposure fusion (for inference)
4. ✅ **Perceptual Co-training** - VGG-based perceptual loss

### Next Steps:

1. **Test individual components** on sample images
2. **Fine-tune loss weights** for optimal balance
3. **Train model** with improvements (start with one improvement at a time)
4. **Evaluate** on test datasets and compare with baseline
5. **Iterate** based on results

### Usage Tips:

- Start by testing with `use_bright_dark=True` only, then add others incrementally
- Adjust loss weights in `create_improved_loss_function()` based on results
- Use exposure fusion only during inference (not training) to save computation
- Monitor dark/bright pixel fractions during training to track improvement
