In [3]:
"""
TRIDENT Physics-Only Model - Underwater Image Restoration
Simplified version using only the Physics-Head U-Net
"""

import os, math, random, time, json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps
import cv2
from tqdm.auto import tqdm
from typing import Dict, Any, Tuple

# ============================================================================
# CONFIGURATION (consolidated from YAML)
# ============================================================================

CFG = {
    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "amp": True,
    
    "ema": {
        "enabled": True,
        "decay": 0.9995
    },
    
    "data": {
        "root_cam": "D:/ML_works/TRIDENT/DATA/Icam",
        "root_ref": "D:/ML_works/TRIDENT/DATA/Iclean",
        "img_size": 256,
        "filename_pattern": "{:d}.png",
        "index_start": 1,
        "count": 21521,
        "splits": {
            "train": 18000,
            "val": 1760,
            "test": 1761
        },
        "loader": {
            "batch_size": 32,# Increased from 24 (more memory available)
            "num_workers": 0,
            "pin_memory": True
        },
        "aug": {
            "hflip_p": 0.5,
            "rotate_deg": 10,
            "paired_crop_240": True,
            "jitter_cam": {
                "brightness": 0.1,
                "contrast": 0.1,
                "prob": 0.3
            },
            "gamma_train_choices": [0.9, 1.0, 1.1],
            "gamma_train_probs": [0.25, 0.50, 0.25],
            "gamma_eval": 1.0
        }
    },
    
    "model": {
        "colorspace": {
            "gamma": 2.2
        },
        "physics": {
            "unet_channels": [32, 64, 128, 256],
            "norm": "group",
            "act": "silu",
            "t_min": 0.02,
            "z_w_dim": 32
        }
    },
    
    "loss": {
        "weights": {
            "phys": 0.1,       # Reduced: now just regularization
            "tv_t": 5.0e-4,    # Keep same
            "A_prior": 5.0e-3, # Keep same
            "hetero": 0.2      # Keep same
        },
        "heteroscedastic": {
            "epsilon": 1.0e-6,
            "start_epoch": 12,
            "watchdog": {
                "mean_sigma2_thresh": 0.5,
                "over_thresh_required": 2,
                "cooldown_epochs": 2
            }
        }
    },
    
    "schedules": {
        "epochs": 30,
        "warmup_steps": 3000,
        "lr": {
            "base": 3.5e-4,
            "min": 1.0e-8,
            "decay": "cosine"
        },
        "hetero": {
            "start_epoch": 8,
            "weight": 0.2
        }
    },
    
    "optimizer": {
        "name": "adamw",
        "betas": [0.9, 0.99],
        "weight_decay": 1.0e-4,
        "grad_clip": 0.5
    },
    
    "logging": {
        "save_best": {
            "by_primary": True,
            "by_secondary": True
        },
        "export_test": {
            "tAzw": True,
            "save_images": True,
            "dir": "outputs/test_exports"
        }
    }
}

# Set seed for reproducibility
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(CFG["seed"])
DEVICE = CFG["device"]
print(f"Using device: {DEVICE}")

# ============================================================================
# COLOR SPACE UTILITIES
# ============================================================================

def to_linear(img_srgb: torch.Tensor, gamma: float = 2.2) -> torch.Tensor:
    """Convert sRGB to linear RGB"""
    return torch.clamp(img_srgb, 0.0, 1.0).pow(gamma)

def to_srgb(img_lin: torch.Tensor, gamma: float = 2.2) -> torch.Tensor:
    """Convert linear RGB to sRGB"""
    return torch.clamp(img_lin, 0.0, 1.0).pow(1.0 / gamma)

def clamp01(x: torch.Tensor) -> torch.Tensor:
    return torch.clamp(x, 0.0, 1.0)

# ============================================================================
# NEURAL NETWORK UTILITIES
# ============================================================================

def gn(num_channels, groups=8):
    """Group normalization helper"""
    g = min(groups, num_channels)
    return nn.GroupNorm(g, num_channels)

def act_fn(name="silu"):
    """Activation function helper"""
    return nn.SiLU() if name == "silu" else nn.ReLU(inplace=True)

# ============================================================================
# DATASET
# ============================================================================

