In [1]:
!pip install pillow opencv-python tqdm numpy




In [2]:
"""
TRIDENT Physics-Only Model - Enhanced Version
Underwater Image Restoration with Perceptual Losses, CBAM, and Active z_w
"""

import os, math, random, time, json, gc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps
from tqdm.auto import tqdm
from torchvision import transforms, models

# ============================================================================
# CONFIGURATION
# ============================================================================

CFG = {
    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "amp": True,
    
    "ema": {"enabled": True, "decay": 0.9995},
    
    "data": {
        "root_cam": "/teamspace/studios/this_studio/data/DATA/Icam",
        "root_ref": "/teamspace/studios/this_studio/data/DATA/Iclean",
        "img_size": 256,
        "filename_pattern": "{:d}.png",
        "index_start": 1,
        "count": 21521,
        "splits": {"train": 18000, "val": 1760, "test": 1761},
        "loader": {
            "batch_size": 128,
            "num_workers": 16,
            "pin_memory": True
        },
        "aug": {
            "hflip_p": 0.5,
            "rotate_deg": 10,
            "paired_crop_240": True,
            "jitter_cam": {"brightness": 0.1, "contrast": 0.1, "prob": 0.3},
            "gamma_train_choices": [0.9, 1.0, 1.1],
            "gamma_train_probs": [0.25, 0.50, 0.25],
            "gamma_eval": 1.0,
            "underwater_beta": [0.05, 0.25],
            "underwater_A": [0.3, 0.6],
            "underwater_prob": 0.3
        }
    },
    
    "model": {
        "colorspace": {"gamma": 2.2},
        "physics": {
            "unet_channels": [32, 64, 128, 256, 512],
            "norm": "group",
            "act": "silu",
            "t_min": 0.02,
            "z_w_dim": 32
        }
    },
    
    "loss": {
        "weights": {
            "phys": 0.05,
            "tv_t": 2.0e-4,
            "A_prior": 2.0e-3,
            "hetero": 0.2,
            "t_bias": 5.0e-5,
            "perceptual": 0.05,
            "frequency": 0.02
        },
        "heteroscedastic": {
            "epsilon": 1.0e-6,
            "start_epoch": 12,
            "watchdog": {
                "mean_sigma2_thresh": 0.5,
                "over_thresh_required": 2,
                "cooldown_epochs": 2
            }
        }
    },
    
    "schedules": {
        "epochs": 30,
        "warmup_steps": 3000,
        "lr": {"base": 2.0e-4, "min": 1.0e-8, "decay": "cosine"},
        "hetero": {"start_epoch": 8, "weight": 0.2}
    },
    
    "optimizer": {
        "name": "adamw",
        "betas": [0.9, 0.99],
        "weight_decay": 1.0e-4,
        "grad_clip": 0.5
    },
    
    "logging": {
        "save_best": {"by_primary": True, "by_secondary": True},
        "export_test": {"tAzw": True, "save_images": True, "dir": "outputs/test_exports"}
    }
}

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(CFG["seed"])
DEVICE = CFG["device"]
print(f"Using device: {DEVICE}")

# ============================================================================
# COLOR SPACE UTILITIES
# ============================================================================

def to_linear(img_srgb: torch.Tensor, gamma: float = 2.2) -> torch.Tensor:
    return torch.clamp(img_srgb, 0.0, 1.0).pow(gamma)

def to_srgb(img_lin: torch.Tensor, gamma: float = 2.2) -> torch.Tensor:
    return torch.clamp(img_lin, 0.0, 1.0).pow(1.0 / gamma)

def clamp01(x: torch.Tensor) -> torch.Tensor:
    return torch.clamp(x, 0.0, 1.0)

# ============================================================================
# NEURAL NETWORK UTILITIES
# ============================================================================

def gn(num_channels, groups=8):
    g = min(groups, num_channels)
    return nn.GroupNorm(g, num_channels)

def act_fn(name="silu"):
    return nn.SiLU() if name == "silu" else nn.ReLU(inplace=True)

# ============================================================================
# DATASET
# ============================================================================

