In [1]:
"""
Underwater Image Denoising Model Comparison - Part 1: Calculate Metrics
Compares 4 base models (attention, classical_physics, phy_head_v2, phy_head_v3)
and 2 post-processed variants (attention+CLAHE+RedBoost, phy_head_v3+CLAHE+RedBoost)
"""

import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import json
from pathlib import Path

# ============================================================================
# CONFIGURATION
# ============================================================================

CFG = {
    "data": {
        "root_cam": "D:/ML_works/TRIDENT/DATA/Icam",
        "root_ref": "D:/ML_works/TRIDENT/DATA/Iclean",
        "img_size": 256,
    },
    "checkpoints": {
        "attention": "Attention/best_weights_epoch_041.pth",
        "phy_v2": "physics_encoder/checkpointsphy/best_v2.pt",
        "phy_v3": "physics_encoder/checkpointsphy/best_v3.pt"
    },
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "batch_size": 16,  # Reduced for 8GB VRAM
    "results_dir": "comparison_results"
}

os.makedirs(CFG["results_dir"], exist_ok=True)

# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def to_linear(img_srgb: torch.Tensor, gamma: float = 2.2) -> torch.Tensor:
    return torch.clamp(img_srgb, 0.0, 1.0).pow(gamma)

def to_srgb(img_lin: torch.Tensor, gamma: float = 2.2) -> torch.Tensor:
    return torch.clamp(img_lin, 0.0, 1.0).pow(1.0 / gamma)

def clamp01(x: torch.Tensor) -> torch.Tensor:
    return torch.clamp(x, 0.0, 1.0)

def calculate_psnr(img1, img2):
    """Calculate PSNR between two images (0-1 range)"""
    mse = F.mse_loss(img1, img2)
    if mse < 1e-10:
        return 100.0
    return 20 * torch.log10(1.0 / torch.sqrt(mse)).item()