class PairedDataset(Dataset):
    def __init__(self, split, cfg):
        self.cfg = cfg
        self.gamma = cfg["model"]["colorspace"]["gamma"]
        self.split = split
        self.train = split == "train"
        self.size = cfg["data"]["img_size"]
        self.aug_cfg = cfg["data"]["aug"]
        
        # Roots & filename pattern
        self.root_cam = cfg["data"]["root_cam"]
        self.root_ref = cfg["data"]["root_ref"]
        pat = cfg["data"]["filename_pattern"]
        start = int(cfg["data"]["index_start"])
        total = int(cfg["data"]["count"])
        
        # Deterministic ID list
        ids = list(range(start, start + total))
        
        # Fixed split counts
        n_train = int(cfg["data"]["splits"]["train"])
        n_val = int(cfg["data"]["splits"]["val"])
        n_test = int(cfg["data"]["splits"]["test"])
        assert n_train + n_val + n_test == total, "splits must sum to data.count"
        
        # Deterministic shuffle with seed
        rng = random.Random(cfg["seed"])
        rng.shuffle(ids)
        
        if split == "train":
            self.ids = ids[:n_train]
        elif split == "val":
            self.ids = ids[n_train:n_train+n_val]
        else:
            self.ids = ids[n_train+n_val:]
        
        # Verify existence
        keep = []
        for i in self.ids:
            fname = pat.format(i)
            p_cam = os.path.join(self.root_cam, fname)
            p_ref = os.path.join(self.root_ref, fname)
            if os.path.exists(p_cam) and os.path.exists(p_ref):
                keep.append((p_cam, p_ref))
        
        missing = len(self.ids) - len(keep)
        if missing > 0:
            print(f"[{split}] Warning: {missing} pairs missing; using {len(keep)}.")
        self.entries = keep
    
    def __len__(self):
        return len(self.entries)
    
    def gamma_corr(self, img, gamma=1.0):
        return clamp01(img ** gamma)
    
    def paired_augs(self, cam, ref):
        """Apply paired augmentations"""
        HFLIP = self.aug_cfg.get("hflip_p", 0.5)
        ROT = int(self.aug_cfg.get("rotate_deg", 10))
        
        if self.train and random.random() < HFLIP:
            cam = ImageOps.mirror(cam)
            ref = ImageOps.mirror(ref)
        
        if self.train and ROT > 0:
            deg = random.uniform(-ROT, ROT)
            cam = cam.rotate(deg, resample=Image.BILINEAR, fillcolor=None)
            ref = ref.rotate(deg, resample=Image.BILINEAR, fillcolor=None)
        
        if self.train and self.aug_cfg.get("paired_crop_240", True):
            w, h = cam.size
            if w >= 256 and h >= 256:
                nw = nh = 240
                x = random.randint(0, w - nw)
                y = random.randint(0, h - nh)
                cam = cam.crop((x, y, x+nw, y+nh)).resize((self.size, self.size), Image.BILINEAR)
                ref = ref.crop((x, y, x+nw, y+nh)).resize((self.size, self.size), Image.BILINEAR)
        else:
            cam = cam.resize((self.size, self.size), Image.BILINEAR)
            ref = ref.resize((self.size, self.size), Image.BILINEAR)
        
        return cam, ref
    
    def __getitem__(self, idx):
        from torchvision import transforms
        
        p_cam, p_ref = self.entries[idx]
        cam = Image.open(p_cam).convert("RGB")
        ref = Image.open(p_ref).convert("RGB")
        cam, ref = self.paired_augs(cam, ref)
        
        to_t = transforms.ToTensor()
        cam_s = to_t(cam)
        ref_s = to_t(ref)
        
        # Gamma augmentation
        if self.train:
            choices = self.aug_cfg["gamma_train_choices"]
            probs = self.aug_cfg["gamma_train_probs"]
            g = random.choices(choices, probs)[0]
            if abs(g - 1.0) > 1e-6:
                cam_s = self.gamma_corr(cam_s, g)
        else:
            cam_s = self.gamma_corr(cam_s, self.aug_cfg.get("gamma_eval", 1.0))
        
        # Color jitter
        if self.train and random.random() < self.aug_cfg["jitter_cam"]["prob"]:
            b = self.aug_cfg["jitter_cam"]["brightness"]
            c = self.aug_cfg["jitter_cam"]["contrast"]
            cam_s = transforms.ColorJitter(brightness=b, contrast=c)(cam_s)
        
        # Convert to linear
        cam_l = to_linear(cam_s, self.gamma)
        ref_l = to_linear(ref_s, self.gamma)
        
        return {
            "I_cam_srgb": cam_s,
            "I_cam_lin": cam_l,
            "I_clean_lin": ref_l,
            "cam_path": p_cam,
            "ref_path": p_ref
        }