class PairedDataset(Dataset):
    def __init__(self, split, cfg):
        self.cfg = cfg
        self.gamma = cfg["model"]["colorspace"]["gamma"]
        self.split = split
        self.train = split == "train"
        self.size = cfg["data"]["img_size"]
        self.aug_cfg = cfg["data"]["aug"]
        
        self.root_cam = cfg["data"]["root_cam"]
        self.root_ref = cfg["data"]["root_ref"]
        pat = cfg["data"]["filename_pattern"]
        start = int(cfg["data"]["index_start"])
        total = int(cfg["data"]["count"])
        
        ids = list(range(start, start + total))
        
        n_train = int(cfg["data"]["splits"]["train"])
        n_val = int(cfg["data"]["splits"]["val"])
        n_test = int(cfg["data"]["splits"]["test"])
        assert n_train + n_val + n_test == total
        
        rng = random.Random(cfg["seed"])
        rng.shuffle(ids)
        
        if split == "train":
            self.ids = ids[:n_train]
        elif split == "val":
            self.ids = ids[n_train:n_train+n_val]
        else:
            self.ids = ids[n_train+n_val:]
        
        keep = []
        for i in self.ids:
            fname = pat.format(i)
            p_cam = os.path.join(self.root_cam, fname)
            p_ref = os.path.join(self.root_ref, fname)
            if os.path.exists(p_cam) and os.path.exists(p_ref):
                keep.append((p_cam, p_ref))
        
        missing = len(self.ids) - len(keep)
        if missing > 0:
            print(f"[{split}] Warning: {missing} pairs missing; using {len(keep)}.")
        self.entries = keep
    
    def __len__(self):
        return len(self.entries)
    
    def gamma_corr(self, img, gamma=1.0):
        return clamp01(img ** gamma)
    
    def simulate_underwater(self, clean_img):
        if not self.train or random.random() > self.aug_cfg["underwater_prob"]:
            return clean_img
        
        beta = random.uniform(*self.aug_cfg["underwater_beta"])
        A_val = random.uniform(*self.aug_cfg["underwater_A"])
        
        C, H, W = clean_img.shape
        depth = torch.linspace(1.0, 5.0, H).view(1, H, 1).expand(C, H, W)
        depth = depth + torch.rand(C, H, W) * 0.5
        
        t = torch.exp(-beta * depth)
        degraded = clean_img * t + A_val * (1 - t)
        
        return torch.clamp(degraded, 0.0, 1.0)
    
    def paired_augs(self, cam, ref):
        HFLIP = self.aug_cfg.get("hflip_p", 0.5)
        ROT = int(self.aug_cfg.get("rotate_deg", 10))
        
        if self.train and random.random() < HFLIP:
            cam = ImageOps.mirror(cam)
            ref = ImageOps.mirror(ref)
        
        if self.train and ROT > 0:
            deg = random.uniform(-ROT, ROT)
            cam = cam.rotate(deg, resample=Image.BILINEAR)
            ref = ref.rotate(deg, resample=Image.BILINEAR)
        
        if self.train and self.aug_cfg.get("paired_crop_240", True):
            w, h = cam.size
            if w >= 256 and h >= 256:
                nw = nh = 240
                x = random.randint(0, w - nw)
                y = random.randint(0, h - nh)
                cam = cam.crop((x, y, x+nw, y+nh)).resize((self.size, self.size), Image.BILINEAR)
                ref = ref.crop((x, y, x+nw, y+nh)).resize((self.size, self.size), Image.BILINEAR)
        else:
            cam = cam.resize((self.size, self.size), Image.BILINEAR)
            ref = ref.resize((self.size, self.size), Image.BILINEAR)
        
        return cam, ref
    
    def __getitem__(self, idx):
        p_cam, p_ref = self.entries[idx]
        cam = Image.open(p_cam).convert("RGB")
        ref = Image.open(p_ref).convert("RGB")
        cam, ref = self.paired_augs(cam, ref)
        
        to_t = transforms.ToTensor()
        ref_l = to_linear(to_t(ref), self.gamma)
        
        if self.train and random.random() < self.aug_cfg["underwater_prob"]:
            cam_l = self.simulate_underwater(ref_l)
            cam_s = to_srgb(cam_l, self.gamma)
        else:
            cam_s = to_t(cam)
            cam_l = to_linear(cam_s, self.gamma)
        
        if self.train:
            choices = self.aug_cfg["gamma_train_choices"]
            probs = self.aug_cfg["gamma_train_probs"]
            g = random.choices(choices, probs)[0]
            if abs(g - 1.0) > 1e-6:
                cam_s = self.gamma_corr(cam_s, g)
                cam_l = to_linear(cam_s, self.gamma)
        else:
            cam_s = self.gamma_corr(cam_s, self.aug_cfg.get("gamma_eval", 1.0))
            cam_l = to_linear(cam_s, self.gamma)
        
        if self.train and random.random() < self.aug_cfg["jitter_cam"]["prob"]:
            b = self.aug_cfg["jitter_cam"]["brightness"]
            c = self.aug_cfg["jitter_cam"]["contrast"]
            cam_s = transforms.ColorJitter(brightness=b, contrast=c)(cam_s)
            cam_l = to_linear(cam_s, self.gamma)
        
        return {
            "I_cam_srgb": cam_s,
            "I_cam_lin": cam_l,
            "I_clean_lin": ref_l,
            "cam_path": p_cam,
            "ref_path": p_ref
        }

def get_loaders(cfg):
    ds_train = PairedDataset("train", cfg)
    ds_val = PairedDataset("val", cfg)
    ds_test = PairedDataset("test", cfg)
    
    bs = cfg["data"]["loader"]["batch_size"]
    nw = cfg["data"]["loader"]["num_workers"]
    pin = cfg["data"]["loader"]["pin_memory"]
    
    dl_train = DataLoader(ds_train, batch_size=bs, shuffle=True, num_workers=nw, pin_memory=pin, drop_last=True, persistent_workers = True)
    dl_val = DataLoader(ds_val, batch_size=bs, shuffle=False, num_workers=nw, pin_memory=pin, persistent_workers = True)
    dl_test = DataLoader(ds_test, batch_size=bs, shuffle=False, num_workers=nw, pin_memory=pin, persistent_workers = True)
    
    return dl_train, dl_val, dl_test

# ============================================================================
# MODEL COMPONENTS
# ============================================================================

class CBAM(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.ca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )
        self.sa = nn.Sequential(
            nn.Conv2d(2, 1, 7, padding=3),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        ch_attn = self.ca(x)
        x = x * ch_attn
        
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out = torch.max(x, dim=1, keepdim=True)[0]
        spatial = torch.cat([avg_out, max_out], dim=1)
        sp_attn = self.sa(spatial)
        return x * sp_attn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, norm="group", act="silu"):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            gn(out_ch) if norm == "group" else nn.BatchNorm2d(out_ch),
            act_fn(act),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            gn(out_ch) if norm == "group" else nn.BatchNorm2d(out_ch),
            act_fn(act)
        )
    
    def forward(self, x):
        return self.conv(x)