def calculate_ssim(img1, img2, window_size=11):
    """Calculate SSIM between two images"""
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    
    mu1 = F.avg_pool2d(img1, window_size, stride=1, padding=window_size // 2)
    mu2 = F.avg_pool2d(img2, window_size, stride=1, padding=window_size // 2)
    
    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2
    
    sigma1_sq = F.avg_pool2d(img1 * img1, window_size, stride=1, padding=window_size // 2) - mu1_sq
    sigma2_sq = F.avg_pool2d(img2 * img2, window_size, stride=1, padding=window_size // 2) - mu2_sq
    sigma12 = F.avg_pool2d(img1 * img2, window_size, stride=1, padding=window_size // 2) - mu1_mu2
    
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
               ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    
    return ssim_map.mean().item()

# ============================================================================
# POST-PROCESSING FUNCTIONS (CLAHE + Red Boost)
# ============================================================================

def clahe_enhancement(img_tensor):
    """Apply CLAHE to tensor image (B, C, H, W) in [0, 1] range"""
    device = img_tensor.device
    img_np = (img_tensor.cpu().numpy() * 255).astype(np.uint8)
    
    batch_size = img_np.shape[0]
    enhanced = []
    
    for i in range(batch_size):
        img = img_np[i].transpose(1, 2, 0)  # CHW -> HWC
        lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l_clahe = clahe.apply(l)
        
        lab_clahe = cv2.merge((l_clahe, a, b))
        rgb_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
        
        enhanced.append(rgb_clahe.transpose(2, 0, 1))  # HWC -> CHW
    
    enhanced = np.stack(enhanced, axis=0)
    return torch.from_numpy(enhanced / 255.0).float().to(device)

def adaptive_red_boost(img_tensor):
    """Apply adaptive red channel boost to tensor image"""
    device = img_tensor.device
    img_np = (img_tensor.cpu().numpy() * 255).astype(np.uint8)
    
    batch_size = img_np.shape[0]
    boosted = []
    
    for i in range(batch_size):
        img = img_np[i].transpose(1, 2, 0)  # CHW -> HWC
        
        r_mean = img[:, :, 0].mean()
        g_mean = img[:, :, 1].mean()
        b_mean = img[:, :, 2].mean()
        
        if r_mean < 1:
            factor = 2.0
        else:
            factor = (g_mean + b_mean) / (2 * r_mean)
        factor = np.clip(factor, 1.2, 3.0)
        
        img_float = img.astype(np.float32)
        img_float[:, :, 0] = np.clip(img_float[:, :, 0] * factor, 0, 255)
        
        boosted.append(img_float.astype(np.uint8).transpose(2, 0, 1))
    
    boosted = np.stack(boosted, axis=0)
    return torch.from_numpy(boosted / 255.0).float().to(device)

def compute_classical(I_cam_srgb):
    """Apply classical enhancement (CLAHE + Red Boost) to the input sRGB image"""
    # Input is (B, C, H, W) sRGB tensor in [0, 1] range
    temp_srgb = clahe_enhancement(I_cam_srgb)
    enhanced_srgb = adaptive_red_boost(temp_srgb)
    return enhanced_srgb

# ============================================================================
# DATASET
# ============================================================================

# ============================================================================
# DATASET
# ============================================================================

class FullDataset(Dataset):
    """Load all images from Icam and Iclean folders"""
    def __init__(self, root_cam, root_ref, img_size=256):
        self.root_cam = root_cam
        self.root_ref = root_ref
        self.img_size = img_size
        
        # Get all image files
        cam_files = sorted([f for f in os.listdir(root_cam) if f.endswith('.png')])
        ref_files = sorted([f for f in os.listdir(root_ref) if f.endswith('.png')])
        
        # Find common files
        common = sorted(list(set(cam_files) & set(ref_files)))
        
        self.entries = []
        for fname in common:
            cam_path = os.path.join(root_cam, fname)
            ref_path = os.path.join(root_ref, fname)
            if os.path.exists(cam_path) and os.path.exists(ref_path):
                self.entries.append((cam_path, ref_path, fname))
        
        print(f"Found {len(self.entries)} image pairs")
    
    def __len__(self):
        return len(self.entries)
    
    def __getitem__(self, idx):
        cam_path, ref_path, fname = self.entries[idx]
        
        cam = Image.open(cam_path).convert("RGB")
        ref = Image.open(ref_path).convert("RGB")
        
        cam = cam.resize((self.img_size, self.img_size), Image.BILINEAR)
        ref = ref.resize((self.img_size, self.img_size), Image.BILINEAR)
        
        from torchvision import transforms
        to_tensor = transforms.ToTensor()
        
        cam_srgb = to_tensor(cam)
        ref_srgb = to_tensor(ref)
        
        cam_lin = to_linear(cam_srgb, 2.2)
        ref_lin = to_linear(ref_srgb, 2.2)
        
        return {
            "I_cam_srgb": cam_srgb,
            "I_cam_lin": cam_lin,
            "I_clean_srgb": ref_srgb,
            "I_clean_lin": ref_lin,
            "filename": fname
        }

# ============================================================================
# ATTENTION MODEL (ResNet34 U-Net with CBAM)
# ============================================================================

class ChannelAttention_v1(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = self.fc(self.avg_pool(x))
        mx = self.fc(self.max_pool(x))
        return x * self.sigmoid(avg + mx)

class SpatialAttention_v1(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        mx, _ = torch.max(x, dim=1, keepdim=True)
        cat = torch.cat([avg, mx], dim=1)
        return x * self.sigmoid(self.conv(cat))
class CBAM_v1(nn.Module):
    def __init__(self, channels, reduction=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention_v1(channels, reduction)  # <-- Change this
        self.sa = SpatialAttention_v1(kernel_size)         # <-- Change this

    def forward(self, x):
        return self.sa(self.ca(x))

class ResNet34_UNet_CBAM(nn.Module):
    def __init__(self):
        super().__init__()
        from torchvision import models
        resnet = models.resnet34(weights=None)
        
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
        self.cbam = CBAM_v1(channels=512, reduction=16, kernel_size=7)
        
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=0.1)
            )
        
        self.up3 = conv_block(512 + 256, 256)
        self.up2 = conv_block(256 + 128, 128)
        self.up1 = conv_block(128 + 64, 64)
        self.up0 = conv_block(64 + 64, 64)
        self.up_final = conv_block(64, 64)
        self.final_conv = nn.Conv2d(64, 3, kernel_size=1)

    def forward(self, x):
        x0 = self.relu(self.bn1(self.conv1(x)))
        x1 = self.layer1(self.maxpool(x0))
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        
        z = self.cbam(x4)
        
        u3 = F.interpolate(z, scale_factor=2, mode='bilinear', align_corners=False)
        u3 = self.up3(torch.cat([u3, x3], dim=1))
        
        u2 = F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False)
        u2 = self.up2(torch.cat([u2, x2], dim=1))
        
        u1 = F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False)
        u1 = self.up1(torch.cat([u1, x1], dim=1))
        
        u0 = F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False)
        u0 = self.up0(torch.cat([u0, x0], dim=1))
        
        uF = F.interpolate(u0, scale_factor=2, mode='bilinear', align_corners=False)
        uF = self.up_final(uF)
        
        return torch.sigmoid(self.final_conv(uF))

print("✓ Part 1 script loaded - Continue to Part 2 for Physics models")

✓ Part 1 script loaded - Continue to Part 2 for Physics models


In [2]:
"""
Part 2: Physics Models and Metrics Calculation
"""

# ============================================================================
# V3-COMPATIBLE CBAM MODULES (Final version)
# These definitions force layer names to be indices to match phy_v3.pt
# ============================================================================

class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        # These modules are registered by index, not name
        self.add_module("0", nn.AdaptiveAvgPool2d(1))
        self.add_module("1", nn.Conv2d(channels, channels // reduction, 1, bias=True))
        self.add_module("2", nn.ReLU())
        self.add_module("3", nn.Conv2d(channels // reduction, channels, 1, bias=True))
        self.add_module("4", nn.Sigmoid())
        self.add_module("5", nn.AdaptiveMaxPool2d(1))

    def forward(self, x):
        # We access the layers by their forced index names
        avg_out = self._modules["3"](self._modules["2"](self._modules["1"](self._modules["0"](x))))
        max_out = self._modules["3"](self._modules["2"](self._modules["1"](self._modules["5"](x))))
        # Note: self._modules["4"] is the sigmoid
        return x * self._modules["4"](avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        # These modules are registered by index
        self.add_module("0", nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=True))
        self.add_module("1", nn.Sigmoid())

    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        mx, _ = torch.max(x, dim=1, keepdim=True)
        cat = torch.cat([avg, mx], dim=1)
        # Access by index name
        return x * self._modules["1"](self._modules["0"](cat))

class CBAM(nn.Module):
    def __init__(self, channels, reduction=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(channels, reduction)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        return self.sa(self.ca(x))

# ============================================================================
# PHYSICS ENCODER MODELS (v2 and v3)
# ============================================================================

# ============================================================================
# PHYSICS ENCODER MODELS (v2 and v3)
# ============================================================================

def gn(num_channels, groups=8):
# ... (rest of your script) ...
    g = min(groups, num_channels)
    return nn.GroupNorm(g, num_channels)

def act_fn(name="silu"):
    return nn.SiLU() if name == "silu" else nn.ReLU(inplace=True)

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, norm="group", act="silu"):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            gn(out_ch) if norm == "group" else nn.BatchNorm2d(out_ch),
            act_fn(act),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            gn(out_ch) if norm == "group" else nn.BatchNorm2d(out_ch),
            act_fn(act)
        )
    
    def forward(self, x):
        return self.conv(x)

class UNetS_Physics_v2(nn.Module):
    """Physics-Head U-Net v2 (5-level)"""
    def __init__(self, t_min=0.02, zw_dim=32):
        super().__init__()
        chs = [32, 64, 128, 256]
        self.t_min = t_min
        self.zw_dim = zw_dim
        
        self.e1 = ConvBlock(3, chs[0])
        self.p1 = nn.MaxPool2d(2)
        self.e2 = ConvBlock(chs[0], chs[1])
        self.p2 = nn.MaxPool2d(2)
        self.e3 = ConvBlock(chs[1], chs[2])
        self.p3 = nn.MaxPool2d(2)
        self.e4 = ConvBlock(chs[2], chs[3])
        
        self.t_head = nn.Sequential(
            nn.Conv2d(chs[3], 64, 3, padding=1),
            act_fn("silu"),
            nn.Conv2d(64, 3, 1)
        )
        
        self.A_pool = nn.AdaptiveAvgPool2d(1)
        self.A_mlp = nn.Sequential(
            nn.Conv2d(chs[3], 128, 1),
            act_fn("silu"),
            nn.Conv2d(128, 3, 1)
        )
        
        self.zw_pool = nn.AdaptiveAvgPool2d(1)
        self.zw_mlp = nn.Sequential(
            nn.Linear(chs[3], 256),
            nn.SiLU(),
            nn.Linear(256, self.zw_dim)
        )
        
        self.sigma_head = nn.Sequential(
            nn.Conv2d(chs[3], 16, 3, padding=1),
            act_fn("silu"),
            nn.Conv2d(16, 1, 1)
        )
    
    def forward(self, I_cam_lin):
        x1 = self.e1(I_cam_lin)
        x2 = self.e2(self.p1(x1))
        x3 = self.e3(self.p2(x2))
        x4 = self.e4(self.p3(x3))
        
        t_logits = self.t_head(x4)
        t = torch.sigmoid(t_logits) * (1 - 2*self.t_min) + self.t_min
        
        A = self.A_mlp(self.A_pool(x4))
        zw = self.zw_mlp(self.zw_pool(x4).flatten(1))
        sigma2 = F.softplus(self.sigma_head(x4))
        
        return {"t": t, "A": A, "zw": zw, "sigma2": sigma2}

class TRIDENT_v2(nn.Module):
    def __init__(self, t_min=0.02):
        super().__init__()
        self.gamma = 2.2
        self.t_min = t_min
        self.phys = UNetS_Physics_v2(t_min=t_min)
    
    def forward(self, I_cam_l):
        ph = self.phys(I_cam_l)
        t, A = ph["t"], ph["A"]
        
        H, W = I_cam_l.shape[-2], I_cam_l.shape[-1]
        t = F.interpolate(t, size=(H, W), mode="bilinear", align_corners=False)
        A_b = A.expand(-1, -1, H, W)
        
        numerator = I_cam_l - A_b * (1 - t)
        denominator = torch.maximum(t, torch.tensor(self.t_min, device=t.device))
        I_hat_l = clamp01(numerator / denominator)
        
        return I_hat_l




class UNetS_Physics_v3(nn.Module):
    """Physics-Head U-Net v3 (6-level with CBAM)"""
    def __init__(self, t_min=0.02, zw_dim=32):
        super().__init__()
        chs = [32, 64, 128, 256, 512]
        self.t_min = t_min
        self.zw_dim = zw_dim
        
        self.e1 = nn.Sequential(ConvBlock(3, chs[0]), CBAM(chs[0]))
        self.p1 = nn.MaxPool2d(2)
        self.e2 = nn.Sequential(ConvBlock(chs[0], chs[1]), CBAM(chs[1]))
        self.p2 = nn.MaxPool2d(2)
        self.e3 = nn.Sequential(ConvBlock(chs[1], chs[2]), CBAM(chs[2]))
        self.p3 = nn.MaxPool2d(2)
        self.e4 = nn.Sequential(ConvBlock(chs[2], chs[3]), CBAM(chs[3]))
        self.p4 = nn.MaxPool2d(2)
        self.e5 = nn.Sequential(ConvBlock(chs[3], chs[4]), CBAM(chs[4]))
        
        self.t_head = nn.Sequential(
            nn.Conv2d(chs[4], 64, 3, padding=1),
            act_fn("silu"),
            nn.Conv2d(64, 3, 1)
        )
        
        self.A_pool = nn.AdaptiveAvgPool2d(1)
        self.A_mlp = nn.Sequential(
            nn.Conv2d(chs[4], 128, 1),
            act_fn("silu"),
            nn.Conv2d(128, 3, 1)
        )
        
        self.zw_pool = nn.AdaptiveAvgPool2d(1)
        self.zw_mlp = nn.Sequential(
            nn.Linear(chs[4], 256),
            nn.SiLU(),
            nn.Linear(256, self.zw_dim)
        )
        
        self.sigma_head = nn.Sequential(
            nn.Conv2d(chs[4], 16, 3, padding=1),
            act_fn("silu"),
            nn.Conv2d(16, 1, 1)
        )
    
    def forward(self, I_cam_lin):
        x1 = self.e1(I_cam_lin)
        x2 = self.e2(self.p1(x1))
        x3 = self.e3(self.p2(x2))
        x4 = self.e4(self.p3(x3))
        x5 = self.e5(self.p4(x4))
        
        t_logits = self.t_head(x5)
        t = torch.sigmoid(t_logits) * (1 - 2*self.t_min) + self.t_min
        
        A = self.A_mlp(self.A_pool(x5))
        zw = self.zw_mlp(self.zw_pool(x5).flatten(1))
        sigma2 = F.softplus(self.sigma_head(x5))
        
        return {"t": t, "A": A, "zw": zw, "sigma2": sigma2}

class TRIDENT_v3(nn.Module):
    def __init__(self, t_min=0.02):
        super().__init__()
        self.gamma = 2.2
        self.t_min = t_min
        self.phys = UNetS_Physics_v3(t_min=t_min)
        
        self.zw_to_t_bias = nn.Sequential(
            nn.Linear(32, 64),
            nn.SiLU(),
            nn.Linear(64, 3)
        )
    
    def forward(self, I_cam_l):
        ph = self.phys(I_cam_l)
        t, A, zw = ph["t"], ph["A"], ph["zw"]
        
        t_bias = self.zw_to_t_bias(zw).unsqueeze(-1).unsqueeze(-1)
        t_bias = torch.clamp(t_bias, -0.02, 0.02)
        t = torch.clamp(t + t_bias, min=self.t_min, max=1.0)
        
        H, W = I_cam_l.shape[-2], I_cam_l.shape[-1]
        t = F.interpolate(t, size=(H, W), mode="bilinear", align_corners=False)
        A_b = A.expand(-1, -1, H, W)
        A_b = torch.clamp(A_b, 0.0, 1.0)
        
        numerator = torch.clamp(I_cam_l - A_b * (1 - t), min=1e-8)
        denominator = torch.clamp(t, min=1e-8)
        ratio = numerator / denominator
        I_hat_l = torch.clamp(ratio, 0.0, 5.0)
        I_hat_l = torch.clamp(I_hat_l, 0.0, 1.0)
        
        return I_hat_l

# ============================================================================
# LOAD ALL MODELS
# ============================================================================

def load_all_models(device):
    """Load all 4 models"""
    models = {}
    
    # 1. Attention Model
    print("Loading Attention model...")
    attention_model = ResNet34_UNet_CBAM().to(device)
    attention_ckpt = torch.load(CFG["checkpoints"]["attention"], map_location=device)
    attention_model.load_state_dict(attention_ckpt)
    attention_model.eval()
    models["attention"] = attention_model
    
    # 2. Physics v2
    print("Loading Physics v2 model...")
    phy_v2_model = TRIDENT_v2().to(device)
    phy_v2_ckpt = torch.load(CFG["checkpoints"]["phy_v2"], map_location=device)
    phy_v2_model.load_state_dict(phy_v2_ckpt)
    phy_v2_model.eval()
    models["phy_v2"] = phy_v2_model
    
    # 3. Physics v3
    print("Loading Physics v3 model...")
    phy_v3_model = TRIDENT_v3().to(device)
    phy_v3_ckpt = torch.load(CFG["checkpoints"]["phy_v3"], map_location=device)
    phy_v3_model.load_state_dict(phy_v3_ckpt)
    phy_v3_model.eval()
    models["phy_v3"] = phy_v3_model
    
    print("✓ All models loaded successfully")
    return models

# ============================================================================
# METRICS CALCULATION (FOR METHODS WITHOUT PHY_V3)
# ============================================================================

def calculate_metrics_batch1(dataset, models, device):
    """Calculate metrics for: attention, classical, phy_v2, attention+post"""
    
    dataloader = DataLoader(dataset, batch_size=CFG["batch_size"], 
                          shuffle=False, num_workers=0, pin_memory=True)
    
    results = {
        "attention": {"ssim": [], "psnr": []},
        "classical": {"ssim": [], "psnr": []},
        "phy_v2": {"ssim": [], "psnr": []},
        "attention_post": {"ssim": [], "psnr": []}
    }
    
    
    
    print("\n" + "="*60)
    print("CALCULATING METRICS - BATCH 1")
    print("Methods: attention, classical_physics, phy_v2, attention+post")
    print("="*60)
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Processing Batch 1"):
            I_cam_srgb = batch["I_cam_srgb"].to(device)
            I_cam_lin = batch["I_cam_lin"].to(device)
            I_ref_srgb = batch["I_clean_srgb"].to(device)
            I_ref_lin = batch["I_clean_lin"].to(device)
            filenames = batch["filename"]
            
            # 1. Attention Model
            pred_attention = models["attention"](I_cam_srgb)
            
            for i in range(pred_attention.shape[0]):
                ssim = calculate_ssim(pred_attention[i:i+1], I_ref_srgb[i:i+1])
                psnr = calculate_psnr(pred_attention[i:i+1], I_ref_srgb[i:i+1])
                results["attention"]["ssim"].append(ssim)
                results["attention"]["psnr"].append(psnr)
            
            # 2. Classical Physics (compute on the fly)
            pred_classical_srgb = compute_classical(I_cam_srgb)

            for i in range(pred_classical_srgb.shape[0]):
                ssim = calculate_ssim(pred_classical_srgb[i:i+1], I_ref_srgb[i:i+1])
                psnr = calculate_psnr(pred_classical_srgb[i:i+1], I_ref_srgb[i:i+1])
                results["classical"]["ssim"].append(ssim)
                results["classical"]["psnr"].append(psnr)
            # 3. Physics v2
            pred_v2_lin = models["phy_v2"](I_cam_lin)
            pred_v2_srgb = to_srgb(pred_v2_lin, 2.2)
            
            for i in range(pred_v2_srgb.shape[0]):
                ssim = calculate_ssim(pred_v2_srgb[i:i+1], I_ref_srgb[i:i+1])
                psnr = calculate_psnr(pred_v2_srgb[i:i+1], I_ref_srgb[i:i+1])
                results["phy_v2"]["ssim"].append(ssim)
                results["phy_v2"]["psnr"].append(psnr)
            
            # 4. Attention + Post-processing
            pred_attention_post = clahe_enhancement(pred_attention)
            pred_attention_post = adaptive_red_boost(pred_attention_post)
            
            for i in range(pred_attention_post.shape[0]):
                ssim = calculate_ssim(pred_attention_post[i:i+1], I_ref_srgb[i:i+1])
                psnr = calculate_psnr(pred_attention_post[i:i+1], I_ref_srgb[i:i+1])
                results["attention_post"]["ssim"].append(ssim)
                results["attention_post"]["psnr"].append(psnr)
    
    # Calculate averages
    summary = {}
    for method, metrics in results.items():
        summary[method] = {
            "avg_ssim": np.mean(metrics["ssim"]),
            "avg_psnr": np.mean(metrics["psnr"]),
            "std_ssim": np.std(metrics["ssim"]),
            "std_psnr": np.std(metrics["psnr"])
        }
    
    print("\n" + "="*60)
    print("BATCH 1 RESULTS:")
    print("="*60)
    for method, stats in summary.items():
        print(f"{method:20s} | SSIM: {stats['avg_ssim']:.4f} ± {stats['std_ssim']:.4f} | "
              f"PSNR: {stats['avg_psnr']:.2f} ± {stats['std_psnr']:.2f} dB")
    print("="*60)
    
    return results, summary

print("✓ Part 2 loaded - Continue to Part 3 for remaining metrics")

✓ Part 2 loaded - Continue to Part 3 for remaining metrics


In [3]:
"""
Part 3: Calculate v3 metrics and Generate comparison visualizations
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# ============================================================================
# METRICS CALCULATION (FOR V3 METHODS)
# ============================================================================

# ============================================================================
# METRICS CALCULATION (FOR V3 METHODS)
# ============================================================================

def calculate_metrics_batch2(dataset, models, device):
    """Calculate metrics for: phy_v3, phy_v3+post"""
    
    dataloader = DataLoader(dataset, batch_size=CFG["batch_size"], 
                          shuffle=False, num_workers=0, pin_memory=True)
    
    results = {
        "phy_v3": {"ssim": [], "psnr": []},
        "phy_v3_post": {"ssim": [], "psnr": []}
    }
    
    print("\n" + "="*60)
    print("CALCULATING METRICS - BATCH 2")
    print("Methods: phy_v3, phy_v3+post")
    print("="*60)
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Processing Batch 2"):
            I_cam_lin = batch["I_cam_lin"].to(device)
            I_ref_srgb = batch["I_clean_srgb"].to(device)
            
            # 1. Physics v3
            pred_v3_lin = models["phy_v3"](I_cam_lin)
            pred_v3_srgb = to_srgb(pred_v3_lin, 2.2)
            
            for i in range(pred_v3_srgb.shape[0]):
                ssim = calculate_ssim(pred_v3_srgb[i:i+1], I_ref_srgb[i:i+1])
                psnr = calculate_psnr(pred_v3_srgb[i:i+1], I_ref_srgb[i:i+1])
                results["phy_v3"]["ssim"].append(ssim)
                results["phy_v3"]["psnr"].append(psnr)
            
            # 2. Physics v3 + Post-processing
            pred_v3_post = clahe_enhancement(pred_v3_srgb)
            pred_v3_post = adaptive_red_boost(pred_v3_post)
            
            for i in range(pred_v3_post.shape[0]):
                ssim = calculate_ssim(pred_v3_post[i:i+1], I_ref_srgb[i:i+1])
                psnr = calculate_psnr(pred_v3_post[i:i+1], I_ref_srgb[i:i+1])
                results["phy_v3_post"]["ssim"].append(ssim)
                results["phy_v3_post"]["psnr"].append(psnr)
    
    # Calculate averages
    summary = {}
    for method, metrics in results.items():
        summary[method] = {
            "avg_ssim": np.mean(metrics["ssim"]),
            "avg_psnr": np.mean(metrics["psnr"]),
            "std_ssim": np.std(metrics["ssim"]),
            "std_psnr": np.std(metrics["psnr"])
        }
    
    print("\n" + "="*60)
    print("BATCH 2 RESULTS:")
    print("="*60)
    for method, stats in summary.items():
        print(f"{method:20s} | SSIM: {stats['avg_ssim']:.4f} ± {stats['std_ssim']:.4f} | "
              f"PSNR: {stats['avg_psnr']:.2f} ± {stats['std_psnr']:.2f} dB")
    print("="*60)
    
    return results, summary

# ============================================================================
# IMAGE COMPARISON VISUALIZATION
# ============================================================================

def generate_comparison_images(dataset, models, all_results, device):
    """Generate 18 comparison images (8 columns each)"""
    import matplotlib.pyplot as plt
    from matplotlib.gridspec import GridSpec
    
    # Merge all results
    all_ssim = {}
    for method in ["attention", "classical", "phy_v2", "phy_v3", "attention_post", "phy_v3_post"]:
        if method in all_results:
            all_ssim[method] = all_results[method]["ssim"]
    
    num_images = len(dataset)
    
    # 1. Select 6 random images
    np.random.seed(42)
    random_indices = np.random.choice(num_images, 6, replace=False)
    
    # 2. Select best image for each method (6 images)
    best_indices = []
    for method in ["attention", "classical", "phy_v2", "phy_v3", "attention_post", "phy_v3_post"]:
        best_idx = np.argmax(all_ssim[method])
        best_indices.append(best_idx)
    
    # 3. Select worst image for each method (6 images)
    worst_indices = []
    for method in ["attention", "classical", "phy_v2", "phy_v3", "attention_post", "phy_v3_post"]:
        worst_idx = np.argmin(all_ssim[method])
        worst_indices.append(worst_idx)
    
    # Combine all indices (18 total)
    all_indices = list(random_indices) + best_indices + worst_indices
    
    print("\n" + "="*60)
    print("GENERATING COMPARISON IMAGES")
    print(f"Total: 18 images × 8 columns")
    print("="*60)
    

    output_dir = os.path.join(CFG["results_dir"], "comparison_images")
    os.makedirs(output_dir, exist_ok=True)
    
    method_names = ["Attention", "Classical", "Phy-v2", "Phy-v3", "Att+Post", "V3+Post"]
    
    with torch.no_grad():
        for idx_num, img_idx in enumerate(tqdm(all_indices, desc="Generating comparisons")):
            # Load data
            data = dataset[img_idx]
            I_cam_srgb = data["I_cam_srgb"].unsqueeze(0).to(device)
            I_cam_lin = data["I_cam_lin"].unsqueeze(0).to(device)
            I_ref_srgb = data["I_clean_srgb"].unsqueeze(0).to(device)
            filename = data["filename"]
            
            # Generate predictions
            predictions = {}
            
            # Attention
            predictions["attention"] = models["attention"](I_cam_srgb)
            
            # Classical
            predictions["classical"] = compute_classical(I_cam_srgb)
            
            # Physics v2
            pred_v2_lin = models["phy_v2"](I_cam_lin)
            predictions["phy_v2"] = to_srgb(pred_v2_lin, 2.2)
            
            # Physics v3
            pred_v3_lin = models["phy_v3"](I_cam_lin)
            predictions["phy_v3"] = to_srgb(pred_v3_lin, 2.2)
            
            # Attention + Post
            predictions["attention_post"] = clahe_enhancement(predictions["attention"])
            predictions["attention_post"] = adaptive_red_boost(predictions["attention_post"])
            
            # Physics v3 + Post
            predictions["phy_v3_post"] = clahe_enhancement(predictions["phy_v3"])
            predictions["phy_v3_post"] = adaptive_red_boost(predictions["phy_v3_post"])
            
            # Calculate SSIM for each method
            ssims = {}
            psnrs = {}
            for method_key, method_name in zip(
                ["attention", "classical", "phy_v2", "phy_v3", "attention_post", "phy_v3_post"],
                method_names
            ):
                ssims[method_name] = calculate_ssim(predictions[method_key], I_ref_srgb)
                psnrs[method_name] = calculate_psnr(predictions[method_key], I_ref_srgb)
            
            # Create visualization
            fig = plt.figure(figsize=(24, 4))
            gs = GridSpec(1, 8, figure=fig, wspace=0.02, hspace=0)
            
            # Helper to convert tensor to numpy
            def tensor_to_numpy(t):
                return t.squeeze(0).cpu().permute(1, 2, 0).numpy()
            
            # Plot images
            images = [
                ("Icam", tensor_to_numpy(I_cam_srgb)),
                ("Attention\nSSIM:{:.3f}\nPSNR:{:.1f}".format(ssims["Attention"], psnrs["Attention"]), 
                 tensor_to_numpy(predictions["attention"])),
                ("Classical\nSSIM:{:.3f}\nPSNR:{:.1f}".format(ssims["Classical"], psnrs["Classical"]), 
                 tensor_to_numpy(predictions["classical"])),
                ("Phy-v2\nSSIM:{:.3f}\nPSNR:{:.1f}".format(ssims["Phy-v2"], psnrs["Phy-v2"]), 
                 tensor_to_numpy(predictions["phy_v2"])),
                ("Phy-v3\nSSIM:{:.3f}\nPSNR:{:.1f}".format(ssims["Phy-v3"], psnrs["Phy-v3"]), 
                 tensor_to_numpy(predictions["phy_v3"])),
                ("Att+Post\nSSIM:{:.3f}\nPSNR:{:.1f}".format(ssims["Att+Post"], psnrs["Att+Post"]), 
                 tensor_to_numpy(predictions["attention_post"])),
                ("V3+Post\nSSIM:{:.3f}\nPSNR:{:.1f}".format(ssims["V3+Post"], psnrs["V3+Post"]), 
                 tensor_to_numpy(predictions["phy_v3_post"])),
                ("Iclean", tensor_to_numpy(I_ref_srgb))
            ]
            
            for col_idx, (title, img) in enumerate(images):
                ax = fig.add_subplot(gs[0, col_idx])
                ax.imshow(img)
                ax.set_title(title, fontsize=8)
                ax.axis('off')
            
            # Determine category
            if idx_num < 6:
                category = "random"
                save_filename = f"comparison_{category}_{idx_num+1:02d}.png"
            elif idx_num < 12:
                method_idx = idx_num - 6
                category = f"best_{method_names[method_idx]}"
                save_filename = f"comparison_{category}.png"
            else:
                method_idx = idx_num - 12
                category = f"worst_{method_names[method_idx]}"
                save_filename = f"comparison_{category}.png"
            
            save_path = os.path.join(output_dir, save_filename)
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            plt.close()
    
    print(f"✓ Saved {len(all_indices)} comparison images to: {output_dir}")

# ============================================================================
# SAVE RESULTS TO JSON
# ============================================================================

def save_results_to_json(all_summaries, all_results):
    """Save all metrics to JSON file"""
    output = {
        "summary": all_summaries,
        "detailed_results": {
            method: {
                "ssim": [float(x) for x in metrics["ssim"]],
                "psnr": [float(x) for x in metrics["psnr"]]
            }
            for method, metrics in all_results.items()
        }
    }
    
    output_path = os.path.join(CFG["results_dir"], "metrics_summary.json")
    with open(output_path, 'w') as f:
        json.dump(output, f, indent=2)
    
    print(f"✓ Saved metrics to: {output_path}")

# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    print("\n" + "="*70)
    print("UNDERWATER IMAGE DENOISING - MODEL COMPARISON")
    print("="*70)
    
    device = torch.device(CFG["device"])
    print(f"Using device: {device}")
    
    # Load dataset
    print("\nLoading dataset...")
    full_dataset = FullDataset(
        CFG["data"]["root_cam"],
        CFG["data"]["root_ref"],
        CFG["data"]["img_size"]
    )

    # --- NEW: Randomly sample 15% of the dataset ---
    print(f"Full dataset size: {len(full_dataset)}")
    num_total = len(full_dataset)
    num_sample = int(num_total * 0.15)

    # Ensure we have at least 1 image to sample
    if num_sample == 0 and num_total > 0:
        num_sample = 1

    np.random.seed(42) # for reproducible sampling
    all_indices = list(range(num_total))
    np.random.shuffle(all_indices)
    sample_indices = all_indices[:num_sample]

    sample_dataset = torch.utils.data.Subset(full_dataset, sample_indices)
    print(f"Using a random 15% sample: {len(sample_dataset)} images")
    # --- END NEW ---

    # Load models
    models = load_all_models(device)

    # Calculate metrics - Batch 1 (using sample_dataset)
    results_batch1, summary_batch1 = calculate_metrics_batch1(sample_dataset, models, device)

    # Calculate metrics - Batch 2 (using sample_dataset)
    results_batch2, summary_batch2 = calculate_metrics_batch2(sample_dataset, models, device)
    
    # Merge results
    all_results = {**results_batch1, **results_batch2}
    all_summaries = {**summary_batch1, **summary_batch2}
    
    # Print final summary
    print("\n" + "="*70)
    print("FINAL SUMMARY - ALL METHODS")
    print("="*70)
    method_order = ["attention", "classical", "phy_v2", "phy_v3", "attention_post", "phy_v3_post"]
    method_display = ["Attention", "Classical Physics", "Physics v2", "Physics v3", 
                     "Attention + Post", "Physics v3 + Post"]
    
    for method, display_name in zip(method_order, method_display):
        stats = all_summaries[method]
        print(f"{display_name:25s} | SSIM: {stats['avg_ssim']:.4f} ± {stats['std_ssim']:.4f} | "
              f"PSNR: {stats['avg_psnr']:.2f} ± {stats['std_psnr']:.2f} dB")
    print("="*70)
    
    # Generate comparison images
    generate_comparison_images(sample_dataset, models, all_results, device)
    
    # Save results
    save_results_to_json(all_summaries, all_results)
    
    print("\n" + "="*70)
    print("✓ ALL DONE!")
    print(f"Results saved to: {CFG['results_dir']}")
    print("="*70 + "\n")

if __name__ == "__main__":
    main()


UNDERWATER IMAGE DENOISING - MODEL COMPARISON
Using device: cuda

Loading dataset...
Found 21521 image pairs
Full dataset size: 21521
Using a random 15% sample: 3228 images
Loading Attention model...
Loading Physics v2 model...
Loading Physics v3 model...
✓ All models loaded successfully

CALCULATING METRICS - BATCH 1
Methods: attention, classical_physics, phy_v2, attention+post


Processing Batch 1: 100%|██████████| 202/202 [02:15<00:00,  1.49it/s]



BATCH 1 RESULTS:
attention            | SSIM: 0.7406 ± 0.0880 | PSNR: 19.44 ± 2.88 dB
classical            | SSIM: 0.7129 ± 0.1069 | PSNR: 14.95 ± 2.81 dB
phy_v2               | SSIM: 0.8392 ± 0.0896 | PSNR: 22.03 ± 3.59 dB
attention_post       | SSIM: 0.8272 ± 0.0894 | PSNR: 19.18 ± 3.38 dB

CALCULATING METRICS - BATCH 2
Methods: phy_v3, phy_v3+post


Processing Batch 2: 100%|██████████| 202/202 [01:03<00:00,  3.17it/s]



BATCH 2 RESULTS:
phy_v3               | SSIM: 0.8322 ± 0.0989 | PSNR: 19.23 ± 3.90 dB
phy_v3_post          | SSIM: 0.7077 ± 0.1055 | PSNR: 15.37 ± 2.51 dB

FINAL SUMMARY - ALL METHODS
Attention                 | SSIM: 0.7406 ± 0.0880 | PSNR: 19.44 ± 2.88 dB
Classical Physics         | SSIM: 0.7129 ± 0.1069 | PSNR: 14.95 ± 2.81 dB
Physics v2                | SSIM: 0.8392 ± 0.0896 | PSNR: 22.03 ± 3.59 dB
Physics v3                | SSIM: 0.8322 ± 0.0989 | PSNR: 19.23 ± 3.90 dB
Attention + Post          | SSIM: 0.8272 ± 0.0894 | PSNR: 19.18 ± 3.38 dB
Physics v3 + Post         | SSIM: 0.7077 ± 0.1055 | PSNR: 15.37 ± 2.51 dB

GENERATING COMPARISON IMAGES
Total: 18 images × 8 columns


Generating comparisons: 100%|██████████| 18/18 [00:12<00:00,  1.40it/s]

✓ Saved 18 comparison images to: comparison_results\comparison_images
✓ Saved metrics to: comparison_results\metrics_summary.json

✓ ALL DONE!
Results saved to: comparison_results