def get_loaders(cfg):
    """Create train/val/test data loaders"""
    ds_train = PairedDataset("train", cfg)
    ds_val = PairedDataset("val", cfg)
    ds_test = PairedDataset("test", cfg)
    
    bs = cfg["data"]["loader"]["batch_size"]
    nw = cfg["data"]["loader"]["num_workers"]
    pin = cfg["data"]["loader"]["pin_memory"]
    
    dl_train = DataLoader(ds_train, batch_size=bs, shuffle=True, num_workers=nw, pin_memory=pin, drop_last=True)
    dl_val = DataLoader(ds_val, batch_size=bs, shuffle=False, num_workers=nw, pin_memory=pin)
    dl_test = DataLoader(ds_test, batch_size=bs, shuffle=False, num_workers=nw, pin_memory=pin)
    
    return dl_train, dl_val, dl_test

# ============================================================================
# MODEL COMPONENTS
# ============================================================================

class ConvBlock(nn.Module):
    """Double convolution block with normalization and activation"""
    def __init__(self, in_ch, out_ch, norm="group", act="silu"):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            gn(out_ch) if norm == "group" else nn.BatchNorm2d(out_ch),
            act_fn(act),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            gn(out_ch) if norm == "group" else nn.BatchNorm2d(out_ch),
            act_fn(act)
        )
    
    def forward(self, x):
        return self.conv(x)

class UNetS_Physics(nn.Module):
    """Physics-Head U-Net for predicting t, A, z_w"""
    def __init__(self, cfg):
        super().__init__()
        chs = cfg["model"]["physics"]["unet_channels"]  # [32, 64, 128, 256]
        norm = cfg["model"]["physics"]["norm"]
        act = cfg["model"]["physics"]["act"]
        self.t_min = cfg["model"]["physics"]["t_min"]
        self.zw_dim = cfg["model"]["physics"]["z_w_dim"]
        
        # Encoder
        self.e1 = ConvBlock(3, chs[0], norm, act)
        self.p1 = nn.MaxPool2d(2)
        self.e2 = ConvBlock(chs[0], chs[1], norm, act)
        self.p2 = nn.MaxPool2d(2)
        self.e3 = ConvBlock(chs[1], chs[2], norm, act)
        self.p3 = nn.MaxPool2d(2)
        self.e4 = ConvBlock(chs[2], chs[3], norm, act)
        
        # Transmission map head
        self.t_head = nn.Sequential(
            nn.Conv2d(chs[3], 64, 3, padding=1),
            act_fn(act),
            nn.Conv2d(64, 3, 1)
        )
        
        # Atmospheric light head
        self.A_pool = nn.AdaptiveAvgPool2d(1)
        self.A_mlp = nn.Sequential(
            nn.Conv2d(chs[3], 128, 1),
            act_fn(act),
            nn.Conv2d(128, 3, 1)
        )
        
        # Water properties embedding
        self.zw_pool = nn.AdaptiveAvgPool2d(1)
        self.zw_mlp = nn.Sequential(
            nn.Linear(chs[3], 256),
            nn.SiLU(),
            nn.Linear(256, self.zw_dim)
        )
        
        # Heteroscedastic uncertainty head
        self.sigma_head = nn.Sequential(
            nn.Conv2d(chs[3], 16, 3, padding=1),
            act_fn(act),
            nn.Conv2d(16, 1, 1)
        )
    
    def forward(self, I_cam_lin):
        # Encode
        x1 = self.e1(I_cam_lin)      # 32, H/2
        x2 = self.e2(self.p1(x1))    # 64, H/4
        x3 = self.e3(self.p2(x2))    # 128, H/8
        x4 = self.e4(self.p3(x3))    # 256, H/16 (bottleneck)
        
        # Transmission map (range: [t_min, 1])
        t_logits = self.t_head(x4)
        t = torch.sigmoid(t_logits) * (1 - 2*self.t_min) + self.t_min
        
        # Atmospheric light (global, RGB)
        A = self.A_mlp(self.A_pool(x4))  # (B, 3, 1, 1)
        
        # Water properties embedding
        zw = self.zw_mlp(self.zw_pool(x4).flatten(1))  # (B, zw_dim)
        
        # Heteroscedastic uncertainty
        sigma2 = F.softplus(self.sigma_head(x4))  # (B, 1, H/16, W/16)
        
        return {
            "t": t,
            "A": A,
            "zw": zw,
            "sigma2": sigma2
        }