class UNetS_Physics(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        chs = cfg["model"]["physics"]["unet_channels"]
        norm = cfg["model"]["physics"]["norm"]
        act = cfg["model"]["physics"]["act"]
        self.t_min = cfg["model"]["physics"]["t_min"]
        self.zw_dim = cfg["model"]["physics"]["z_w_dim"]
        
        self.e1 = nn.Sequential(ConvBlock(3, chs[0], norm, act), CBAM(chs[0]))
        self.p1 = nn.MaxPool2d(2)
        self.e2 = nn.Sequential(ConvBlock(chs[0], chs[1], norm, act), CBAM(chs[1]))
        self.p2 = nn.MaxPool2d(2)
        self.e3 = nn.Sequential(ConvBlock(chs[1], chs[2], norm, act), CBAM(chs[2]))
        self.p3 = nn.MaxPool2d(2)
        self.e4 = nn.Sequential(ConvBlock(chs[2], chs[3], norm, act), CBAM(chs[3]))
        self.p4 = nn.MaxPool2d(2)
        self.e5 = nn.Sequential(ConvBlock(chs[3], chs[4], norm, act), CBAM(chs[4]))
        
        self.t_head = nn.Sequential(
            nn.Conv2d(chs[4], 64, 3, padding=1),
            act_fn(act),
            nn.Conv2d(64, 3, 1)
        )
        
        self.A_pool = nn.AdaptiveAvgPool2d(1)
        self.A_mlp = nn.Sequential(
            nn.Conv2d(chs[4], 128, 1),
            act_fn(act),
            nn.Conv2d(128, 3, 1)
        )
        
        self.zw_pool = nn.AdaptiveAvgPool2d(1)
        self.zw_mlp = nn.Sequential(
            nn.Linear(chs[4], 256),
            nn.SiLU(),
            nn.Linear(256, self.zw_dim)
        )
        
        self.sigma_head = nn.Sequential(
            nn.Conv2d(chs[4], 16, 3, padding=1),
            act_fn(act),
            nn.Conv2d(16, 1, 1)
        )
    
    def forward(self, I_cam_lin):
        x1 = self.e1(I_cam_lin)
        x2 = self.e2(self.p1(x1))
        x3 = self.e3(self.p2(x2))
        x4 = self.e4(self.p3(x3))
        x5 = self.e5(self.p4(x4))
        
        t_logits = self.t_head(x5)
        t = torch.sigmoid(t_logits) * (1 - 2*self.t_min) + self.t_min
        
        A = self.A_mlp(self.A_pool(x5))
        zw = self.zw_mlp(self.zw_pool(x5).flatten(1))
        sigma2 = F.softplus(self.sigma_head(x5))
        
        return {"t": t, "A": A, "zw": zw, "sigma2": sigma2}

class TRIDENT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.gamma = cfg["model"]["colorspace"]["gamma"]
        self.t_min = cfg["model"]["physics"]["t_min"]
        self.phys = UNetS_Physics(cfg)
        
        self.zw_to_t_bias = nn.Sequential(
            nn.Linear(cfg["model"]["physics"]["z_w_dim"], 64),
            nn.SiLU(),
            nn.Linear(64, 3)
        )
        
        with torch.no_grad():
            self.zw_to_t_bias[-1].weight.mul_(0.01)
            self.zw_to_t_bias[-1].bias.zero_()
    
    def forward(self, batch):
        I_cam_l = batch["I_cam_lin"].to(DEVICE)
        
        ph = self.phys(I_cam_l)
        t, A, zw, sigma2 = ph["t"], ph["A"], ph["zw"], ph["sigma2"]
        
        # In TRIDENT.forward():
        t_bias = self.zw_to_t_bias(zw).unsqueeze(-1).unsqueeze(-1)
        t_bias = torch.clamp(t_bias, -0.02, 0.02)  # Hard constraint
        t = t + t_bias
        t = torch.clamp(t, min=self.t_min, max=1.0)
        
        H, W = I_cam_l.shape[-2], I_cam_l.shape[-1]
        t = F.interpolate(t, size=(H, W), mode="bilinear", align_corners=False)
        sigma2 = F.interpolate(sigma2, size=(H, W), mode="bilinear", align_corners=False)
        
        A_b = A.expand(-1, -1, H, W)
        A_b = torch.clamp(A_b, 0.0, 1.0)

        # More stable restoration
        numerator = torch.clamp(I_cam_l - A_b * (1 - t), min=1e-8)
        denominator = torch.clamp(t, min=1e-8)

        # Avoid log(very small numbers) by capping the ratio
        ratio = numerator / denominator
        I_hat_l = torch.clamp(ratio, 0.0, 5.0)  # Prevent extreme values
        
        # Final clamping
        I_hat_l = torch.clamp(I_hat_l, 0.0, 1.0)
        
        I_phys_l = None
        if "I_clean_lin" in batch:
            I_ref_l = batch["I_clean_lin"].to(DEVICE)
            I_phys_l = I_ref_l * t + A_b * (1 - t)
            I_phys_l = torch.clamp(I_phys_l, 0.0, 1.0)
        
        return {
            "I_hat_lin": I_hat_l,
            "I_phys_lin": I_phys_l,
            "t": t,
            "A": A,
            "zw": zw,
            "sigma2": sigma2,
            "t_bias": t_bias
        }

# ============================================================================
# EMA
# ============================================================================

class EMA:
    def __init__(self, model, decay=0.999):
        self.ema = {k: v.detach().clone().float() for k, v in model.state_dict().items()}
        self.decay = decay
    
    @torch.no_grad()
    def update(self, model):
        sd = model.state_dict()
        for k in self.ema.keys():
            self.ema[k].mul_(self.decay).add_(sd[k].detach().float(), alpha=1 - self.decay)
    
    def copy_to(self, model):
        model.load_state_dict(self.ema, strict=True)

# ============================================================================
# LOSSES
# ============================================================================

class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features
        self.features = nn.Sequential(*list(vgg[:9])).eval()
        for p in self.features.parameters():
            p.requires_grad = False
        
        # Normalize features to [0,1] range
        self.register_buffer('vgg_mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.register_buffer('vgg_std', torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
    
    def forward(self, pred, target):
        # Normalize input
        pred = (pred - self.vgg_mean) / self.vgg_std
        target = (target - self.vgg_mean) / self.vgg_std
        
        pred_feat = self.features(pred)
        target_feat = self.features(target)
        
        # Scale down
        return F.l1_loss(pred_feat, target_feat)  # Let the weight handle it

def frequency_loss(pred, target):
    pred_fft = torch.fft.fft2(pred, dim=(-2, -1), norm='ortho')  # Orthonormal
    target_fft = torch.fft.fft2(target, dim=(-2, -1), norm='ortho')
    
    pred_mag = torch.log1p(torch.abs(pred_fft))
    target_mag = torch.log1p(torch.abs(target_fft))
    
    # Better high-frequency emphasis
    H, W = pred.shape[-2:]
    cy, cx = H // 2, W // 2
    y, x = torch.meshgrid(torch.arange(H, device=pred.device), 
                          torch.arange(W, device=pred.device), indexing='ij')
    dist = torch.sqrt((y - cy)**2 + (x - cx)**2)
    mask = (dist > min(H, W) * 0.2).float()  # High-freq mask
    
    freq_loss = F.l1_loss(pred_mag * mask, target_mag * mask)
    return torch.clamp(freq_loss, 0.0, 1.0)

PERCEPTUAL_LOSS = None

def ssim_srgb(x, y, C1=0.01**2, C2=0.03**2):
    mu_x = F.avg_pool2d(x, 11, stride=1, padding=5)
    mu_y = F.avg_pool2d(y, 11, stride=1, padding=5)
    sigma_x = F.avg_pool2d(x*x, 11, 1, 5) - mu_x*mu_x
    sigma_y = F.avg_pool2d(y*y, 11, 1, 5) - mu_y*mu_y
    sigma_xy = F.avg_pool2d(x*y, 11, 1, 5) - mu_x*mu_y
    ssim_n = (2*mu_x*mu_y + C1) * (2*sigma_xy + C2)
    ssim_d = (mu_x*mu_x + mu_y*mu_y + C1) * (sigma_x + sigma_y + C2)
    ssim_map = ssim_n / (ssim_d + 1e-8)
    return ssim_map.mean()

def psnr_srgb(x, y):
    mse = F.mse_loss(x, y)
    return 10 * torch.log10(1.0 / (mse + 1e-8))

def tv_loss_t(t):
    dx = t[:, :, :, 1:] - t[:, :, :, :-1]
    dy = t[:, :, 1:, :] - t[:, :, :-1, :]
    return (dx.abs().mean() + dy.abs().mean())

def A_prior_loss(A, I_cam_lin):
    mean_cam = I_cam_lin.mean(dim=(2, 3), keepdim=True)
    return F.mse_loss(A, mean_cam)

def heteroscedastic_l1(err_lin, sigma2):
    eps = CFG["loss"]["heteroscedastic"]["epsilon"]  # Use config value (1e-6)
    sigma2_safe = torch.clamp(sigma2, min=eps, max=10.0)
    # Add stability to log
    return ((err_lin / sigma2_safe) + torch.log(sigma2_safe + eps)).mean()

def compute_losses(batch, out, epoch, hetero_on):
    W = CFG["loss"]["weights"]
    
    I_hat_l = out["I_hat_lin"]
    I_phys_l = out["I_phys_lin"]
    I_cam_l = batch["I_cam_lin"].to(DEVICE)
    I_ref_l = batch["I_clean_lin"].to(DEVICE)

    L_recon = F.l1_loss(I_hat_l, I_ref_l)
    L_recon = torch.clamp(L_recon, 0.0, 1.0)

    L_phys  = torch.tensor(0., device=DEVICE)
    if I_phys_l is not None:
        L_phys = F.l1_loss(I_phys_l, I_cam_l)
        L_phys = torch.clamp(L_phys, 0.0, 1.0)

    L_tv    = tv_loss_t(out["t"])
    L_Ap    = A_prior_loss(out["A"], I_cam_l)

    I_hat_s = to_srgb(I_hat_l, CFG["model"]["colorspace"]["gamma"])
    I_ref_s = to_srgb(I_ref_l, CFG["model"]["colorspace"]["gamma"])

    L_perc  = PERCEPTUAL_LOSS(I_hat_s, I_ref_s)
    L_perc  = torch.clamp(L_perc, 0.0, 1.0)

    L_freq  = frequency_loss(I_hat_l, I_ref_l)
    L_freq  = torch.clamp(L_freq, 0.0, 1.0)

    L_t_bias= F.mse_loss(out["t_bias"], torch.zeros_like(out["t_bias"]))

    # FIXED: Compute heteroscedastic loss
    L_hetero = torch.tensor(0., device=DEVICE)
    if hetero_on:
        err_lin = (I_hat_l - I_ref_l).abs()
        L_hetero = heteroscedastic_l1(err_lin, out["sigma2"])
        L_hetero = torch.clamp(L_hetero, 0.0, 1.0)

    # Build the final loss
    loss = L_recon + W["phys"]*L_phys + W["tv_t"]*L_tv + W["A_prior"]*L_Ap + \
        W["t_bias"]*L_t_bias + W["perceptual"]*L_perc + W["frequency"]*L_freq

    if hetero_on:
        loss = loss + W["hetero"] * L_hetero

    loss = torch.clamp(loss, 0.0, 10.0) 
    
    if torch.isnan(loss) or torch.isinf(loss):
        print(f"\n{'='*50}")
        print(f"NaN/Inf detected in loss calculation!")
        print(f"Recon={L_recon.item():.4f}, Phys={L_phys.item():.4f}")
        print(f"Perc={L_perc.item():.4f}, Freq={L_freq.item():.4f}")
        print(f"t range: [{out['t'].min().item():.4f}, {out['t'].max().item():.4f}]")
        print(f"A range: [{out['A'].min().item():.4f}, {out['A'].max().item():.4f}]")
        print(f"I_hat range: [{I_hat_l.min().item():.4f}, {I_hat_l.max().item():.4f}]")
        print(f"Falling back to L_recon + L_phys only")
        print(f"{'='*50}\n")
        loss = L_recon + W["phys"] * L_phys
    
    train_ssim = ssim_srgb(I_hat_s, I_ref_s)
    
    logs = {
        "L_total": loss.item(),
        "Recon": L_recon.item(),
        "Phys": L_phys.item(),
        "TVt": L_tv.item(),
        "Apr": L_Ap.item(),
        "Hetero": L_hetero.item() if hetero_on else 0.0,
        "Perc": L_perc.item(),
        "Freq": L_freq.item(),
        "t_bias": out["t_bias"].abs().mean().item(),
        "SSIM": train_ssim.item()
    }
    
    return loss, logs

# ============================================================================
# TRAINING UTILITIES
# ============================================================================

class HeteroWatchdog:
    def __init__(self, cfg):
        wd = cfg["loss"]["heteroscedastic"]["watchdog"]
        self.th = wd["mean_sigma2_thresh"]
        self.req = wd["over_thresh_required"]
        self.cool = wd["cooldown_epochs"]
        self.count = 0
        self.cooldown = 0
    
    def update(self, mean_sigma2):
        if mean_sigma2 > self.th:
            self.count += 1
        else:
            self.count = 0
        
        triggered = False
        if self.count >= self.req:
            self.cooldown = self.cool
            self.count = 0
            triggered = True
        
        if self.cooldown > 0:
            self.cooldown -= 1
            return False, True
        
        return triggered, False

def cosine_lr(optimizer, step, total_steps, base_lr, min_lr, warmup_steps):
    if step < warmup_steps:
        lr = base_lr * step / max(1, warmup_steps)
    else:
        t = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        lr = min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * t))
    
    for pg in optimizer.param_groups:
        pg["lr"] = lr
    
    return lr

def save_checkpoint(state_dict, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cpu_state = {k: v.cpu().clone() for k, v in state_dict.items()}
    torch.save(cpu_state, path)
    print(f"✓ Saved: {path}")

# ============================================================================
# TRAINING
# ============================================================================

def train_val_setup():
    global PERCEPTUAL_LOSS
    PERCEPTUAL_LOSS = VGGPerceptualLoss().to(DEVICE)
    
    dl_train, dl_val, dl_test = get_loaders(CFG)
    model = TRIDENT(CFG).to(DEVICE)
    
    opt = torch.optim.AdamW(
        model.parameters(),
        lr=CFG["schedules"]["lr"]["base"],
        betas=tuple(CFG["optimizer"]["betas"]),
        weight_decay=CFG["optimizer"]["weight_decay"],
        eps=1e-8
    )
    
    scaler = torch.amp.GradScaler('cuda', enabled=(CFG["amp"] and DEVICE.startswith("cuda")))
    ema = EMA(model, decay=CFG["ema"]["decay"]) if CFG["ema"]["enabled"] else None
    
    return model, opt, scaler, ema, dl_train, dl_val, dl_test

def run_training():
    model, opt, scaler, ema, dl_train, dl_val, dl_test = train_val_setup()
    steps_per_epoch = len(dl_train)
    total_steps = steps_per_epoch * CFG["schedules"]["epochs"]
    hetero_wd = HeteroWatchdog(CFG)
    
    best_ssim = -1.0
    best_l1 = float('inf')
    step = 0
    
    for epoch in range(1, CFG["schedules"]["epochs"] + 1):
        torch.cuda.empty_cache()
        gc.collect()
        
        model.train()
        epoch_logs = []
        pbar = tqdm(dl_train, desc=f"Epoch {epoch}/{CFG['schedules']['epochs']} [train]", leave=True)
        
        for b, batch in enumerate(pbar):
            lr = cosine_lr(opt, step, total_steps,
                          CFG["schedules"]["lr"]["base"],
                          CFG["schedules"]["lr"]["min"],
                          CFG["schedules"]["warmup_steps"])
            
            opt.zero_grad(set_to_none=True)
            
            ctx = torch.amp.autocast('cuda', enabled=(CFG["amp"] and DEVICE.startswith("cuda")))
            with ctx:
                out = model(batch)
                hetero_on = (epoch >= CFG["schedules"]["hetero"]["start_epoch"]) and (hetero_wd.cooldown == 0)
                loss, logs = compute_losses(batch, out, epoch, hetero_on)
            
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"\nSkipping batch {b} due to NaN/Inf loss")
                opt.zero_grad(set_to_none=True)
                step += 1
                continue
            
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["optimizer"]["grad_clip"])
            scaler.step(opt)
            scaler.update()
            
            # CRITICAL FIX: Update EMA after each successful step
            if CFG["ema"]["enabled"] and ema is not None:
                ema.update(model)
            
            if step % 100 == 0:
                print(f"\n  Loss Scale: {scaler.get_scale():.2f}")
                if scaler.get_scale() < 1.0:
                    print("  WARNING: Loss scale dropped below 1.0!")
            
            step += 1
            pbar.set_postfix({
                "lr": f"{lr:.2e}",
                "L": f"{logs['L_total']:.3f}",
                "Recon": f"{logs['Recon']:.3f}",
                "SSIM": f"{logs['SSIM']:.3f}",
                "Perc": f"{logs['Perc']:.3f}"
            })
            
            if step % 200 == 0:
                torch.cuda.empty_cache()
        
        torch.cuda.empty_cache()
        gc.collect()
        
        # Validation phase
        model.eval()
        val_model = TRIDENT(CFG).to(DEVICE)
        if CFG["ema"]["enabled"] and ema is not None:
            ema.copy_to(val_model)
        else:
            val_model.load_state_dict(model.state_dict())
        val_model.eval()
        
        ssim_vals, l1_vals, sigma_means = [], [], []
        pbar_val = tqdm(dl_val, desc=f"Epoch {epoch} [val]", leave=False)
        
        with torch.no_grad():
            for batch in pbar_val:
                out = val_model(batch)
                I_hat_s = to_srgb(out["I_hat_lin"], CFG["model"]["colorspace"]["gamma"])
                I_ref_s = to_srgb(batch["I_clean_lin"].to(DEVICE), CFG["model"]["colorspace"]["gamma"])
                
                ssim_vals.append(ssim_srgb(I_hat_s, I_ref_s).item())
                l1_vals.append(F.l1_loss(out["I_hat_lin"], batch["I_clean_lin"].to(DEVICE)).item())
                sigma_means.append(out["sigma2"].mean().item())
                
                pbar_val.set_postfix({
                    "SSIM": f"{np.mean(ssim_vals):.3f}",
                    "L1": f"{np.mean(l1_vals):.3f}"
                })
        
        val_ssim = float(np.mean(ssim_vals))
        val_l1 = float(np.mean(l1_vals))
        mean_sigma = float(np.mean(sigma_means))
        
        del val_model
        torch.cuda.empty_cache()
        gc.collect()
        
        trig, disabled = hetero_wd.update(mean_sigma)
        
        if val_ssim > best_ssim:
            best_ssim = val_ssim
            print(f"  → New best SSIM: {val_ssim:.4f}")
            if CFG["ema"]["enabled"] and ema is not None:
                save_checkpoint(ema.ema, "checkpointsphy/best_ssim.pt")
            else:
                save_checkpoint(model.state_dict(), "checkpointsphy/best_ssim.pt")
        
        if val_l1 < best_l1:
            best_l1 = val_l1
            print(f"  → New best L1: {val_l1:.4f}")
            if CFG["ema"]["enabled"] and ema is not None:
                save_checkpoint(ema.ema, "checkpointsphy/best_l1.pt")
            else:
                save_checkpoint(model.state_dict(), "checkpointsphy/best_l1.pt")
        
        print(f"[Epoch {epoch}] SSIM={val_ssim:.4f}  L1={val_l1:.4f}  "
              f"meanσ²={mean_sigma:.3f}  heteroCooldown={hetero_wd.cooldown}")
    
    # Save final checkpoint (non-EMA)
    save_checkpoint(model.state_dict(), "checkpointsphy/last.pt")
    
    print("\n" + "="*70)
    print(f"Training complete!")
    print(f"Best SSIM: {best_ssim:.4f}")
    print(f"Best L1: {best_l1:.4f}")
    print("="*70 + "\n")