class TRIDENT(nn.Module):
    """Simplified TRIDENT - Physics-Head Only (CORRECTED - No Data Leak)"""
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.gamma = cfg["model"]["colorspace"]["gamma"]
        self.t_min = cfg["model"]["physics"]["t_min"]  # Add this line
        self.phys = UNetS_Physics(cfg)
    
    def forward(self, batch):
        I_cam_l = batch["I_cam_lin"].to(DEVICE)
        
        # Physics head forward (only uses I_cam)
        ph = self.phys(I_cam_l)
        t, A, zw, sigma2 = ph["t"], ph["A"], ph["zw"], ph["sigma2"]
        
        # Upsample to full resolution
        H, W = I_cam_l.shape[-2], I_cam_l.shape[-1]
        t = F.interpolate(t, size=(H, W), mode="bilinear", align_corners=False)
        sigma2 = F.interpolate(sigma2, size=(H, W), mode="bilinear", align_corners=False)
        
        # Broadcast A to full resolution
        A_b = A.expand(-1, -1, H, W)
        
        # ===== INVERT THE IMAGING MODEL (NO DATA LEAK) =====
        # Forward model: I_cam = I_clean * t + A * (1 - t)
        # Inverted model: I_clean = (I_cam - A * (1 - t)) / t
        numerator = I_cam_l - A_b * (1 - t)
        denominator = torch.maximum(t, torch.tensor(self.t_min, device=t.device))
        I_hat_l = numerator / denominator
        I_hat_l = clamp01(I_hat_l)
        # ===================================================
        
        # Forward projection (only if ground truth available, for physics loss)
        I_phys_l = None
        if "I_clean_lin" in batch:
            I_ref_l = batch["I_clean_lin"].to(DEVICE)
            I_phys_l = I_ref_l * t + A_b * (1 - t)
        
        return {
            "I_hat_lin": I_hat_l,      # Predicted clean image (from inversion)
            "I_phys_lin": I_phys_l,    # Forward projection (for regularization)
            "t": t,
            "A": A,
            "zw": zw,
            "sigma2": sigma2
        }


# ============================================================================
# EMA
# ============================================================================

class EMA:
    """Exponential Moving Average for model parameters"""
    def __init__(self, model, decay=0.999):
        self.ema = {k: v.detach().clone().float() for k, v in model.state_dict().items()}
        self.decay = decay
    
    @torch.no_grad()
    def update(self, model):
        sd = model.state_dict()
        for k in self.ema.keys():
            self.ema[k].mul_(self.decay).add_(sd[k].detach().float(), alpha=1 - self.decay)
    
    def copy_to(self, model):
        model.load_state_dict(self.ema, strict=True)

# ============================================================================
# LOSSES
# ============================================================================

def ssim_srgb(x, y, C1=0.01**2, C2=0.03**2):
    """SSIM in sRGB space"""
    mu_x = F.avg_pool2d(x, 11, stride=1, padding=5)
    mu_y = F.avg_pool2d(y, 11, stride=1, padding=5)
    sigma_x = F.avg_pool2d(x*x, 11, 1, 5) - mu_x*mu_x
    sigma_y = F.avg_pool2d(y*y, 11, 1, 5) - mu_y*mu_y
    sigma_xy = F.avg_pool2d(x*y, 11, 1, 5) - mu_x*mu_y
    ssim_n = (2*mu_x*mu_y + C1) * (2*sigma_xy + C2)
    ssim_d = (mu_x*mu_x + mu_y*mu_y + C1) * (sigma_x + sigma_y + C2)
    ssim_map = ssim_n / (ssim_d + 1e-8)
    return ssim_map.mean()

def psnr_srgb(x, y):
    """PSNR in sRGB space"""
    mse = F.mse_loss(x, y)
    return 10 * torch.log10(1.0 / (mse + 1e-8))

def tv_loss_t(t):
    """Total variation loss for transmission map"""
    dx = t[:, :, :, 1:] - t[:, :, :, :-1]
    dy = t[:, :, 1:, :] - t[:, :, :-1, :]
    return (dx.abs().mean() + dy.abs().mean())

def A_prior_loss(A, I_cam_lin):
    """Atmospheric light should be close to mean of camera input"""
    mean_cam = I_cam_lin.mean(dim=(2, 3), keepdim=True)
    return F.mse_loss(A, mean_cam)