# ============================================================================
# EVALUATION
# ============================================================================

def evaluate_and_export():
    global PERCEPTUAL_LOSS
    if PERCEPTUAL_LOSS is None:
        PERCEPTUAL_LOSS = VGGPerceptualLoss().to(DEVICE)
    
    model = TRIDENT(CFG).to(DEVICE)
    
    checkpoint_path = "checkpointsphy/best_ssim.pt"
    if not os.path.exists(checkpoint_path):
        print(f"[ERROR] Checkpoint not found: {checkpoint_path}")
        return
    
    sd = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(sd, strict=True)
    model.eval()
    
    print(f"Loaded checkpoint: {checkpoint_path}")
    
    _, _, dl_test = get_loaders(CFG)
    
    base_export_dir = CFG["logging"]["export_test"]["dir"]
    os.makedirs(base_export_dir, exist_ok=True)
    
    if CFG["logging"]["export_test"].get("save_images", False):
        img_export_dir = os.path.join(base_export_dir, "images")
        os.makedirs(img_export_dir, exist_ok=True)
    
    SSIMs, PSNRs, L1s = [], [], []
    
    print("\n" + "="*70)
    print("EVALUATION ON TEST SET")
    print("="*70)
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dl_test, desc="Evaluating and Exporting")):
            out = model(batch)
            I_hat_s = to_srgb(out["I_hat_lin"], CFG["model"]["colorspace"]["gamma"])
            I_ref_s = to_srgb(batch["I_clean_lin"].to(DEVICE), CFG["model"]["colorspace"]["gamma"])
            
            if CFG["logging"]["export_test"].get("save_images", False):
                I_cam_s = batch["I_cam_srgb"].to(DEVICE)
                cam_paths = batch["cam_path"]
                
                for i in range(I_hat_s.shape[0]):
                    base_name = os.path.basename(cam_paths[i])
                    comparison_grid = torch.cat([I_cam_s[i], I_hat_s[i], I_ref_s[i]], dim=-1)
                    
                    save_path = os.path.join(img_export_dir, base_name)
                    from torchvision.utils import save_image
                    save_image(comparison_grid, save_path)
            
            SSIMs.append(ssim_srgb(I_hat_s, I_ref_s).item())
            PSNRs.append(psnr_srgb(I_hat_s, I_ref_s).item())
            L1s.append(F.l1_loss(out["I_hat_lin"], batch["I_clean_lin"].to(DEVICE)).item())
            
            if CFG["logging"]["export_test"]["tAzw"]:
                batch_id = os.path.basename(batch["cam_path"][0]).split('.')[0]
                np.savez_compressed(
                    os.path.join(base_export_dir, f"physics_{batch_id}.npz"),
                    t=out["t"].detach().cpu().numpy(),
                    A=out["A"].detach().cpu().numpy(),
                    zw=out["zw"].detach().cpu().numpy(),
                    sigma2=out["sigma2"].detach().cpu().numpy()
                )
    
    print("\n" + "="*70)
    print("TEST RESULTS")
    print("="*70)
    print(f"SSIM:  {np.mean(SSIMs):.4f} ± {np.std(SSIMs):.4f}")
    print(f"PSNR:  {np.mean(PSNRs):.2f} ± {np.std(PSNRs):.2f} dB")
    print(f"L1:    {np.mean(L1s):.4f} ± {np.std(L1s):.4f}")
    print("="*70)
    
    metrics = {
        "ssim_mean": float(np.mean(SSIMs)),
        "ssim_std": float(np.std(SSIMs)),
        "psnr_mean": float(np.mean(PSNRs)),
        "psnr_std": float(np.std(PSNRs)),
        "l1_mean": float(np.mean(L1s)),
        "l1_std": float(np.std(L1s))
    }
    
    metrics_path = os.path.join(base_export_dir, "test_metrics.json")
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=2)
    
    print(f"\n✓ Metrics saved to: {metrics_path}")
    
    if CFG["logging"]["export_test"].get("save_images", False):
        print(f"✓ Images saved to: {img_export_dir}")
    
    if CFG["logging"]["export_test"]["tAzw"]:
        print(f"✓ Physics parameters saved to: {base_export_dir}")

# ============================================================================
# MAIN
# ============================================================================

if __name__ == "__main__":
    import sys
    
    print("\n" + "="*70)
    print("TRIDENT PHYSICS-ONLY MODEL - ENHANCED")
    print("="*70 + "\n")
    
    if len(sys.argv) > 1 and sys.argv[1] == "eval":
        print("Running evaluation only...\n")
        evaluate_and_export()
    else:
        print("Starting training...\n")
        run_training()
        print("\nStarting evaluation...\n")
        evaluate_and_export()
    
    print("\n" + "="*70)
    print("ALL DONE!")
    print("="*70 + "\n")

Using device: cuda

TRIDENT PHYSICS-ONLY MODEL - ENHANCED

Starting training...



Epoch 1/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00

  Loss Scale: 65536.00


Epoch 1 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7471
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1956
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 1] SSIM=0.7471  L1=0.1956  meanσ²=0.753  heteroCooldown=0


Epoch 2/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00


Epoch 2 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7497
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1921
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 2] SSIM=0.7497  L1=0.1921  meanσ²=0.752  heteroCooldown=1


Epoch 3/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00

  Loss Scale: 65536.00


Epoch 3 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7538
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1861
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 3] SSIM=0.7538  L1=0.1861  meanσ²=0.749  heteroCooldown=0


Epoch 4/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00


Epoch 4 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7600
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1766
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 4] SSIM=0.7600  L1=0.1766  meanσ²=0.745  heteroCooldown=1


Epoch 5/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00


Epoch 5 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7670
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1654
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 5] SSIM=0.7670  L1=0.1654  meanσ²=0.743  heteroCooldown=0


Epoch 6/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00

  Loss Scale: 65536.00


Epoch 6 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7741
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1532
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 6] SSIM=0.7741  L1=0.1532  meanσ²=0.740  heteroCooldown=1


Epoch 7/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00


Epoch 7 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7808
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1409
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 7] SSIM=0.7808  L1=0.1409  meanσ²=0.738  heteroCooldown=0


Epoch 8/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00

  Loss Scale: 65536.00


Epoch 8 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7855
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1288
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 8] SSIM=0.7855  L1=0.1288  meanσ²=0.737  heteroCooldown=1


Epoch 9/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00


Epoch 9 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7889
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1182
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 9] SSIM=0.7889  L1=0.1182  meanσ²=0.737  heteroCooldown=0


Epoch 10/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00


Epoch 10 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7905
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1103
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 10] SSIM=0.7905  L1=0.1103  meanσ²=0.738  heteroCooldown=1


Epoch 11/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00

  Loss Scale: 65536.00


Epoch 11 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7919
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1045
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 11] SSIM=0.7919  L1=0.1045  meanσ²=0.739  heteroCooldown=0


Epoch 12/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00


Epoch 12 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7937
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.1003
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 12] SSIM=0.7937  L1=0.1003  meanσ²=0.740  heteroCooldown=1


Epoch 13/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00

  Loss Scale: 65536.00


Epoch 13 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7960
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0971
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 13] SSIM=0.7960  L1=0.0971  meanσ²=0.742  heteroCooldown=0


Epoch 14/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00


Epoch 14 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.7987
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0946
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 14] SSIM=0.7987  L1=0.0946  meanσ²=0.744  heteroCooldown=1


Epoch 15/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00


Epoch 15 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8016
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0925
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 15] SSIM=0.8016  L1=0.0925  meanσ²=0.746  heteroCooldown=0


Epoch 16/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00

  Loss Scale: 131072.00


Epoch 16 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8047
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0907
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 16] SSIM=0.8047  L1=0.0907  meanσ²=0.748  heteroCooldown=1


Epoch 17/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00


Epoch 17 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8078
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0889
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 17] SSIM=0.8078  L1=0.0889  meanσ²=0.751  heteroCooldown=0


Epoch 18/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00

  Loss Scale: 131072.00


Epoch 18 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8110
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0871
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 18] SSIM=0.8110  L1=0.0871  meanσ²=0.753  heteroCooldown=1


Epoch 19/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00


Epoch 19 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8141
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0853
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 19] SSIM=0.8141  L1=0.0853  meanσ²=0.755  heteroCooldown=0