def heteroscedastic_l1(err_lin, sigma2):
    """Heteroscedastic L1 loss"""
    eps = 1e-6
    return ((err_lin / (sigma2 + eps)) + torch.log(sigma2 + eps)).mean()

def compute_losses(batch, out, epoch, hetero_on):
    """Compute all losses for physics-only model (CORRECTED)"""
    W = CFG["loss"]["weights"]
    
    # Get tensors
    I_hat_l = out["I_hat_lin"]       # Predicted clean image
    I_phys_l = out["I_phys_lin"]     # Forward projection (can be None)
    I_cam_l = batch["I_cam_lin"].to(DEVICE)
    I_ref_l = batch["I_clean_lin"].to(DEVICE)
    
    # ===== MAIN LOSS: Reconstruction =====
    # Compare predicted clean with ground truth clean
    L_recon = F.l1_loss(I_hat_l, I_ref_l)
    
    # ===== REGULARIZATION: Physics consistency =====
    # Forward projection should match camera input
    L_phys = torch.tensor(0., device=DEVICE)
    if I_phys_l is not None:
        L_phys = F.l1_loss(I_phys_l, I_cam_l)
    
    # ===== REGULARIZATION: Transmission smoothness =====
    L_tv = tv_loss_t(out["t"])
    
    # ===== REGULARIZATION: Atmospheric light prior =====
    L_Ap = A_prior_loss(out["A"], I_cam_l)
    
    # ===== UNCERTAINTY: Heteroscedastic loss =====
    hetero_loss = torch.tensor(0., device=DEVICE)
    if hetero_on:
        err = (I_hat_l - I_ref_l).abs()  # Error in predicted clean
        hetero_loss = heteroscedastic_l1(err, out["sigma2"])
    
    # ===== TOTAL LOSS =====
    # Main reconstruction loss + weighted regularization terms
    loss = L_recon + W["phys"] * L_phys + W["tv_t"] * L_tv + W["A_prior"] * L_Ap
    
    if hetero_on:
        loss = loss + CFG["schedules"]["hetero"]["weight"] * hetero_loss
    
    # ===== LOGGING =====
    logs = {
        "L_total": loss.item(),
        "Recon": L_recon.item(),      # Main loss
        "Phys": L_phys.item(),         # Physics regularization
        "TVt": L_tv.item(),            # TV regularization
        "Apr": L_Ap.item(),            # Prior regularization
        "Hetero": hetero_loss.item() if hetero_on else 0.0
    }
    
    return loss, logs
# ============================================================================
# TRAINING UTILITIES
# ============================================================================

class HeteroWatchdog:
    """Watchdog for heteroscedastic uncertainty"""
    def __init__(self, cfg):
        wd = cfg["loss"]["heteroscedastic"]["watchdog"]
        self.th = wd["mean_sigma2_thresh"]
        self.req = wd["over_thresh_required"]
        self.cool = wd["cooldown_epochs"]
        self.count = 0
        self.cooldown = 0
    
    def update(self, mean_sigma2):
        if mean_sigma2 > self.th:
            self.count += 1
        else:
            self.count = 0
        
        triggered = False
        if self.count >= self.req:
            self.cooldown = self.cool
            self.count = 0
            triggered = True
        
        if self.cooldown > 0:
            self.cooldown -= 1
            return False, True  # hetero disabled during cooldown
        
        return triggered, False

def cosine_lr(optimizer, step, total_steps, base_lr, min_lr, warmup_steps):
    """Cosine learning rate schedule with warmup"""
    if step < warmup_steps:
        lr = base_lr * step / max(1, warmup_steps)
    else:
        t = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        lr = min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * t))
    
    for pg in optimizer.param_groups:
        pg["lr"] = lr
    
    return lr

def save_checkpoint(state_dict, path):
    """Save checkpoint to disk"""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    # Move to CPU before saving
    cpu_state = {k: v.cpu().clone() for k, v in state_dict.items()}
    torch.save(cpu_state, path)
    print(f"✓ Saved: {path}")

# ============================================================================
# TRAINING
# ============================================================================