Epoch 20/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00


Epoch 20 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8169
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0835
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 20] SSIM=0.8169  L1=0.0835  meanσ²=0.757  heteroCooldown=1


Epoch 21/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00

  Loss Scale: 131072.00


Epoch 21 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8197
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0817
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 21] SSIM=0.8197  L1=0.0817  meanσ²=0.759  heteroCooldown=0


Epoch 22/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00


Epoch 22 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8222
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0800
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 22] SSIM=0.8222  L1=0.0800  meanσ²=0.761  heteroCooldown=1


Epoch 23/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00

  Loss Scale: 131072.00


Epoch 23 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8244
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0783
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 23] SSIM=0.8244  L1=0.0783  meanσ²=0.762  heteroCooldown=0


Epoch 24/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00


Epoch 24 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8262
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0768
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 24] SSIM=0.8262  L1=0.0768  meanσ²=0.763  heteroCooldown=1


Epoch 25/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00


Epoch 25 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8278
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0754
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 25] SSIM=0.8278  L1=0.0754  meanσ²=0.765  heteroCooldown=0


Epoch 26/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00

  Loss Scale: 131072.00


Epoch 26 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8292
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0741
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 26] SSIM=0.8292  L1=0.0741  meanσ²=0.766  heteroCooldown=1


Epoch 27/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00


Epoch 27 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8303
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0730
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 27] SSIM=0.8303  L1=0.0730  meanσ²=0.767  heteroCooldown=0


Epoch 28/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 131072.00

  Loss Scale: 131072.00


Epoch 28 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8312
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0719
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 28] SSIM=0.8312  L1=0.0719  meanσ²=0.767  heteroCooldown=1


Epoch 29/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 262144.00


Epoch 29 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8320
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0710
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 29] SSIM=0.8320  L1=0.0710  meanσ²=0.768  heteroCooldown=0


Epoch 30/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 262144.00


Epoch 30 [val]:   0%|          | 0/14 [00:00<?, ?it/s]

  → New best SSIM: 0.8326
✓ Saved: checkpointsphy/best_ssim.pt
  → New best L1: 0.0702
✓ Saved: checkpointsphy/best_l1.pt
[Epoch 30] SSIM=0.8326  L1=0.0702  meanσ²=0.769  heteroCooldown=1
✓ Saved: checkpointsphy/last.pt

Training complete!
Best SSIM: 0.8326
Best L1: 0.0702


Starting evaluation...

Loaded checkpoint: checkpointsphy/best_ssim.pt

EVALUATION ON TEST SET


Evaluating and Exporting:   0%|          | 0/14 [00:00<?, ?it/s]


TEST RESULTS
SSIM:  0.8377 ± 0.0077
PSNR:  18.20 ± 0.36 dB
L1:    0.0697 ± 0.0026

✓ Metrics saved to: outputs/test_exports/test_metrics.json
✓ Images saved to: outputs/test_exports/images
✓ Physics parameters saved to: outputs/test_exports

ALL DONE!



In [None]:
"""
Extended Training - Continue from Best SSIM Checkpoint
Loads the best_ssim.pt checkpoint and trains for another 30 epochs
"""

import os
import torch
import gc

# ============================================================================
# EXTENDED TRAINING CONFIGURATION
# ============================================================================

EXTENDED_CFG = {
    "checkpoint_path": "checkpointsphy/best_ssim.pt",
    "extended_epochs": 30,
    "save_dir": "checkpointsphy/extended",
    "lr_schedule": {
        "base": 1.0e-4,  # Lower learning rate for fine-tuning
        "min": 5.0e-9,
        "warmup_steps": 1000
    }
}

# ============================================================================
# EXTENDED TRAINING FUNCTION
# ============================================================================