def train_val_setup():
    """Initialize model, optimizer, and data loaders"""
    dl_train, dl_val, dl_test = get_loaders(CFG)
    model = TRIDENT(CFG).to(DEVICE)
    
    opt = torch.optim.AdamW(
        model.parameters(),
        lr=CFG["schedules"]["lr"]["base"],
        betas=tuple(CFG["optimizer"]["betas"]),
        weight_decay=CFG["optimizer"]["weight_decay"]
    )
    
    scaler = torch.amp.GradScaler('cuda', enabled=(CFG["amp"] and DEVICE.startswith("cuda")))
    ema = EMA(model, decay=CFG["ema"]["decay"]) if CFG["ema"]["enabled"] else None
    
    return model, opt, scaler, ema, dl_train, dl_val, dl_test

def run_training():
    """Main training loop"""
    import gc
    
    model, opt, scaler, ema, dl_train, dl_val, dl_test = train_val_setup()
    steps_per_epoch = len(dl_train)
    total_steps = steps_per_epoch * CFG["schedules"]["epochs"]
    hetero_wd = HeteroWatchdog(CFG)
    
    best_ssim = -1.0
    best_l1 = float('inf')
    step = 0
    
    for epoch in range(1, CFG["schedules"]["epochs"] + 1):
        # Memory cleanup
        torch.cuda.empty_cache()
        gc.collect()
        
        # Training phase
        model.train()
        epoch_logs = []
        pbar = tqdm(dl_train, desc=f"Epoch {epoch}/{CFG['schedules']['epochs']} [train]", leave=True)
        
        for b, batch in enumerate(pbar):
            lr = cosine_lr(opt, step, total_steps,
                          CFG["schedules"]["lr"]["base"],
                          CFG["schedules"]["lr"]["min"],
                          CFG["schedules"]["warmup_steps"])
            
            opt.zero_grad(set_to_none=True)
            
            # Forward pass with AMP
            ctx = torch.amp.autocast('cuda', enabled=(CFG["amp"] and DEVICE.startswith("cuda")))
            with ctx:
                out = model(batch)
                hetero_on = (epoch >= CFG["schedules"]["hetero"]["start_epoch"]) and (hetero_wd.cooldown == 0)
                loss, logs = compute_losses(batch, out, epoch, hetero_on)
            
            # Backward pass
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["optimizer"]["grad_clip"])
            scaler.step(opt)
            scaler.update()
            
            if CFG["ema"]["enabled"]:
                ema.update(model)
            
            epoch_logs.append(logs)
            step += 1
            
            pbar.set_postfix({
                "lr": f"{lr:.2e}",
                "L": f"{logs['L_total']:.3f}",
                "Recon": f"{logs['Recon']:.3f}",
                "Phys": f"{logs['Phys']:.3f}"
            })
            
            # Periodic memory cleanup
            if step % 200 == 0:
                torch.cuda.empty_cache()
        
        # Memory cleanup before validation
        torch.cuda.empty_cache()
        gc.collect()
        
        # Validation phase
        model.eval()
        val_model = TRIDENT(CFG).to(DEVICE)
        if CFG["ema"]["enabled"]:
            ema.copy_to(val_model)
        else:
            val_model.load_state_dict(model.state_dict())
        val_model.eval()
        
        ssim_vals, l1_vals, sigma_means = [], [], []
        pbar_val = tqdm(dl_val, desc=f"Epoch {epoch} [val]", leave=False)
        
        with torch.no_grad():
            for batch in pbar_val:
                out = val_model(batch)
                I_hat_s = to_srgb(out["I_hat_lin"], CFG["model"]["colorspace"]["gamma"])
                I_ref_s = to_srgb(batch["I_clean_lin"].to(DEVICE), CFG["model"]["colorspace"]["gamma"])
                
                ssim_vals.append(ssim_srgb(I_hat_s, I_ref_s).item())
                l1_vals.append(F.l1_loss(out["I_hat_lin"], batch["I_clean_lin"].to(DEVICE)).item())
                sigma_means.append(out["sigma2"].mean().item())
                
                pbar_val.set_postfix({
                    "SSIM": f"{np.mean(ssim_vals):.3f}",
                    "L1": f"{np.mean(l1_vals):.3f}"
                })
        
        val_ssim = float(np.mean(ssim_vals))
        val_l1 = float(np.mean(l1_vals))
        mean_sigma = float(np.mean(sigma_means))
        
        # Delete validation model
        del val_model
        torch.cuda.empty_cache()
        gc.collect()
        
        # Watchdog updates
        trig, disabled = hetero_wd.update(mean_sigma)
        
        # Save best models immediately
        if val_ssim > best_ssim:
            best_ssim = val_ssim
            print(f"  → New best SSIM: {val_ssim:.4f}")
            if CFG["ema"]["enabled"]:
                save_checkpoint(ema.ema, "checkpointsphy/best_ssim.pt")
            else:
                save_checkpoint(model.state_dict(), "checkpointsphy/best_ssim.pt")
        
        if val_l1 < best_l1:
            best_l1 = val_l1
            print(f"  → New best L1: {val_l1:.4f}")
            if CFG["ema"]["enabled"]:
                save_checkpoint(ema.ema, "checkpointsphy/best_l1.pt")
            else:
                save_checkpoint(model.state_dict(), "checkpointsphy/best_l1.pt")
        
        print(f"[Epoch {epoch}] SSIM={val_ssim:.4f}  L1={val_l1:.4f}  "
              f"meanσ²={mean_sigma:.3f}  heteroCooldown={hetero_wd.cooldown}")
    
    # Save final checkpoint
    save_checkpoint(model.state_dict(), "checkpointsphy/last.pt")
    
    print("\n" + "="*70)
    print(f"Training complete!")
    print(f"Best SSIM: {best_ssim:.4f}")
    print(f"Best L1: {best_l1:.4f}")
    print("="*70 + "\n")

# ============================================================================
# VERIFICATION: Test that no data leak exists
# ============================================================================

def verify_no_data_leak():
    """
    Verify that the model can run inference without ground truth
    """
    print("\n" + "="*70)
    print("Verifying No Data Leak")
    print("="*70)
    
    # Create model
    model = TRIDENT(CFG).to(DEVICE)
    model.eval()
    
    # Create dummy input (NO GROUND TRUTH)
    dummy_cam_srgb = torch.rand(2, 3, 256, 256)
    dummy_cam_lin = to_linear(dummy_cam_srgb, CFG["model"]["colorspace"]["gamma"])
    
    dummy_batch = {
        "I_cam_lin": dummy_cam_lin
        # Note: NO "I_clean_lin" key!
    }
    
    # Test forward pass
    try:
        with torch.no_grad():
            output = model(dummy_batch)
        
        print("✓ Model runs without ground truth!")
        print(f"  Output keys: {list(output.keys())}")
        print(f"  I_hat shape: {output['I_hat_lin'].shape}")
        print(f"  t shape: {output['t'].shape}")
        print(f"  A shape: {output['A'].shape}")
        
        # Check that I_phys_lin is None (since no ground truth)
        if output['I_phys_lin'] is None:
            print("✓ I_phys_lin is None (expected, no ground truth)")
        
        # Check value ranges
        print(f"\nValue ranges:")
        print(f"  I_hat: [{output['I_hat_lin'].min():.3f}, {output['I_hat_lin'].max():.3f}]")
        print(f"  t: [{output['t'].min():.3f}, {output['t'].max():.3f}]")
        print(f"  A: [{output['A'].min():.3f}, {output['A'].max():.3f}]")
        
        print("\n✓ Verification passed! No data leak detected.")
        
    except Exception as e:
        print(f"✗ Verification failed: {e}")
        print("  Make sure you've replaced the TRIDENT class with the corrected version.")
    
    print("="*70 + "\n")


# ============================================================================
# EVALUATION & EXPORT
# ============================================================================