def run_extended_training():
    """
    Continue training from the best SSIM checkpoint for another 30 epochs
    """
    global PERCEPTUAL_LOSS
    
    # Initialize perceptual loss if needed
    if PERCEPTUAL_LOSS is None:
        PERCEPTUAL_LOSS = VGGPerceptualLoss().to(DEVICE)
    
    # Create save directory
    os.makedirs(EXTENDED_CFG["save_dir"], exist_ok=True)
    
    # Load checkpoint
    checkpoint_path = EXTENDED_CFG["checkpoint_path"]
    if not os.path.exists(checkpoint_path):
        print(f"[ERROR] Checkpoint not found: {checkpoint_path}")
        print("Please run training first to generate the best_ssim.pt checkpoint")
        return
    
    print("\n" + "="*70)
    print("EXTENDED TRAINING FROM BEST SSIM CHECKPOINT")
    print("="*70)
    print(f"Loading checkpoint: {checkpoint_path}")
    
    # Initialize model and load weights
    model = TRIDENT(CFG).to(DEVICE)
    sd = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(sd, strict=True)
    print(f"✓ Checkpoint loaded successfully")
    
    # Setup optimizer
    opt = torch.optim.AdamW(
        model.parameters(),
        lr=EXTENDED_CFG["lr_schedule"]["base"],
        betas=tuple(CFG["optimizer"]["betas"]),
        weight_decay=CFG["optimizer"]["weight_decay"],
        eps=1e-8
    )
    
    # Setup gradient scaler and EMA
    scaler = torch.amp.GradScaler('cuda', enabled=(CFG["amp"] and DEVICE.startswith("cuda")))
    ema = EMA(model, decay=CFG["ema"]["decay"]) if CFG["ema"]["enabled"] else None
    
    # Get data loaders
    dl_train, dl_val, dl_test = get_loaders(CFG)
    
    # Training setup
    steps_per_epoch = len(dl_train)
    total_steps = steps_per_epoch * EXTENDED_CFG["extended_epochs"]
    hetero_wd = HeteroWatchdog(CFG)
    
    best_ssim = -1.0
    best_l1 = float('inf')
    step = 0
    
    print(f"\nExtended training configuration:")
    print(f"  Epochs: {EXTENDED_CFG['extended_epochs']}")
    print(f"  Base LR: {EXTENDED_CFG['lr_schedule']['base']:.2e}")
    print(f"  Steps per epoch: {steps_per_epoch}")
    print(f"  Total steps: {total_steps}")
    print("="*70 + "\n")
    
    # Training loop
    for epoch in range(1, EXTENDED_CFG["extended_epochs"] + 1):
        torch.cuda.empty_cache()
        gc.collect()
        
        model.train()
        epoch_logs = []
        pbar = tqdm(dl_train, desc=f"Extended Epoch {epoch}/{EXTENDED_CFG['extended_epochs']} [train]", leave=True)
        
        for b, batch in enumerate(pbar):
            # Cosine learning rate schedule
            lr = cosine_lr(opt, step, total_steps,
                          EXTENDED_CFG["lr_schedule"]["base"],
                          EXTENDED_CFG["lr_schedule"]["min"],
                          EXTENDED_CFG["lr_schedule"]["warmup_steps"])
            
            opt.zero_grad(set_to_none=True)
            
            # Forward pass with mixed precision
            ctx = torch.amp.autocast('cuda', enabled=(CFG["amp"] and DEVICE.startswith("cuda")))
            with ctx:
                out = model(batch)
                hetero_on = (hetero_wd.cooldown == 0)
                loss, logs = compute_losses(batch, out, epoch, hetero_on)
            
            # Skip batch if loss is invalid
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"\nSkipping batch {b} due to NaN/Inf loss")
                opt.zero_grad(set_to_none=True)
                step += 1
                continue
            
            # Backward pass
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["optimizer"]["grad_clip"])
            scaler.step(opt)
            scaler.update()
            
            # Update EMA
            if CFG["ema"]["enabled"] and ema is not None:
                ema.update(model)
            
            # Monitor loss scale
            if step % 100 == 0:
                print(f"\n  Loss Scale: {scaler.get_scale():.2f}")
                if scaler.get_scale() < 1.0:
                    print("  WARNING: Loss scale dropped below 1.0!")
            
            step += 1
            pbar.set_postfix({
                "lr": f"{lr:.2e}",
                "L": f"{logs['L_total']:.3f}",
                "Recon": f"{logs['Recon']:.3f}",
                "SSIM": f"{logs['SSIM']:.3f}",
                "Perc": f"{logs['Perc']:.3f}"
            })
            
            # Periodic memory cleanup
            if step % 200 == 0:
                torch.cuda.empty_cache()
        
        torch.cuda.empty_cache()
        gc.collect()
        
        # Validation phase
        model.eval()
        val_model = TRIDENT(CFG).to(DEVICE)
        if CFG["ema"]["enabled"] and ema is not None:
            ema.copy_to(val_model)
        else:
            val_model.load_state_dict(model.state_dict())
        val_model.eval()
        
        ssim_vals, l1_vals, sigma_means = [], [], []
        pbar_val = tqdm(dl_val, desc=f"Extended Epoch {epoch} [val]", leave=False)
        
        with torch.no_grad():
            for batch in pbar_val:
                out = val_model(batch)
                I_hat_s = to_srgb(out["I_hat_lin"], CFG["model"]["colorspace"]["gamma"])
                I_ref_s = to_srgb(batch["I_clean_lin"].to(DEVICE), CFG["model"]["colorspace"]["gamma"])
                
                ssim_vals.append(ssim_srgb(I_hat_s, I_ref_s).item())
                l1_vals.append(F.l1_loss(out["I_hat_lin"], batch["I_clean_lin"].to(DEVICE)).item())
                sigma_means.append(out["sigma2"].mean().item())
                
                pbar_val.set_postfix({
                    "SSIM": f"{np.mean(ssim_vals):.3f}",
                    "L1": f"{np.mean(l1_vals):.3f}"
                })
        
        val_ssim = float(np.mean(ssim_vals))
        val_l1 = float(np.mean(l1_vals))
        mean_sigma = float(np.mean(sigma_means))
        
        del val_model
        torch.cuda.empty_cache()
        gc.collect()
        
        # Update heteroscedastic watchdog
        trig, disabled = hetero_wd.update(mean_sigma)
        
        # Save best checkpoints
        if val_ssim > best_ssim:
            best_ssim = val_ssim
            print(f"  → New best SSIM: {val_ssim:.4f}")
            if CFG["ema"]["enabled"] and ema is not None:
                save_checkpoint(ema.ema, os.path.join(EXTENDED_CFG["save_dir"], "best_ssim.pt"))
            else:
                save_checkpoint(model.state_dict(), os.path.join(EXTENDED_CFG["save_dir"], "best_ssim.pt"))
        
        if val_l1 < best_l1:
            best_l1 = val_l1
            print(f"  → New best L1: {val_l1:.4f}")
            if CFG["ema"]["enabled"] and ema is not None:
                save_checkpoint(ema.ema, os.path.join(EXTENDED_CFG["save_dir"], "best_l1.pt"))
            else:
                save_checkpoint(model.state_dict(), os.path.join(EXTENDED_CFG["save_dir"], "best_l1.pt"))
        
        # Print epoch summary
        print(f"[Extended Epoch {epoch}] SSIM={val_ssim:.4f}  L1={val_l1:.4f}  "
              f"meanσ²={mean_sigma:.3f}  heteroCooldown={hetero_wd.cooldown}")
        
        # Save checkpoint every 5 epochs
        if epoch % 5 == 0:
            checkpoint_name = f"extended_epoch_{epoch}.pt"
            if CFG["ema"]["enabled"] and ema is not None:
                save_checkpoint(ema.ema, os.path.join(EXTENDED_CFG["save_dir"], checkpoint_name))
            else:
                save_checkpoint(model.state_dict(), os.path.join(EXTENDED_CFG["save_dir"], checkpoint_name))
    
    # Save final checkpoint
    save_checkpoint(model.state_dict(), os.path.join(EXTENDED_CFG["save_dir"], "last.pt"))
    
    print("\n" + "="*70)
    print(f"Extended training complete!")
    print(f"Best SSIM: {best_ssim:.4f}")
    print(f"Best L1: {best_l1:.4f}")
    print(f"Checkpoints saved to: {EXTENDED_CFG['save_dir']}")
    print("="*70 + "\n")
    
    return best_ssim, best_l1


# ============================================================================
# RUN EXTENDED TRAINING
# ============================================================================

if __name__ == "__main__":
    print("\n" + "="*70)
    print("STARTING EXTENDED TRAINING")
    print("="*70 + "\n")
    
    best_ssim, best_l1 = run_extended_training()
    
    print("\n" + "="*70)
    print("EXTENDED TRAINING COMPLETE!")
    print(f"Final Best SSIM: {best_ssim:.4f}")
    print(f"Final Best L1: {best_l1:.4f}")
    print("="*70 + "\n")


STARTING EXTENDED TRAINING


EXTENDED TRAINING FROM BEST SSIM CHECKPOINT
Loading checkpoint: checkpointsphy/best_ssim.pt
✓ Checkpoint loaded successfully

Extended training configuration:
  Epochs: 30
  Base LR: 1.00e-04
  Steps per epoch: 140
  Total steps: 4200



Extended Epoch 1/30 [train]:   0%|          | 0/140 [00:00<?, ?it/s]


  Loss Scale: 65536.00