def evaluate_and_export():
    """Evaluate on test set and export results"""
    # Load best SSIM model
    model = TRIDENT(CFG).to(DEVICE)
    
    checkpoint_path = "checkpointsphy/last2.pt"
    if not os.path.exists(checkpoint_path):
        print(f"[ERROR] Checkpoint not found: {checkpoint_path}")
        return
    
    sd = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(sd, strict=True)
    model.eval()
    
    print(f"Loaded checkpoint: {checkpoint_path}")
    
    _, _, dl_test = get_loaders(CFG)
    
    # Create export directories
    base_export_dir = CFG["logging"]["export_test"]["dir"]
    os.makedirs(base_export_dir, exist_ok=True)
    
    if CFG["logging"]["export_test"].get("save_images", False):
        img_export_dir = os.path.join(base_export_dir, "images")
        os.makedirs(img_export_dir, exist_ok=True)
    
    SSIMs, PSNRs, L1s = [], [], []
    
    print("\n" + "="*70)
    print("EVALUATION ON TEST SET")
    print("="*70)
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dl_test, desc="Evaluating and Exporting")):
            out = model(batch)
            I_hat_s = to_srgb(out["I_hat_lin"], CFG["model"]["colorspace"]["gamma"])
            I_ref_s = to_srgb(batch["I_clean_lin"].to(DEVICE), CFG["model"]["colorspace"]["gamma"])
            
            # --- Image Saving Logic ---
            if CFG["logging"]["export_test"].get("save_images", False):
                I_cam_s = batch["I_cam_srgb"].to(DEVICE)
                cam_paths = batch["cam_path"]
                
                # Save each image in the batch individually
                for i in range(I_hat_s.shape[0]):
                    base_name = os.path.basename(cam_paths[i])
                    
                    # Create side-by-side comparison: [Input | Output | Ground Truth]
                    comparison_grid = torch.cat([I_cam_s[i], I_hat_s[i], I_ref_s[i]], dim=-1)
                    
                    # Save comparison image
                    save_path = os.path.join(img_export_dir, base_name)
                    from torchvision.utils import save_image
                    save_image(comparison_grid, save_path)
            
            # --- Metric Calculation ---
            SSIMs.append(ssim_srgb(I_hat_s, I_ref_s).item())
            PSNRs.append(psnr_srgb(I_hat_s, I_ref_s).item())
            L1s.append(F.l1_loss(out["I_hat_lin"], batch["I_clean_lin"].to(DEVICE)).item())
            
            # --- Physics Export Logic ---
            if CFG["logging"]["export_test"]["tAzw"]:
                # Save one .npz file per batch
                batch_id = os.path.basename(batch["cam_path"][0]).split('.')[0]
                np.savez_compressed(
                    os.path.join(base_export_dir, f"physics_{batch_id}.npz"),
                    t=out["t"].detach().cpu().numpy(),
                    A=out["A"].detach().cpu().numpy(),
                    zw=out["zw"].detach().cpu().numpy(),
                    sigma2=out["sigma2"].detach().cpu().numpy()
                )
    
    # Print results
    print("\n" + "="*70)
    print("TEST RESULTS")
    print("="*70)
    print(f"SSIM:  {np.mean(SSIMs):.4f} ± {np.std(SSIMs):.4f}")
    print(f"PSNR:  {np.mean(PSNRs):.2f} ± {np.std(PSNRs):.2f} dB")
    print(f"L1:    {np.mean(L1s):.4f} ± {np.std(L1s):.4f}")
    print("="*70)
    
    # Save metrics to JSON
    metrics = {
        "ssim_mean": float(np.mean(SSIMs)),
        "ssim_std": float(np.std(SSIMs)),
        "psnr_mean": float(np.mean(PSNRs)),
        "psnr_std": float(np.std(PSNRs)),
        "l1_mean": float(np.mean(L1s)),
        "l1_std": float(np.std(L1s))
    }
    
    metrics_path = os.path.join(base_export_dir, "test_metrics.json")
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=2)
    
    print(f"\n✓ Metrics saved to: {metrics_path}")
    
    if CFG["logging"]["export_test"].get("save_images", False):
        print(f"✓ Images saved to: {img_export_dir}")
    
    if CFG["logging"]["export_test"]["tAzw"]:
        print(f"✓ Physics parameters saved to: {base_export_dir}")

# ============================================================================
# MAIN EXECUTION
# ============================================================================



Using device: cuda


In [None]:
if __name__ == "__main__":
    import sys
    
    print("\n" + "="*70)
    print("TRIDENT PHYSICS-ONLY MODEL")
    print("Underwater Image Restoration via Physics-Based Approach")
    print("="*70 + "\n")
    verify_no_data_leak()
    # Check if we should run training or just evaluation
    if len(sys.argv) > 1 and sys.argv[1] == "eval":
        print("Running evaluation only...\n")
        evaluate_and_export()
    else:
        print("Starting training...\n")
        run_training()
        print("\nStarting evaluation...\n")
        evaluate_and_export()
    
    print("\n" + "="*70)
    print("ALL DONE!")
    print("="*70 + "\n")

In [4]:
evaluate_and_export()

Loaded checkpoint: checkpointsphy/last2.pt

EVALUATION ON TEST SET


Evaluating and Exporting:   0%|          | 0/56 [00:00<?, ?it/s]


TEST RESULTS
SSIM:  0.8385 ± 0.0181
PSNR:  20.40 ± 0.95 dB
L1:    0.0528 ± 0.0068

✓ Metrics saved to: outputs/test_exports\test_metrics.json
✓ Images saved to: outputs/test_exports\images
✓ Physics parameters saved to: outputs/test_exports
