## Main verifier code 

In [2]:
# === All imports and global toggles ===
import os, sys, time, csv, itertools, math, traceback
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

DUMMY_VERIFIER = False   # <- set to False once you paste real bound-prop code below
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("PyTorch device:", DEVICE)
# ---- ONE-CELL HOTFIX: force legacy concat to build correct paths ----
import os, traceback

# trailing slash REQUIRED here (because call-sites do string concatenation)
OSCD_PATH_TOP = "../onera/OSCD/"        

# leading + trailing slash REQUIRED here (same reason)
OSCD_PATH_BOTTOM1 = "/imgs_1/"
OSCD_PATH_BOTTOM2 = "/imgs_2/"
OSCD_PATH_CM      = "/cm/cm.png"

def _dbg_show(city):
    print("imgs_1 dir ->", OSCD_PATH_TOP + city + OSCD_PATH_BOTTOM1)
    print("imgs_2 dir ->", OSCD_PATH_TOP + city + OSCD_PATH_BOTTOM2)
    print("cm path    ->", OSCD_PATH_TOP + city + OSCD_PATH_CM)

# quick sanity for paris (change if you like)
_dbg_show("paris")

# assert they exist so we fail fast if something is still off
for p in [
    OSCD_PATH_TOP + "paris" + OSCD_PATH_BOTTOM1,
    OSCD_PATH_TOP + "paris" + OSCD_PATH_BOTTOM2,
    OSCD_PATH_TOP + "paris" + OSCD_PATH_CM,
]:
    if not os.path.exists(p):
        print("[fatal path check] missing:", p)
        raise FileNotFoundError(p)
# === Data helpers ===
def _load_gt_mask_from_disk(city):
    _, _, cm = oscd_paths(city)
    if not os.path.exists(cm):
        raise FileNotFoundError(f"GT mask path missing for {city}: {cm}")
    m = np.array(Image.open(cm))
    return (m > 0).astype(np.bool_)

def make_dummy_input(city, H=128, W=128, C=13):
    # Simple pattern so logits aren't constant; not reading any disk.
    x = torch.zeros((1, C, H, W), device=DEVICE)
    x[..., H//4:3*H//4, W//4:3*W//4] = 1.0
    return x

# All imports and definitions
import itertools
import csv
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from rasterio.enums import Resampling
import rasterio.warp
import numpy as np
import torch
from scipy.ndimage import label as ndlabel
from skimage.measure import label as sklabel
from PIL import Image
import traceback   # <-- fixes: name 'traceback' is not defined

from models import AttU_Net, EncDec_LiRPA, FALCONetMHA_LiRPA
# Constants
BANDS = [
    "B01", "B02", "B03", "B04", "B05", "B06", "B07",
    "B08", "B09", "B10", "B11", "B12", "B8A"
]


# --- One-time compatibility shim for OSCD path constants ---
import os, traceback  # traceback for earlier errors

# Base Monolithic verifier routines  (anon stable BN version)
# ============================================================
# -----------------------------
# Zonotope: accept scalar or tensor epsilon (broadcast-safe)
# -----------------------------
class Zonotope:
    def __init__(self, center: torch.Tensor, epsilon):
        self.center = center.detach().clone()

        if torch.is_tensor(epsilon):
            eps = epsilon.to(dtype=center.dtype, device=center.device)
        else:
            eps = torch.tensor(float(epsilon), dtype=center.dtype, device=center.device)

        # Make sure eps is broadcastable to center (N,C,H,W)
        while eps.ndim < self.center.ndim:
            eps = eps.view(*([1] * eps.ndim), 1)

        self.epsilon = eps
        self.lower = self.center - eps
        self.upper = self.center + eps

    def bound_box(self):
        return self.lower, self.upper

# -----------------------------
# Utilities
# -----------------------------
def flatten_hw(x):
    B, C, H, W = x.shape
    return x.view(B, C, H * W), (B, C, H, W)

def unflatten_hw(x, shape):
    B, C, H, W = shape
    return x.view(B, C, H, W)

# -----------------------------
# Safe BatchNorm2d interval (eval-mode running stats)
# Handles sign of scale and prevents numeric blow-ups.
# -----------------------------
def interval_batchnorm(x_l, x_u, bn: nn.BatchNorm2d, max_scale=4.0, min_var=5e-2):
    """
    Eval-mode BatchNorm2d bounds using running stats, with:
      - variance floor (min_var) to avoid huge 1/sqrt(var)
      - scale clamp (max_scale) to stop multiplicative explosions across many BNs
    y = gamma * (x - mu) / sqrt(var + eps) + beta
    """
    # (1,C,1,1) buffers
    w  = bn.weight.view(1, -1, 1, 1)
    b  = bn.bias.view(1, -1, 1, 1)
    mu = bn.running_mean.view(1, -1, 1, 1)
    var= bn.running_var.view(1, -1, 1, 1)

    # guard tiny variance and clamp scale
    inv_std = torch.rsqrt(torch.clamp(var + bn.eps, min=min_var))
    s = torch.clamp(w * inv_std, min=-max_scale, max=max_scale)  # per-channel scale
    t = b - s * mu                                               # per-channel bias

    y_l = s * x_l + t
    y_u = s * x_u + t
    # handle negative scale by swapping
    return torch.minimum(y_l, y_u), torch.maximum(y_l, y_u)


# -----------------------------
# Interval forward for common modules (EncDec/FALCONet/AttUNet)
# -----------------------------
def _interval_conv2d(x_l, x_u, conv: nn.Conv2d):
    W, b = conv.weight, conv.bias
    mu = 0.5 * (x_u + x_l)
    r  = 0.5 * (x_u - x_l)
    mu = F.conv2d(mu, W, b, stride=conv.stride, padding=conv.padding,
                  dilation=conv.dilation, groups=conv.groups)
    r  = F.conv2d(r,  W.abs(), None, stride=conv.stride, padding=conv.padding,
                  dilation=conv.dilation, groups=conv.groups)
    return mu - r, mu + r

def _interval_conv1d(x_l, x_u, conv: nn.Conv1d):
    W, b = conv.weight, conv.bias
    mu = 0.5 * (x_u + x_l)
    r  = 0.5 * (x_u - x_l)

    # Ensure (B, C_in, L)
    if mu.dim() == 3 and mu.shape[1] != W.shape[1] * conv.groups:
        mu = mu.transpose(1, 2)  # (B,L,C) -> (B,C,L)
        r  = r.transpose(1, 2)

    mu = F.conv1d(mu, W, b, stride=conv.stride, padding=conv.padding,
                  dilation=conv.dilation, groups=conv.groups)
    r  = F.conv1d(r,  W.abs(), None, stride=conv.stride, padding=conv.padding,
                  dilation=conv.dilation, groups=conv.groups)
    return mu - r, mu + r

def _interval_linear(x_l, x_u, linear: nn.Linear):
    W, b = linear.weight, linear.bias
    mu = 0.5 * (x_u + x_l)
    r  = 0.5 * (x_u - x_l)

    orig = mu.shape
    mu = mu.view(-1, mu.shape[-1])
    r  = r.view(-1,  r.shape[-1])

    mu = F.linear(mu, W, b)
    r  = F.linear(r,  W.abs(), None)

    y_l = (mu - r).view(*orig[:-1], -1)
    y_u = (mu + r).view(*orig[:-1], -1)
    return y_l, y_u

def _resolve_final_1x1(model: nn.Module) -> nn.Conv2d:
    # Common names first
    for name in ("final", "outc", "Conv_1x1", "conv_last", "classifier", "head"):
        if hasattr(model, name):
            mod = getattr(model, name)
            if isinstance(mod, nn.Conv2d):
                return mod
            # If it's a container, grab its last Conv2d
            for m in reversed(list(mod.modules())):
                if isinstance(m, nn.Conv2d):
                    return m
    # Fallback: last 1×1 Conv2d anywhere in the model
    for m in reversed(list(model.modules())):
        if isinstance(m, nn.Conv2d) and m.kernel_size == (1, 1):
            return m
    raise AttributeError(f"Could not resolve final 1×1 head for {model.__class__.__name__}")


# -----------------------------
# FALCONet module-wise interval forward
# -----------------------------
def interval_forward_falconet(x_l, x_u, module):
    if isinstance(module, nn.Conv2d):
        return _interval_conv2d(x_l, x_u, module)

    elif isinstance(module, nn.Conv1d):
        return _interval_conv1d(x_l, x_u, module)

    elif isinstance(module, nn.Sequential):
        for sub in module:
            x_l, x_u = interval_forward_falconet(x_l, x_u, sub)
        return x_l, x_u

    elif isinstance(module, nn.ReLU):
        return F.relu(x_l), F.relu(x_u)

    elif isinstance(module, nn.LeakyReLU):
        # Monotone increasing → apply elementwise
        return F.leaky_relu(x_l, negative_slope=module.negative_slope), \
               F.leaky_relu(x_u, negative_slope=module.negative_slope)

    elif isinstance(module, nn.BatchNorm2d):
        # ---- SAFE BN (eval-mode running stats) ----
        w  = module.weight.view(1, -1, 1, 1)
        b  = module.bias.view(1, -1, 1, 1)
        mu = module.running_mean.view(1, -1, 1, 1)
        var= module.running_var.view(1, -1, 1, 1)
    
        # guard tiny variance + clamp scale
        inv_std = torch.rsqrt(torch.clamp(var + module.eps, min=1e-2))
        s = torch.clamp(w * inv_std, min=-16.0, max=16.0)
        t = b - s * mu
    
        y_l = s * x_l + t
        y_u = s * x_u + t
    
        # # one-time debug to ensure THIS branch is actually hit
        # if not hasattr(module, "_bn_dbg_printed"):
        #     print("[bn] scale range:", float(s.min()), float(s.max()))
        #     module._bn_dbg_printed = True
    
        return torch.minimum(y_l, y_u), torch.maximum(y_l, y_u)

    elif isinstance(module, nn.Upsample):
        return (F.interpolate(x_l, scale_factor=module.scale_factor, mode=module.mode,
                              align_corners=getattr(module, 'align_corners', None)),
                F.interpolate(x_u, scale_factor=module.scale_factor, mode=module.mode,
                              align_corners=getattr(module, 'align_corners', None)))

    elif isinstance(module, nn.Identity):
        return x_l, x_u

    elif isinstance(module, nn.Linear):
        return _interval_linear(x_l, x_u, module)

    elif isinstance(module, nn.Module) and len(list(module.children())) > 0:
        for sub in module.children():
            x_l, x_u = interval_forward_falconet(x_l, x_u, sub)
        return x_l, x_u

    else:
        raise NotImplementedError(f"[FALCONet] Unsupported module type: {type(module)}")

# -----------------------------
# Specialized interval forward for MultiHeadConvAttention (safe fallback)
# -----------------------------
def interval_forward_tokenmixer(x_l, x_u, attn_module):
    """
    Token mixer fallback: propagate through depthwise/projections headwise,
    ignore attention weights (use V bounds). Shapes: (B,C,L) in/out.
    """
    assert x_l.ndim == 3 and x_u.ndim == 3, f"[token_mixer] Expected (B,C,L), got {x_l.shape}"
    B, C, L = x_l.shape
    head_dim  = attn_module.head_dim
    num_heads = attn_module.num_heads
    assert C == head_dim * num_heads, f"[token_mixer] C={C} != {head_dim}x{num_heads}"

    outs_l, outs_u = [], []
    for i, head in enumerate(attn_module.heads):
        s, e = i * head_dim, (i + 1) * head_dim
        # (B, head_dim, L)
        xl = x_l[:, s:e, :]
        xu = x_u[:, s:e, :]

        # Depthwise conv (1D)
        xl, xu = interval_forward_falconet(xl, xu, head.depthwise_conv)

        # Q,K,V projections
        ql, qu = interval_forward_falconet(xl, xu, head.q_proj)
        kl, ku = interval_forward_falconet(xl, xu, head.k_proj)
        vl, vu = interval_forward_falconet(xl, xu, head.v_proj)

        # Safe fallback: pass V only
        outl, outu = vl, vu
        outs_l.append(outl.transpose(1, 2))  # (B,L,head_dim)
        outs_u.append(outu.transpose(1, 2))  # (B,L,head_dim)

    # Concat heads on channel axis (B,L,C)
    out_l = torch.cat(outs_l, dim=2)
    out_u = torch.cat(outs_u, dim=2)

    # Final linear projection (treat as Linear over last dim)
    B, L, C = out_l.shape
    out_l = out_l.reshape(B * L, C)
    out_u = out_u.reshape(B * L, C)
    out_l, out_u = interval_forward_falconet(out_l, out_u, attn_module.out_proj)
    out_l = out_l.view(B, L, -1).transpose(1, 2)  # (B,C_out,L)
    out_u = out_u.view(B, L, -1).transpose(1, 2)  # (B,C_out,L)
    return out_l, out_u

# -----------------------------
# Top-level verifier for Simple Encoder-Decoder
# -----------------------------
def propagate_bounds_encdecnet(model, zonotope):
    x_l, x_u = zonotope.lower, zonotope.upper

    # Encoder
    x1_l, x1_u = interval_forward_falconet(x_l, x_u, model.inc)
    x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
    x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
    x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
    x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)

    # Decoder
    def up_step(up, xl, xu, sk_l, sk_u):
        xl = up.up(xl); xu = up.up(xu)
        xl = torch.cat([sk_l, xl], dim=1)
        xu = torch.cat([sk_u, xu], dim=1)
        return interval_forward_falconet(xl, xu, up.conv)

    x_l, x_u = up_step(model.up1, x5_l, x5_u, x4_l, x4_u)
    x_l, x_u = up_step(model.up2, x_l,  x_u,  x3_l, x3_u)
    x_l, x_u = up_step(model.up3, x_l,  x_u,  x2_l, x2_u)
    x_l, x_u = up_step(model.up4, x_l,  x_u,  x1_l, x1_u)
    x_l, x_u = interval_forward_falconet(x_l, x_u, model.outc)
    return x_l, x_u

# -----------------------------
# Top-level verifier for FALCONet + token mixers
# -----------------------------
def propagate_bounds_falconetmha_lirpa(model, zonotope):
    x_l, x_u = zonotope.lower, zonotope.upper

    # Encoder
    x1_l, x1_u = interval_forward_falconet(x_l, x_u, model.inc)
    x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
    x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)

    x3_l, shape3 = flatten_hw(x3_l); x3_u, _ = flatten_hw(x3_u)
    x3_l, x3_u   = interval_forward_tokenmixer(x3_l, x3_u, model.token_mixer_2)
    x3_l, x3_u   = unflatten_hw(x3_l, shape3), unflatten_hw(x3_u, shape3)

    x4_l, x4_u   = interval_forward_falconet(x3_l, x3_u, model.down3)
    x4_l, shape4 = flatten_hw(x4_l); x4_u, _ = flatten_hw(x4_u)
    x4_l, x4_u   = interval_forward_tokenmixer(x4_l, x4_u, model.token_mixer_3)
    x4_l, x4_u   = unflatten_hw(x4_l, shape4), unflatten_hw(x4_u, shape4)

    x5_l, x5_u   = interval_forward_falconet(x4_l, x4_u, model.down4)
    x5_l, shape5 = flatten_hw(x5_l); x5_u, _ = flatten_hw(x5_u)
    x5_l, x5_u   = interval_forward_tokenmixer(x5_l, x5_u, model.token_mixer_4)
    x5_l, x5_u   = unflatten_hw(x5_l, shape5), unflatten_hw(x5_u, shape5)

    # Decoder
    def up_step(up, xl, xu, sk_l, sk_u):
        xl = up.up(xl); xu = up.up(xu)
        xl = torch.cat([sk_l, xl], dim=1)
        xu = torch.cat([sk_u, xu], dim=1)
        return interval_forward_falconet(xl, xu, up.conv)

    x_l, x_u = up_step(model.up1, x5_l, x5_u, x4_l, x4_u)
    x_l, x_u = up_step(model.up2, x_l,  x_u,  x3_l, x3_u)
    x_l, x_u = up_step(model.up3, x_l,  x_u,  x2_l, x2_u)
    x_l, x_u = up_step(model.up4, x_l,  x_u,  x1_l, x1_u)
    x_l, x_u = interval_forward_falconet(x_l, x_u, model.outc)
    return x_l, x_u

# -----------------------------
# Interval forward for Attention U-Net (modules)
# -----------------------------
def interval_forward_attunet(x_l, x_u, module):
    if isinstance(module, nn.Conv2d):
        return _interval_conv2d(x_l, x_u, module)

    elif isinstance(module, nn.Sequential):
        for sub in module:
            x_l, x_u = interval_forward_attunet(x_l, x_u, sub)
        return x_l, x_u

    elif isinstance(module, nn.ReLU):
        return F.relu(x_l), F.relu(x_u)

    elif isinstance(module, nn.Sigmoid):
        # Sigmoid is monotone; clamp inputs to avoid overflow
        return torch.sigmoid(x_l.clamp(-10, 10)), torch.sigmoid(x_u.clamp(-10, 10))

    elif isinstance(module, nn.LeakyReLU):
        return F.leaky_relu(x_l, negative_slope=module.negative_slope), \
               F.leaky_relu(x_u, negative_slope=module.negative_slope)

    elif isinstance(module, nn.BatchNorm2d):
        # ---- SAFE BN (eval-mode running stats) ----
        w  = module.weight.view(1, -1, 1, 1)
        b  = module.bias.view(1, -1, 1, 1)
        mu = module.running_mean.view(1, -1, 1, 1)
        var= module.running_var.view(1, -1, 1, 1)
    
        # guard tiny variance + clamp scale
        inv_std = torch.rsqrt(torch.clamp(var + module.eps, min=1e-2))
        s = torch.clamp(w * inv_std, min=-16.0, max=16.0)
        t = b - s * mu
    
        y_l = s * x_l + t
        y_u = s * x_u + t
    
        # # one-time debug to ensure THIS branch is actually hit
        # if not hasattr(module, "_bn_dbg_printed"):
        #     print("[bn] scale range:", float(s.min()), float(s.max()))
        #     module._bn_dbg_printed = True
    
        return torch.minimum(y_l, y_u), torch.maximum(y_l, y_u)

    elif isinstance(module, nn.Upsample):
        return (F.interpolate(x_l, scale_factor=module.scale_factor, mode=module.mode,
                              align_corners=getattr(module, 'align_corners', None)),
                F.interpolate(x_u, scale_factor=module.scale_factor, mode=module.mode,
                              align_corners=getattr(module, 'align_corners', None)))

    elif isinstance(module, nn.Identity):
        return x_l, x_u

    elif isinstance(module, nn.Module) and len(list(module.children())) > 0:
        for sub in module.children():
            x_l, x_u = interval_forward_attunet(x_l, x_u, sub)
        return x_l, x_u

    else:
        raise NotImplementedError(f"[AttU_Net] Unsupported module type: {type(module)}")

# -----------------------------
# Attention gate (psi sigmoid handled safely)
# -----------------------------
def interval_mul_pos(x_l, x_u, p_l, p_u):
    """
    Interval product when 0 <= p_l <= p_u <= 1 (sigmoid outputs).
    Returns tight lower/upper by enumerating endpoints.
    """
    c1 = x_l * p_l
    c2 = x_l * p_u
    c3 = x_u * p_l
    c4 = x_u * p_u
    lo = torch.minimum(torch.minimum(c1, c2), torch.minimum(c3, c4))
    hi = torch.maximum(torch.maximum(c1, c2), torch.maximum(c3, c4))
    return lo, hi
    
def interval_forward_attention_gate(attn_block, g_l, g_u, x_l, x_u):
    g1_l, g1_u = interval_forward_attunet(g_l, g_u, attn_block.W_g)
    x1_l, x1_u = interval_forward_attunet(x_l, x_u, attn_block.W_x)

    psi_l = F.relu(g1_l + x1_l)
    psi_u = F.relu(g1_u + x1_u)
    psi_l, psi_u = interval_forward_attunet(psi_l, psi_u, attn_block.psi)

    # Sigmoid is monotone; clamp inputs to avoid overflow
    psi_l = torch.sigmoid(psi_l.clamp(-10, 10))
    psi_u = torch.sigmoid(psi_u.clamp(-10, 10))

    return interval_mul_pos(x_l, x_u, psi_l, psi_u)

# -----------------------------
# Top-level verifier for Attention U-Net
# -----------------------------
def propagate_bounds_attunet_lirpa(model, zonotope):
    x_l, x_u = zonotope.lower, zonotope.upper

    # Encoder
    x1_l, x1_u = interval_forward_attunet(x_l, x_u, model.Conv1)
    # print("[dbg] after Conv1:", float(x1_l.min()), float(x1_l.max()), float(x1_u.min()), float(x1_u.max()))

    x2_l = F.max_pool2d(x1_l, kernel_size=2, stride=2)
    x2_u = F.max_pool2d(x1_u, kernel_size=2, stride=2)
    x2_l, x2_u = interval_forward_attunet(x2_l, x2_u, model.Conv2)

    x3_l = F.max_pool2d(x2_l, kernel_size=2, stride=2)
    x3_u = F.max_pool2d(x2_u, kernel_size=2, stride=2)
    x3_l, x3_u = interval_forward_attunet(x3_l, x3_u, model.Conv3)

    x4_l = F.max_pool2d(x3_l, kernel_size=2, stride=2)
    x4_u = F.max_pool2d(x3_u, kernel_size=2, stride=2)
    x4_l, x4_u = interval_forward_attunet(x4_l, x4_u, model.Conv4)

    x5_l = F.max_pool2d(x4_l, kernel_size=2, stride=2)
    x5_u = F.max_pool2d(x4_u, kernel_size=2, stride=2)
    x5_l, x5_u = interval_forward_attunet(x5_l, x5_u, model.Conv5)

    # Decoder with attention
    d5_l, d5_u = interval_forward_attunet(x5_l, x5_u, model.Up5)
    x4_l_att, x4_u_att = interval_forward_attention_gate(model.Att5, d5_l, d5_u, x4_l, x4_u)
    d5_l, d5_u = torch.cat((x4_l_att, d5_l), dim=1), torch.cat((x4_u_att, d5_u), dim=1)
    d5_l, d5_u = interval_forward_attunet(d5_l, d5_u, model.Up_conv5)

    d4_l, d4_u = interval_forward_attunet(d5_l, d5_u, model.Up4)
    x3_l_att, x3_u_att = interval_forward_attention_gate(model.Att4, d4_l, d4_u, x3_l, x3_u)
    d4_l, d4_u = torch.cat((x3_l_att, d4_l), dim=1), torch.cat((x3_u_att, d4_u), dim=1)
    d4_l, d4_u = interval_forward_attunet(d4_l, d4_u, model.Up_conv4)

    d3_l, d3_u = interval_forward_attunet(d4_l, d4_u, model.Up3)
    x2_l_att, x2_u_att = interval_forward_attention_gate(model.Att3, d3_l, d3_u, x2_l, x2_u)
    d3_l, d3_u = torch.cat((x2_l_att, d3_l), dim=1), torch.cat((x2_u_att, d3_u), dim=1)
    d3_l, d3_u = interval_forward_attunet(d3_l, d3_u, model.Up_conv3)

    d2_l, d2_u = interval_forward_attunet(d3_l, d3_u, model.Up2)
    x1_l_att, x1_u_att = interval_forward_attention_gate(model.Att2, d2_l, d2_u, x1_l, x1_u)
    d2_l, d2_u = torch.cat((x1_l_att, d2_l), dim=1), torch.cat((x1_u_att, d2_u), dim=1)
    d2_l, d2_u = interval_forward_attunet(d2_l, d2_u, model.Up_conv2)

    out_l, out_u = interval_forward_attunet(d2_l, d2_u, model.Conv_1x1)
    return out_l, out_u

# ---------- Exact affine bounds for final Conv2d / margin ----------

@torch.no_grad()
def affine_bounds_conv2d(l, u, conv: nn.Conv2d):
    """
    Exact interval bounds for y = Conv2d(x; W,b) given l<=x<=u.
    Works for any kernel/stride/padding/groups. Returns (y_lb, y_ub) in logit space.
    """
    if not hasattr(affine_bounds_conv2d, "_dbg_printed"):
        print("[path] using affine_bounds_conv2d for final logits")
        affine_bounds_conv2d._dbg_printed = True
    assert isinstance(conv, nn.Conv2d)
    W, b = conv.weight, conv.bias
    Wpos = torch.clamp(W, min=0)
    Wneg = torch.clamp(W, max=0)

    y_lb = F.conv2d(l, Wpos, bias=None,
                    stride=conv.stride, padding=conv.padding,
                    dilation=conv.dilation, groups=conv.groups) \
         + F.conv2d(u, Wneg, bias=None,
                    stride=conv.stride, padding=conv.padding,
                    dilation=conv.dilation, groups=conv.groups)
    if b is not None:
        y_lb = y_lb + b.view(1, -1, 1, 1)

    y_ub = F.conv2d(u, Wpos, bias=None,
                    stride=conv.stride, padding=conv.padding,
                    dilation=conv.dilation, groups=conv.groups) \
         + F.conv2d(l, Wneg, bias=None,
                    stride=conv.stride, padding=conv.padding,
                    dilation=conv.dilation, groups=conv.groups)
    if b is not None:
        y_ub = y_ub + b.view(1, -1, 1, 1)

    return y_lb, y_ub


@torch.no_grad()
def affine_margin_bounds_conv2d(l, u, conv: nn.Conv2d, chg_idx=1, nchg_idx=0):
    """
    Tight lower/upper bounds for the logit margin m = z_chg - z_nchg
    using a *single* affine bound:
       Wm := W[chg] - W[nchg],  bm := b[chg] - b[nchg]
       m_lb = conv2d(l, Wm_pos) + conv2d(u, Wm_neg) + bm
       m_ub = conv2d(u, Wm_pos) + conv2d(l, Wm_neg) + bm
    """
    W, b = conv.weight, conv.bias
    Wm = (W[chg_idx:chg_idx+1] - W[nchg_idx:nchg_idx+1])  # (1, Cin, kH, kW)
    bm = None if b is None else (b[chg_idx] - b[nchg_idx]).view(1)

    Wm_pos = torch.clamp(Wm, min=0)
    Wm_neg = torch.clamp(Wm, max=0)

    m_lb = F.conv2d(l, Wm_pos, bias=None,
                    stride=conv.stride, padding=conv.padding,
                    dilation=conv.dilation, groups=conv.groups) \
         + F.conv2d(u, Wm_neg, bias=None,
                    stride=conv.stride, padding=conv.padding,
                    dilation=conv.dilation, groups=conv.groups)
    m_ub = F.conv2d(u, Wm_pos, bias=None,
                    stride=conv.stride, padding=conv.padding,
                    dilation=conv.dilation, groups=conv.groups) \
         + F.conv2d(l, Wm_neg, bias=None,
                    stride=conv.stride, padding=conv.padding,
                    dilation=conv.dilation, groups=conv.groups)

    if bm is not None:
        bm = bm.view(1, 1, 1, 1)
        m_lb = m_lb + bm
        m_ub = m_ub + bm

    # squeeze channel dim -> (N,H,W)
    return m_lb.squeeze(1), m_ub.squeeze(1)

def _span_mean_max(l, u):
    w = (u - l).abs()
    return float(w.mean()), float(w.max())

# --- Encoder-Decoder (stop before model.outc) ---
def propagate_prelogits_encdecnet(model, z):
    x_l, x_u = z.lower, z.upper

    # encoder
    x1_l, x1_u = interval_forward_falconet(x_l,  x_u,  model.inc)
    x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
    x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
    x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
    x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)

    # helper for the first three up blocks
    def up_step(up, xl, xu, sk_l, sk_u):
        xl = up.up(xl); xu = up.up(xu)
        xl = torch.cat([sk_l, xl], dim=1)
        xu = torch.cat([sk_u, xu], dim=1)
        return interval_forward_falconet(xl, xu, up.conv)

    # decoder up1..up3 as before
    x_l, x_u = up_step(model.up1, x5_l, x5_u, x4_l, x4_u)
    x_l, x_u = up_step(model.up2, x_l,  x_u,  x3_l, x3_u)
    x_l, x_u = up_step(model.up3, x_l,  x_u,  x2_l, x2_u)

    # ---- up4 expanded so we can TAP right before its DoubleConv ----
    xup_l = model.up4.up(x_l)
    xup_u = model.up4.up(x_u)
    inp_l = torch.cat([x1_l, xup_l], dim=1)   # <-- input to final DoubleConv
    inp_u = torch.cat([x1_u, xup_u], dim=1)

    # CROWN-tail tap: start at DoubleConv(input) and end at final 1×1 head
    if 'CROWN_TAIL' in globals():
        CROWN_TAIL['l0']    = inp_l.detach().clone()
        CROWN_TAIL['u0']    = inp_u.detach().clone()
        CROWN_TAIL['block'] = getattr(model.up4, 'conv')
        CROWN_TAIL['final'] = getattr(model, 'outc', None) or getattr(model, 'final')
        if not CROWN_TAIL.get('_encdec_tap_printed', False):
            print("[tail] EncDec tap @ up4: DoubleConv + final head set.")
            CROWN_TAIL['_encdec_tap_printed'] = True

    # run the last DoubleConv to produce pre-logits
    x_l, x_u = interval_forward_falconet(inp_l, inp_u, model.up4.conv)
    return x_l, x_u   # <--- pre-logits (input to model.outc / model.final)




# # --- Encoder-Decoder (stop before model.outc) ---
# def propagate_prelogits_encdecnet(model, z):
#     x_l, x_u = z.lower, z.upper
#     x1_l, x1_u = interval_forward_falconet(x_l, x_u, model.inc)
#     x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
#     x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
#     x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
#     x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)
#     def up_step(up, xl, xu, sk_l, sk_u):
#         xl = up.up(xl); xu = up.up(xu)
#         xl = torch.cat([sk_l, xl], dim=1)
#         xu = torch.cat([sk_u, xu], dim=1)
#         return interval_forward_falconet(xl, xu, up.conv)
#     x_l, x_u = up_step(model.up1, x5_l, x5_u, x4_l, x4_u)
#     x_l, x_u = up_step(model.up2, x_l,  x_u,  x3_l, x3_u)
#     x_l, x_u = up_step(model.up3, x_l,  x_u,  x2_l, x2_u)
#     x_l, x_u = up_step(model.up4, x_l,  x_u,  x1_l, x1_u)
#     return x_l, x_u   # <--- pre-logits; apply model.outc via affine bounds


# --- FALCONet + token mixer (stop before model.outc) ---
def propagate_prelogits_falconetmha(model, z):
    x_l, x_u = z.lower, z.upper

    # encoder + mixers
    x1_l, x1_u = interval_forward_falconet(x_l, x_u, model.inc)

    x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)

    x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
    x3_l, sh3  = flatten_hw(x3_l);  x3_u, _ = flatten_hw(x3_u)
    x3_l, x3_u = interval_forward_tokenmixer(x3_l, x3_u, model.token_mixer_2)
    x3_l, x3_u = unflatten_hw(x3_l, sh3), unflatten_hw(x3_u, sh3)

    x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
    x4_l, sh4  = flatten_hw(x4_l);  x4_u, _ = flatten_hw(x4_u)
    x4_l, x4_u = interval_forward_tokenmixer(x4_l, x4_u, model.token_mixer_3)
    x4_l, x4_u = unflatten_hw(x4_l, sh4), unflatten_hw(x4_u, sh4)

    x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)
    x5_l, sh5  = flatten_hw(x5_l);  x5_u, _ = flatten_hw(x5_u)
    x5_l, x5_u = interval_forward_tokenmixer(x5_l, x5_u, model.token_mixer_4)
    x5_l, x5_u = unflatten_hw(x5_l, sh5), unflatten_hw(x5_u, sh5)

    # decoder helper for first three up blocks (unchanged)
    def up_step(up, xl, xu, sk_l, sk_u):
        xl = up.up(xl); xu = up.up(xu)                # (uses whatever op up.up is)
        xl = torch.cat([sk_l, xl], dim=1)
        xu = torch.cat([sk_u, xu], dim=1)
        return interval_forward_falconet(xl, xu, up.conv)

    # up1..up3 as before
    x_l, x_u = up_step(model.up1, x5_l, x5_u, x4_l, x4_u)
    x_l, x_u = up_step(model.up2, x_l,  x_u,  x3_l, x3_u)
    x_l, x_u = up_step(model.up3, x_l,  x_u,  x2_l, x2_u)

    # ---- up4 expanded so we can TAP right before its DoubleConv ----
    xup_l = model.up4.up(x_l)
    xup_u = model.up4.up(x_u)
    inp_l = torch.cat([x1_l, xup_l], dim=1)          # <-- input to last DoubleConv
    inp_u = torch.cat([x1_u, xup_u], dim=1)

    # CROWN-tail tap: start at DoubleConv(input) and end at the final 1×1 head
    if 'CROWN_TAIL' in globals():
        CROWN_TAIL['l0']    = inp_l.detach().clone()
        CROWN_TAIL['u0']    = inp_u.detach().clone()
        CROWN_TAIL['block'] = getattr(model.up4, 'conv')       # last DoubleConv module
        CROWN_TAIL['final'] = getattr(model, 'outc', None) or getattr(model, 'final')
        # (optional) print to confirm once
        if not CROWN_TAIL.get('_falco_tap_printed', False):
            print("[tail] FALCONet tap @ up4: DoubleConv + final head set.")
            CROWN_TAIL['_falco_tap_printed'] = True

    # run the last DoubleConv to produce pre-logits
    x_l, x_u = interval_forward_falconet(inp_l, inp_u, model.up4.conv)
    return x_l, x_u   # pre-logits (input to model.outc)



# # --- FALCONet + token mixer (stop before model.outc) ---
# def propagate_prelogits_falconetmha(model, z):
#     x_l, x_u = z.lower, z.upper
#     x1_l, x1_u = interval_forward_falconet(x_l, x_u, model.inc)
#     x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
#     x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
#     x3_l, sh3 = flatten_hw(x3_l); x3_u, _ = flatten_hw(x3_u)
#     x3_l, x3_u = interval_forward_tokenmixer(x3_l, x3_u, model.token_mixer_2)
#     x3_l, x3_u = unflatten_hw(x3_l, sh3), unflatten_hw(x3_u, sh3)
#     x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
#     x4_l, sh4  = flatten_hw(x4_l); x4_u, _ = flatten_hw(x4_u)
#     x4_l, x4_u = interval_forward_tokenmixer(x4_l, x4_u, model.token_mixer_3)
#     x4_l, x4_u = unflatten_hw(x4_l, sh4), unflatten_hw(x4_u, sh4)
#     x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)
#     x5_l, sh5  = flatten_hw(x5_l); x5_u, _ = flatten_hw(x5_u)
#     x5_l, x5_u = interval_forward_tokenmixer(x5_l, x5_u, model.token_mixer_4)
#     x5_l, x5_u = unflatten_hw(x5_l, sh5), unflatten_hw(x5_u, sh5)
#     # decoder
#     def up_step(up, xl, xu, sk_l, sk_u):
#         xl = up.up(xl); xu = up.up(xu)
#         xl = torch.cat([sk_l, xl], dim=1)
#         xu = torch.cat([sk_u, xu], dim=1)
#         return interval_forward_falconet(xl, xu, up.conv)
#     x_l, x_u = up_step(model.up1, x5_l, x5_u, x4_l, x4_u)
#     x_l, x_u = up_step(model.up2, x_l,  x_u,  x3_l, x3_u)
#     x_l, x_u = up_step(model.up3, x_l,  x_u,  x2_l, x2_u)
#     x_l, x_u = up_step(model.up4, x_l,  x_u,  x1_l, x1_u)
#     return x_l, x_u   # <--- pre-logits

# --- Attention U-Net (stop before Conv_1x1) ---
def propagate_prelogits_attunet(model, z):
    x_l, x_u = z.lower, z.upper
    x1_l, x1_u = interval_forward_attunet(x_l, x_u, model.Conv1)
    print("[w] after Conv1 :", _span_mean_max(x1_l, x1_u))
    x2_l = F.max_pool2d(x1_l, 2, 2); x2_u = F.max_pool2d(x1_u, 2, 2)
    x2_l, x2_u = interval_forward_attunet(x2_l, x2_u, model.Conv2)
    x3_l = F.max_pool2d(x2_l, 2, 2); x3_u = F.max_pool2d(x2_u, 2, 2)
    x3_l, x3_u = interval_forward_attunet(x3_l, x3_u, model.Conv3)
    print("[w] after Conv3 :", _span_mean_max(x3_l, x3_u))
    x4_l = F.max_pool2d(x3_l, 2, 2); x4_u = F.max_pool2d(x3_u, 2, 2)
    x4_l, x4_u = interval_forward_attunet(x4_l, x4_u, model.Conv4)
    x5_l = F.max_pool2d(x4_l, 2, 2); x5_u = F.max_pool2d(x4_u, 2, 2)
    x5_l, x5_u = interval_forward_attunet(x5_l, x5_u, model.Conv5)
    print("[w] after Conv5 :", _span_mean_max(x5_l, x5_u))
    # decoder + attention
    d5_l, d5_u = interval_forward_attunet(x5_l, x5_u, model.Up5)
    x4_l_att, x4_u_att = interval_forward_attention_gate(model.Att5, d5_l, d5_u, x4_l, x4_u)
    d5_l, d5_u = torch.cat((x4_l_att, d5_l), 1), torch.cat((x4_u_att, d5_u), 1)
    d5_l, d5_u = interval_forward_attunet(d5_l, d5_u, model.Up_conv5)
    d4_l, d4_u = interval_forward_attunet(d5_l, d5_u, model.Up4)
    x3_l_att, x3_u_att = interval_forward_attention_gate(model.Att4, d4_l, d4_u, x3_l, x3_u)
    d4_l, d4_u = torch.cat((x3_l_att, d4_l), 1), torch.cat((x3_u_att, d4_u), 1)
    d4_l, d4_u = interval_forward_attunet(d4_l, d4_u, model.Up_conv4)
    d3_l, d3_u = interval_forward_attunet(d4_l, d4_u, model.Up3)
    x2_l_att, x2_u_att = interval_forward_attention_gate(model.Att3, d3_l, d3_u, x2_l, x2_u)
    d3_l, d3_u = torch.cat((x2_l_att, d3_l), 1), torch.cat((x2_u_att, d3_u), 1)
    d3_l, d3_u = interval_forward_attunet(d3_l, d3_u, model.Up_conv3)
    print("[w] after Up_conv3:", _span_mean_max(d3_l, d3_u))   
    d2_l, d2_u = interval_forward_attunet(d3_l, d3_u, model.Up2)
    x1_l_att, x1_u_att = interval_forward_attention_gate(model.Att2, d2_l, d2_u, x1_l, x1_u)
    d2_l, d2_u = torch.cat((x1_l_att, d2_l), 1), torch.cat((x1_u_att, d2_u), 1)

    # === tap tail input for CROWN (place this right after you have l,u at Up_conv3 output) ===
    if 'CROWN_TAIL' in globals():
        CROWN_TAIL['l0']    = d3_l.detach().clone()    # bounds BEFORE Up_conv2
        CROWN_TAIL['u0']    = d3_u.detach().clone()
        CROWN_TAIL['block'] = model.Up_conv2           # last DoubleConv block
        CROWN_TAIL['final'] = _resolve_final_1x1(model)  # <-- FIXED
        if not CROWN_TAIL.get('_att_printed', False):
            print("[tail] AttU-Net tap @ Up_conv2 + final 1×1 set.")
            CROWN_TAIL['_att_printed'] = True
    # === end tap ===
    
    d2_l, d2_u = interval_forward_attunet(d2_l, d2_u, model.Up_conv2)
    print("[w] after Up_conv2:", _span_mean_max(d2_l, d2_u))
    return d2_l, d2_u  # <--- pre-logits (before model.Conv_1x1)


# --- Drop-in patch for "final logits layer" so affine_bounds_conv2d accepts containers ---
import torch.nn as nn

def _unwrap_to_last_conv2d(mod):
    # already a conv?
    if isinstance(mod, nn.Conv2d):
        return mod
    # tuple of (W, b)?
    if isinstance(mod, tuple) and len(mod) == 2:
        W, b = mod
        fake = nn.Conv2d(W.shape[1], W.shape[0], kernel_size=1, bias=b is not None)
        with torch.no_grad():
            fake.weight.copy_(W)
            if b is not None:
                fake.bias.copy_(b)
        return fake
    # otherwise: search deepest Conv2d leaf
    last = None
    # named_modules() includes the module itself; we want the deepest Conv2d
    for name, m in mod.named_modules():
        if isinstance(m, nn.Conv2d):
            last = m
    assert last is not None, f"No nn.Conv2d found inside {type(mod)}"
    return last

# keep original around
_affine_bounds_conv2d_orig = affine_bounds_conv2d

def affine_bounds_conv2d(feat_l, feat_u, conv_like):
    conv = _unwrap_to_last_conv2d(conv_like)

    # sanity: we expect a 1x1, stride 1, no padding, groups=1 for the last affine layer
    k = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size)
    s = conv.stride if isinstance(conv.stride, tuple) else (conv.stride, conv.stride)
    p = conv.padding if isinstance(conv.padding, tuple) else (conv.padding, conv.padding)
    g = conv.groups
    assert k == (1, 1) and s == (1, 1) and p == (0, 0) and g == 1, \
        f"Expected final 1x1 conv; got k={k}, s={s}, p={p}, groups={g}"

    return _affine_bounds_conv2d_orig(feat_l, feat_u, conv)

# --- Drop-in shim: make both affine_*bounds* funcs accept OutConv/Sequential/(W,b) ---
import torch
import torch.nn as nn

def _unwrap_to_last_conv2d(mod):
    # Already a Conv2d?
    if isinstance(mod, nn.Conv2d):
        return mod
    # Common wrapper on UNet heads: .conv is the actual 1x1 conv
    if hasattr(mod, "conv") and isinstance(mod.conv, nn.Conv2d):
        return mod.conv
    # Passed as raw (weight, bias) tuple?
    if isinstance(mod, tuple) and len(mod) == 2 and torch.is_tensor(mod[0]):
        W, b = mod
        fake = nn.Conv2d(W.shape[1], W.shape[0], kernel_size=1, bias=b is not None)
        with torch.no_grad():
            fake.weight.copy_(W)
            if b is not None:
                fake.bias.copy_(b)
        return fake
    # Otherwise: search deepest Conv2d leaf
    last = None
    for _, m in mod.named_modules():
        if isinstance(m, nn.Conv2d):
            last = m
    assert last is not None, f"No nn.Conv2d found inside {type(mod)}"
    return last

# Keep originals
try:
    _affine_bounds_conv2d_orig = affine_bounds_conv2d
except NameError:
    pass

try:
    _affine_margin_bounds_conv2d_orig = affine_margin_bounds_conv2d
except NameError:
    pass

# Wrap: affine_bounds_conv2d
def affine_bounds_conv2d(feat_l, feat_u, conv_like):
    conv = _unwrap_to_last_conv2d(conv_like)
    # sanity: final should be 1x1, stride 1, no padding, groups=1
    k = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size)
    s = conv.stride if isinstance(conv.stride, tuple) else (conv.stride, conv.stride)
    p = conv.padding if isinstance(conv.padding, tuple) else (conv.padding, conv.padding)
    g = conv.groups
    assert k == (1, 1) and s == (1, 1) and p == (0, 0) and g == 1, \
        f"Expected final 1x1 conv; got k={k}, s={s}, p={p}, groups={g}"
    return _affine_bounds_conv2d_orig(feat_l, feat_u, conv)

# Wrap: affine_margin_bounds_conv2d
def affine_margin_bounds_conv2d(feat_l, feat_u, conv_like, chg_idx, nchg_idx):
    conv = _unwrap_to_last_conv2d(conv_like)
    k = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size)
    s = conv.stride if isinstance(conv.stride, tuple) else (conv.stride, conv.stride)
    p = conv.padding if isinstance(conv.padding, tuple) else (conv.padding, conv.padding)
    g = conv.groups
    assert k == (1, 1) and s == (1, 1) and p == (0, 0) and g == 1, \
        f"Expected final 1x1 conv; got k={k}, s={s}, p={p}, groups={g}"
    return _affine_margin_bounds_conv2d_orig(feat_l, feat_u, conv, chg_idx, nchg_idx)


# CERTIFICATION - structural predicates & plots 
# ==================================================
# 1. CERTIFICATION DIAGNOSTICS FUNCTION
# ==================================================
def save_maps(lower, upper, model_logits, out_prefix="diagnostic"):
    # lower, upper: tensors (1,C,H,W) or (1,H,W) after squeeze/calling .detach().cpu().numpy()
    l = lower.squeeze().detach().cpu().numpy()
    u = upper.squeeze().detach().cpu().numpy()
    m = (l + u) / 2
    pred = model_logits.squeeze().detach().cpu().numpy()

    # If multi-channel (C>1) pick change channel (e.g., channel 1)
    if l.ndim == 3:
        l = l[0]
        u = u[0]
        m = m[0]
        pred = pred[0]

    np.save(f"{out_prefix}_lower.npy", l)
    np.save(f"{out_prefix}_upper.npy", u)
    np.save(f"{out_prefix}_mid.npy", m)
    np.save(f"{out_prefix}_pred.npy", pred)

    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1); plt.title("lower"); plt.imshow(l, cmap='viridis'); plt.colorbar()
    plt.subplot(1,3,2); plt.title("mid");   plt.imshow(m, cmap='viridis'); plt.colorbar()
    plt.subplot(1,3,3); plt.title("pred");  plt.imshow(pred, cmap='viridis'); plt.colorbar()
    plt.tight_layout()
    plt.savefig(f"{out_prefix}_maps.png")
    print("Saved maps to", f"{out_prefix}_maps.png")

def margins_and_hist(lower, tau=0.9, out_prefix="margin"):
    l = lower.squeeze().detach().cpu().numpy()
    if l.ndim == 3: l = l[0]
    margin = l - tau
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1); plt.imshow(margin, cmap='coolwarm', vmin=-0.2, vmax=0.2); plt.colorbar(); plt.title('lower - tau')
    plt.subplot(1,2,2); plt.hist(margin.ravel(), bins=100); plt.title('margin histogram')
    plt.savefig(out_prefix + ".png")

def run_certification_diagnostics(city, model, mapsave=False, model_type=0, eps=8/255., tau=0.9, k=30, s_min=20, out_prefix="diag", gt_mask=None):
    def resample_band_to_shape(band, src_transform, src_crs, target_shape, target_transform, target_crs):
        dst = np.zeros(target_shape, dtype=np.float32)
        rasterio.warp.reproject(
            source=band,
            destination=dst,
            src_transform=src_transform,
            src_crs=src_crs,
            dst_transform=target_transform,
            dst_crs=target_crs,
            resampling=Resampling.bilinear
        )
        return dst

    def load_oscd_patch(imgs_dir):
        bands = []
        reference_path = [f for f in os.listdir(imgs_dir) if "B02" in f][0]
        reference_path = os.path.join(imgs_dir, reference_path)

        with rasterio.open(reference_path) as ref:
            target_shape = (256,256)
            target_transform = ref.transform
            target_crs = ref.crs

        for band_name in BANDS:
            candidates = [f for f in os.listdir(imgs_dir) if f.endswith(f"{band_name}.tif")]
            band_path = os.path.join(imgs_dir, candidates[0])
            with rasterio.open(band_path) as src:
                band = src.read(1).astype(np.float32) / 10000.0
                if src.shape != target_shape:
                    band = resample_band_to_shape(band, src.transform, src.crs,
                                                  target_shape, target_transform, target_crs)
                bands.append(band)
        return np.stack(bands, axis=-1)

    def pad_to_multiple(x, multiple=16):
        _, _, h, w = x.shape
        pad_h = (multiple - h % multiple) % multiple
        pad_w = (multiple - w % multiple) % multiple
        return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h)), (0, pad_w, 0, pad_h)
    
    # --- RAW tensors (no normalization here) ---
    img1 = load_oscd_patch(OSCD_PATH_TOP + city + OSCD_PATH_BOTTOM1)  # H,W,13  RAW
    img2 = load_oscd_patch(OSCD_PATH_TOP + city + OSCD_PATH_BOTTOM2)  # H,W,13  RAW
    x_np = np.concatenate([img1, img2], axis=-1)                         # H,W,26  RAW
    # --- per-image scalar z-score, then concat (matches OSCD inference) ---
    C1 = img1.shape[2]  # 13 bands per timestamp

    # --- ADD THESE TWO LINES: define x_raw and pad ---
    x_raw = torch.tensor(x_np.transpose(2, 0, 1)).unsqueeze(0).float()   # (1,26,H,W) RAW
    x_raw, pad = pad_to_multiple(x_raw, multiple=16)
    # NEW: per-timestamp scalar stats (match OSCD inference)
    mu1  = float(img1.mean());  sig1 = float(img1.std() + 1e-6)
    mu2  = float(img2.mean());  sig2 = float(img2.std() + 1e-6)
    

    
    x1 = (x_raw[:, :C1] - mu1) / sig1    # (1, 13, H, W)
    x2 = (x_raw[:, C1:] - mu2) / sig2    # (1, 13, H, W)
    x  = torch.cat([x1, x2], dim=1)      # (1, 26, H, W)  <-- this is what the model sees
    
    # per-half epsilon scaling (raw eps divided by each timestamp’s std)
    eps1 = float(eps / sig1)
    eps2 = float(eps / sig2)
    eps_vec = torch.tensor([eps1]*C1 + [eps2]*C1, dtype=x.dtype, device=x.device).view(1, 2*C1, 1, 1)
    
    # # Zonotope centered at the normalized x
    # #zonotope = Zonotope(center=x, epsilon=eps_vec)
    # zonotope = Zonotope(center=x, epsilon=float(max(eps1, eps2)))  # scalar ε

    # zl, zu = zonotope.lower, zonotope.upper
    # print("zono lower/upper:", float(zl.min()), float(zl.max()),
    #       float(zu.min()), float(zu.max()))

    # Replace Zonotope(...) with a dead-simple interval container
    #K = 8.0  # safe z-score clamp (tweakable: 6.0–8.0 are typical)
    K = 4.0   # → tighten
    x = x.clamp(-K, K)
    
    eps_scalar = float(max(eps1, eps2))
    zl = (x - eps_scalar).clamp(-K, K)
    zu = (x + eps_scalar).clamp(-K, K)
    
    class _SimpleZ:
        def __init__(self, l, u):
            self.lower = l; self.upper = u; self.center = 0.5*(l+u)
    
    zonotope = _SimpleZ(zl, zu)

    print("zono lower/upper:", 
          float(zl.min()), float(zl.max()),
          float(zu.min()), float(zu.max()))
    
    # # Bound propagation WITH THE ORIGINAL MODEL (not a wrapper)
    # if model_type == 0:
    #     lower, upper = propagate_bounds_encdecnet(model, zonotope)
    # elif model_type == 1:
    #     lower, upper = propagate_bounds_falconetmha_lirpa(model, zonotope)
    # else:
    #     lower, upper = propagate_bounds_attunet_lirpa(model, zonotope)
    # pre-logits
    USE_AFFINE_LAST_LAYER = True  # <-- flip this flag on

    # --- choose propagation path ---
    if USE_AFFINE_LAST_LAYER:
        if model_type == 0:
            feat_l, feat_u = propagate_prelogits_encdecnet(model, zonotope)
            final = model.outc
        elif model_type == 1:
            feat_l, feat_u = propagate_prelogits_falconetmha(model, zonotope)
            final = model.outc
        else:
            feat_l, feat_u = propagate_prelogits_attunet(model, zonotope)
            final = model.Conv_1x1
    
        # optional: check pre-logit width to confirm we're here
        pre_w_mean = (feat_u - feat_l).abs().mean().item()
        print("[path] affine-last-layer ON  | prelogit width mean:", f"{pre_w_mean:.4f}")
    
        # exact affine bounds for logits + tight margin
        lower, upper = affine_bounds_conv2d(feat_l, feat_u, final)
        chg_idx, nchg_idx = 1, 0
        margin_lb, margin_ub = affine_margin_bounds_conv2d(feat_l, feat_u, final, chg_idx, nchg_idx)
    
    else:
        # (legacy full IBP path — will blow up at eps>0)
        if model_type == 0:
            lower, upper = propagate_bounds_encdecnet(model, zonotope)
            final = model.outc
        elif model_type == 1:
            lower, upper = propagate_bounds_falconetmha_lirpa(model, zonotope)
            final = model.outc
        else:
            lower, upper = propagate_bounds_attunet_lirpa(model, zonotope)
            final = model.Conv_1x1
        # very loose margin from loose logits:
        chg_idx, nchg_idx = 1, 0
        margin_lb = lower[:, chg_idx] - upper[:, nchg_idx]
        margin_ub = upper[:, chg_idx] - lower[:, nchg_idx]
    
    # --- certified set from tight margin bound ---
    mask_cert = (margin_lb > 0)
    
    # Clean logits through the same normalized x
    model.eval()
    with torch.no_grad():
        logits = model(x)

    if mapsave:
        save_maps(lower, upper, logits, out_prefix)
        margins_and_hist(lower, tau, out_prefix + "_margin")

        if gt_mask is not None:
            plt.figure(figsize=(4,4))
            plt.imshow(gt_mask.astype(np.float32), cmap="gray")
            plt.title("GT Change Mask")
            plt.savefig(out_prefix + "_gtmask.png")

    # If eps == 0, force exactness (sanity harness & plots stay meaningful)
    if float(eps) == 0.0:
        lower = logits.clone()
        upper = logits.clone()

    # scale & range sanity
    print("σ1, σ2, eps1, eps2:", sig1, sig2, float(eps/sig1), float(eps/sig2))
    print("ranges:",
          "lower", float(lower.min()), float(lower.max()),
          "upper", float(upper.min()), float(upper.max()),
          "logits", float(logits.min()), float(logits.max()))

    # for check only
    logits_prob = torch.softmax(logits, dim=1)  # or sigmoid if that's your head
    print("tightness (prob space):",
          torch.abs(lower - logits_prob).mean().item(),
          torch.abs(upper - logits_prob).mean().item())

    return lower, upper, logits


# ==================================================
# 2. PREDICATE VERIFICATION (CURRENT OR UPDATED VERSION)
# ==================================================
#Verificaion predicates - non-semantic
def predicate_count_based(pred_map, tau=0.5, k=20):
    binary = pred_map >= tau
    return binary.sum() >= k

def predicate_connected_components(pred_map, tau=0.5, min_area=20):
    binary = (pred_map >= tau).astype(np.uint8)
    labeled, num = ndlabel(binary)
    for i in range(1, num + 1):
        area = (labeled == i).sum()
        if area >= min_area:
            return True
    return False

def predicate_hybrid(pred_map, tau=0.5, k=20, min_area=10):
    binary = (pred_map >= tau).astype(np.uint8)
    if binary.sum() < k:
        return False
    labeled, num = ndlabel(binary)
    for i in range(1, num + 1):
        area = (labeled == i).sum()
        if area >= min_area:
            return True
    return False

def verify_predicates_diagnostic(lower, upper, tau, k, s_min):
    pred_map = lower.squeeze().detach().cpu().numpy()
    global_ok = predicate_count_based(pred_map, tau, k)
    island_ok = predicate_connected_components(pred_map, tau, s_min)
    hybrid_ok = predicate_hybrid(pred_map, tau, k, s_min)
    return global_ok, island_ok, hybrid_ok


# ==================================================
# 3. AGGREGATION UTILITY
# ==================================================
def aggregate_predicate_results(all_results, csv_path="predicate_summary.csv"):
    fieldnames = ["Model", "City", "Eps", "Tau", "K", "S_min",
                  "Global_OK", "Island_OK", "Hybrid_OK", "Hybrid_Midpoint_OK"]
    with open(csv_path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in all_results:
            writer.writerow(row)
    print(f"Saved CSV results to {csv_path}")

# === anon v3 utils (non-destructive; unique names) ===
import numpy as _np
try:
    from skimage.measure import label as _sklabel
except Exception:
    _sklabel = None
try:
    from scipy.ndimage import label as _ndlabel
except Exception:
    _ndlabel = None

def _to_numpy(x):
    try:
        import torch
        if isinstance(x, torch.Tensor):
            return x.detach().cpu().numpy()
    except Exception:
        pass
    return _np.asarray(x)

# === anon v3: margin-based certification for change detection ===
def anon_choose_change_channel(clean_logits, gt_mask):
    z = _to_numpy(clean_logits)
    if z.ndim == 2:
        raise ValueError("Single-channel head detected. Prefer two-logit head for sound margins.")
    C,H,W = z.shape
    exps = _np.exp(z - z.max(axis=0, keepdims=True))
    probs = exps / exps.sum(axis=0, keepdims=True)
    gt = (_to_numpy(gt_mask) > 0).astype('uint8')
    def iou(mask):
        inter = int((mask & gt).sum()); union = int((mask | gt).sum())
        return inter / max(union, 1)
    if C == 2:
        cands = [(0,1), (1,0)]
    else:
        top = int(_np.argmax(probs.mean((1,2))))
        cands = [(top, i) for i in range(C) if i != top]
    best = max(cands, key=lambda ab: iou((probs[ab[0]]>0.5).astype('uint8')))
    return best

def anon_certified_change_mask(lower, upper, chg_idx=0, nchg_idx=1):
    l = _to_numpy(lower); u = _to_numpy(upper)
    if l.ndim != 3 or u.ndim != 3:
        raise ValueError(f"expected 3D bounds (C,H,W), got {l.shape} and {u.shape}")
    lchg = l[chg_idx]; unchg = u[nchg_idx]
    margin_lb = lchg - unchg
    mask = (margin_lb > 0).astype('uint8')
    return mask, margin_lb

# === anon v3: predicates (keys match your summary schema) ===
import json as _json

def anon_predicate_overlap(mask_cert, clean_pred, rho=0.3):
    clean = (_to_numpy(clean_pred)>0).astype('uint8')
    denom = max(int(clean.sum()), 1)
    overlap = int((mask_cert & clean).sum()) / denom
    return (overlap >= rho), float(overlap)

def anon_predicate_fp(mask_cert, gt_mask, gamma=0.3):
    gt = (_to_numpy(gt_mask)>0).astype('uint8')
    size = int(mask_cert.sum())
    if size == 0:
        return True, 0.0
    fp = int(((mask_cert==1) & (gt==0)).sum()) / float(size)
    return (fp <= gamma), float(fp)

def anon_predicate_pattern(mask_cert, s_min=16, connectivity=1):
    if int(mask_cert.sum()) == 0:
        return True, {"num": 0, "sizes": []}
    if _sklabel is not None:
        lbl, num = _sklabel(mask_cert, connectivity=connectivity, return_num=True)
    elif _ndlabel is not None:
        lbl, num = _ndlabel(mask_cert, structure=None)
    else:
        raise RuntimeError("No connected-components function available (skimage/scipy missing)")
    sizes = [int((lbl==i).sum()) for i in range(1, num+1)]
    ok = all(sz >= s_min for sz in sizes)
    return ok, {"num": int(num), "sizes": sizes}

def anon_verify_predicates_semantic(lower, upper, clean_logits, clean_pred, gt_mask,
                                    rho=0.3, gamma=0.3, s_min=16,
                                    chg_idx=None, nchg_idx=None):
    if chg_idx is None or nchg_idx is None:
        chg_idx, nchg_idx = anon_choose_change_channel(clean_logits, gt_mask)
    mask_cert, margin_lb = anon_certified_change_mask(lower, upper, chg_idx, nchg_idx)
    ol_ok, ol_val = anon_predicate_overlap(mask_cert, clean_pred, rho=rho)
    fp_ok, fp_val = anon_predicate_fp(mask_cert, gt_mask, gamma=gamma)
    patt_ok, patt_stats = anon_predicate_pattern(mask_cert, s_min=s_min, connectivity=1)
    strict = bool(ol_ok and fp_ok and patt_ok)
    patt_json = _json.dumps(patt_stats, separators=(',',':'))
    res = {
        "Certified_strict": int(strict),
        "overlap_ok": int(ol_ok), "overlap_ratio": float(ol_val),
        "fp_ok": int(fp_ok), "fp_ratio": float(fp_val),
        "min_margin": float(margin_lb.min()) if margin_lb.size>0 else 0.0,
        "pattern_ok": int(patt_ok), "pattern_stats": patt_json,
        "certified_strict": int(strict),
        "Overlap_ok": int(ol_ok), "Overlap_ratio": float(ol_val),
        "Fp_ok": int(fp_ok), "Fp_ratio": float(fp_val),
        "Min_margin": float(margin_lb.min()) if margin_lb.size>0 else 0.0,
        "Pattern_ok": int(patt_ok), "Pattern_stats": patt_json,
        "chg_idx": int(chg_idx), "nchg_idx": int(nchg_idx),
    }
    return res, mask_cert, margin_lb

# === anon v3: CSV writer aligned to your schema ===
import csv, os

def anon_write_row(csv_path, row, header=None):
    default_header = [
        "Model","City","Eps","Tau","K","S_min",
        "Certified_strict",
        "Overlap_ok","Overlap_ratio",
        "Fp_ok","Fp_ratio",
        "Min_margin",
        "Pattern_ok","Pattern_stats",
        "chg_idx","nchg_idx"
    ]
    header = header or default_header
    exists = os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=header)
        if not exists:
            w.writeheader()
        w.writerow(row)

# channel choice 
# 1) helper: to numpy
def _to_numpy(x):
    try:
        import torch
        if isinstance(x, torch.Tensor):
            return x.detach().cpu().numpy()
    except Exception:
        pass
    return _np.asarray(x)
 
def anon_choose_change_channel(clean_logits, gt_mask):
    z = _to_numpy(clean_logits)
    if z.ndim == 2:
        raise ValueError("Single-channel head; two-logit head recommended.")
    C, H, W = z.shape
    exps = _np.exp(z - z.max(axis=0, keepdims=True))
    probs = exps / exps.sum(axis=0, keepdims=True)
    gt = (_to_numpy(gt_mask) > 0).astype('uint8')
    def iou(mask):
        inter = int((mask & gt).sum()); union = int((mask | gt).sum())
        return inter / max(union, 1)
    if C == 2:
        cands = [(0,1), (1,0)]
    else:
        top = int(_np.argmax(probs.mean((1,2))))
        cands = [(top, i) for i in range(C) if i != top]
    best = max(cands, key=lambda ab: iou((probs[ab[0]] > 0.5).astype('uint8')))
    return best
 
# 3) helper: logits -> clean_pred for the chosen change channel
def logits_to_clean_pred(clean_logits, chg_idx, thresh=0.5):
    z = _to_numpy(clean_logits)
    exps = _np.exp(z - z.max(axis=0, keepdims=True))
    probs = exps / exps.sum(axis=0, keepdims=True)
    return (probs[chg_idx] > thresh).astype('uint8')
 

 
# 5) NEW wrapper with the SAME name: now returns (lower, upper, clean_logits, clean_pred, gt_mask)
def rcd(model, model_type, city, eps):
    # call your original
    lower, upper, clean_logits = run_certification_diagnostics(model=model, model_type=model_type, city=city, eps=eps)
    lower, upper, clean_logits = lower.squeeze(0), upper.squeeze(0), clean_logits.squeeze(0)

    H, W = clean_logits.shape[-2], clean_logits.shape[-1]
 
    # fetch GT mask from your dict (raise clearly if missing)
    if city not in GT_MASKS:
        raise KeyError(f"GT_MASKS['{city}'] not set. Please do: GT_MASKS['{city}']=<HxW binary array>")
 
    gt_mask = GT_MASKS[city]

    gt_mask = np.asarray(gt_mask)
    if gt_mask.ndim == 3:
        gt_mask = gt_mask.squeeze()

    if gt_mask.shape != (H,W):
        gt_mask = (np.array(Image.fromarray((gt_mask > 0).astype('uint8') * 255).resize((W,H), resample=Image.NEAREST)) > 0).astype('uint8')
    else:
        gt_mask = (gt_mask > 0).astype('uint8')
        
 
    # select channels consistently and build clean_pred
    chg_idx, nchg_idx = anon_choose_change_channel(clean_logits, gt_mask)
    clean_pred = logits_to_clean_pred(clean_logits, chg_idx, thresh=0.5)
 
    return lower, upper, clean_logits, clean_pred, gt_mask

# === anon v3: Sanity suite (eps=0 and 2/255) ===
def sanity_suite_one_tile(model, model_type, city, eps_list=(0.0, 2/255), rho=0.2, gamma=0.3, s_min=16,
                          csv_path="predicate_summary_semantic_v3.csv"):
    print(f"[sanity] model={model_type} city={city} eps={eps_list} rho={rho} gamma={gamma} s_min={s_min}")
    for eps in eps_list:
        out = rcd(model=model, city=city, model_type=model_type , eps=eps)
        if not (isinstance(out, (list, tuple)) and len(out) == 5):
            raise ValueError("run_certification_diagnostics must return (lower, upper, clean_logits, clean_pred, gt_mask)")
        lower, upper, clean_logits, clean_pred, gt_mask = out
        results, mask_cert, margin_lb = anon_verify_predicates_semantic(
            lower, upper, clean_logits, clean_pred, gt_mask,
            rho=rho, gamma=gamma, s_min=s_min
        )
        print(f"  eps={eps:.5f} -> overlap={results['Overlap_ratio']:.3f}, fp={results['Fp_ratio']:.3f}, "
              f"pattern_ok={results['Pattern_ok']}, strict={results['Certified_strict']}, "
              f"min_margin={results['Min_margin']:.4f} (chg,nchg)=({results['chg_idx']},{results['nchg_idx']})")
        row = {
            "Model": str(model), "City": str(city), "Eps": eps,
            "Tau": "-", "K": "-", "S_min": s_min,
            "Certified_strict": results["Certified_strict"],
            "Overlap_ok": results["Overlap_ok"], "Overlap_ratio": results["Overlap_ratio"],
            "Fp_ok": results["Fp_ok"], "Fp_ratio": results["Fp_ratio"],
            "Min_margin": results["Min_margin"],
            "Pattern_ok": results["Pattern_ok"], "Pattern_stats": results["Pattern_stats"],
            "chg_idx": results["chg_idx"], "nchg_idx": results["nchg_idx"],
        }
        anon_write_row(csv_path, row)
    print(f"[sanity] wrote -> {csv_path}")

# === anon v3: Normalization check ===
def anon_debug_normalization(lower, upper, clean_logits, eps):
    l = _to_numpy(lower); u = _to_numpy(upper); z = _to_numpy(clean_logits)
    print(f"[norm] eps={eps} -> lower/upper shapes={l.shape}/{u.shape}, clean_logits={z.shape}")
    if eps == 0.0:
        import numpy as _np
        diff_l = _np.abs(l - z).mean()
        diff_u = _np.abs(u - z).mean()
        print(f"[norm] eps=0 diffs -> mean|lower-clean|={diff_l:.4g}, mean|upper-clean|={diff_u:.4g}")
        if diff_l > 1e-3 or diff_u > 1e-3:
            print("[norm][WARN] bounds not tight at eps=0; check normalization and verification graph.")


# channel choice 
# 1) helper: to numpy
def _to_numpy(x):
    try:
        import torch
        if isinstance(x, torch.Tensor):
            return x.detach().cpu().numpy()
    except Exception:
        pass
    return _np.asarray(x)
 
def anon_choose_change_channel(clean_logits, gt_mask):
    z = _to_numpy(clean_logits)
    if z.ndim == 2:
        raise ValueError("Single-channel head; two-logit head recommended.")
    C, H, W = z.shape
    exps = _np.exp(z - z.max(axis=0, keepdims=True))
    probs = exps / exps.sum(axis=0, keepdims=True)
    gt = (_to_numpy(gt_mask) > 0).astype('uint8')
    def iou(mask):
        inter = int((mask & gt).sum()); union = int((mask | gt).sum())
        return inter / max(union, 1)
    if C == 2:
        cands = [(0,1), (1,0)]
    else:
        top = int(_np.argmax(probs.mean((1,2))))
        cands = [(top, i) for i in range(C) if i != top]
    best = max(cands, key=lambda ab: iou((probs[ab[0]] > 0.5).astype('uint8')))
    return best
 
# 3) helper: logits -> clean_pred for the chosen change channel
def logits_to_clean_pred(clean_logits, chg_idx, thresh=0.5):
    z = _to_numpy(clean_logits)
    exps = _np.exp(z - z.max(axis=0, keepdims=True))
    probs = exps / exps.sum(axis=0, keepdims=True)
    return (probs[chg_idx] > thresh).astype('uint8')
 

 
# 5) NEW wrapper with the SAME name: now returns (lower, upper, clean_logits, clean_pred, gt_mask)
def rcd(model, model_type, city, eps):
    # call your original
    lower, upper, clean_logits = run_certification_diagnostics(model=model, model_type=model_type, city=city, eps=eps)
    lower, upper, clean_logits = lower.squeeze(0), upper.squeeze(0), clean_logits.squeeze(0)

    H, W = clean_logits.shape[-2], clean_logits.shape[-1]
 
    # fetch GT mask from your dict (raise clearly if missing)
    if city not in GT_MASKS:
        raise KeyError(f"GT_MASKS['{city}'] not set. Please do: GT_MASKS['{city}']=<HxW binary array>")
 
    gt_mask = GT_MASKS[city]

    gt_mask = np.asarray(gt_mask)
    if gt_mask.ndim == 3:
        gt_mask = gt_mask.squeeze()

    if gt_mask.shape != (H,W):
        gt_mask = (np.array(Image.fromarray((gt_mask > 0).astype('uint8') * 255).resize((W,H), resample=Image.NEAREST)) > 0).astype('uint8')
    else:
        gt_mask = (gt_mask > 0).astype('uint8')
        
 
    # select channels consistently and build clean_pred
    chg_idx, nchg_idx = anon_choose_change_channel(clean_logits, gt_mask)
    clean_pred = logits_to_clean_pred(clean_logits, chg_idx, thresh=0.5)
 
    return lower, upper, clean_logits, clean_pred, gt_mask

# === SHIM: stop recursion in affine_bounds_conv2d & friends, no edits elsewhere ===
import torch.nn as nn
import types

# ---- 1) Locate the "true originals" from whatever state the notebook is in ----
AB_ORIG = None
MB_ORIG = None

# Common places we've seen the originals stashed by earlier patches:
if 'AFFINE_BOUNDS_CONV2D_ORIG' in globals() and callable(globals()['AFFINE_BOUNDS_CONV2D_ORIG']):
    AB_ORIG = globals()['AFFINE_BOUNDS_CONV2D_ORIG']
elif '_affine_bounds_conv2d_orig' in globals() and callable(globals()['_affine_bounds_conv2d_orig']):
    AB_ORIG = globals()['_affine_bounds_conv2d_orig']
elif 'affine_bounds_conv2d' in globals() and callable(globals()['affine_bounds_conv2d']):
    # Best effort: assume current is original (ok if you haven't patched before).
    AB_ORIG = globals()['affine_bounds_conv2d']

if 'AFFINE_MARGIN_CONV2D_ORIG' in globals() and callable(globals()['AFFINE_MARGIN_CONV2D_ORIG']):
    MB_ORIG = globals()['AFFINE_MARGIN_CONV2D_ORIG']
elif '_affine_margin_bounds_conv2d_orig' in globals() and callable(globals()['_affine_margin_bounds_conv2d_orig']):
    MB_ORIG = globals()['_affine_margin_bounds_conv2d_orig']
elif 'affine_margin_bounds_conv2d' in globals() and callable(globals()['affine_margin_bounds_conv2d']):
    MB_ORIG = globals()['affine_margin_bounds_conv2d']  # ok if you don't use margin

if AB_ORIG is None:
    raise RuntimeError("Could not locate the original affine_bounds_conv2d in this notebook state.")

# ---- 2) Robust unwrapping: get the actual Conv2d even if wrapped (OutConv, Sequential, etc.) ----
def _unwrap_last_conv2d(m):
    if isinstance(m, nn.Conv2d):
        return m
    # Try common single-attr wrappers (e.g., U-Net heads)
    for attr in ('conv', 'final', 'out_conv', 'out', 'project', 'proj'):
        if hasattr(m, attr):
            sub = getattr(m, attr)
            c = _unwrap_last_conv2d(sub)
            if c is not None:
                return c
    # Walk children from the end for Sequential / nested modules
    if hasattr(m, 'children'):
        for child in reversed(list(m.children())):
            c = _unwrap_last_conv2d(child)
            if c is not None:
                return c
    return None

# ---- 3) Safe versions that delegate to the originals on a real Conv2d ----
def affine_bounds_conv2d_safe(feat_l, feat_u, conv_like):
    conv = _unwrap_last_conv2d(conv_like)
    if conv is None:
        raise TypeError(f"Expected Conv2d or wrapper; got {type(conv_like).__name__}")
    return AB_ORIG(feat_l, feat_u, conv)

def affine_margin_bounds_conv2d_safe(feat_l, feat_u, conv_like, pos_idx, neg_idx):
    if MB_ORIG is None:
        raise RuntimeError("Margin bounds function not available in this notebook.")
    conv = _unwrap_last_conv2d(conv_like)
    if conv is None:
        raise TypeError(f"Expected Conv2d or wrapper; got {type(conv_like).__name__}")
    return MB_ORIG(feat_l, feat_u, conv, pos_idx, neg_idx)

# ---- 4) Wrap run_certification_diagnostics so it uses the *safe* functions only during its run ----
if 'run_certification_diagnostics' not in globals() or not callable(run_certification_diagnostics):
    raise RuntimeError("run_certification_diagnostics not found; run the cell that defines it first.")

_RCD_ORIG = run_certification_diagnostics  # capture the real one

def run_certification_diagnostics(*args, **kwargs):
    # Temporarily rebind the global names so any internal calls use the safe versions
    g = globals()
    saved_ab = g.get('affine_bounds_conv2d', None)
    saved_mb = g.get('affine_margin_bounds_conv2d', None)
    g['affine_bounds_conv2d'] = affine_bounds_conv2d_safe
    if MB_ORIG is not None:
        g['affine_margin_bounds_conv2d'] = affine_margin_bounds_conv2d_safe
    try:
        return _RCD_ORIG(*args, **kwargs)
    finally:
        # restore previous bindings
        if saved_ab is not None:
            g['affine_bounds_conv2d'] = saved_ab
        if saved_mb is not None and MB_ORIG is not None:
            g['affine_margin_bounds_conv2d'] = saved_mb

print("[shim] installed: safe affine bounds + RCD shim; recursion should be gone.")
# === FINAL BINDINGS PATCH (v9) ===
import torch, torch.nn as nn, torch.nn.functional as F
import os

# --- 1) Provide oscd_paths() so GT loader works even if older cells expect it ---
def oscd_paths(city: str):
    # Use the legacy global path constants already in your notebook
    # (If they aren't defined for some reason, set them before calling run_sweep)
    top = OSCD_PATH_TOP
    imgs1 = top + city + OSCD_PATH_BOTTOM1
    imgs2 = top + city + OSCD_PATH_BOTTOM2
    cm    = top + city + OSCD_PATH_CM
    return imgs1, imgs2, cm

# --- 2) Helper to unwrap to the last real Conv2d (handles OutConv/Sequential/etc.) ---
def _last_conv2d(module: nn.Module):
    last = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last = m
    if last is None:
        raise RuntimeError("Could not find a Conv2d inside final module.")
    return last

# --- 3) General (any k/s/p/dilation/groups) exact interval bounds for the FINAL Conv2d ---
def _conv2d_interval_bounds(xL, xU, W, b, stride, padding, dilation, groups):
    # Standard interval arithmetic: split weight into positive/negative parts
    device = xL.device
    Wpos = torch.clamp(W, min=0).to(device=device)
    Wneg = torch.clamp(W, max=0).to(device=device)
    yL = F.conv2d(xL, Wpos, None, stride, padding, dilation, groups) + \
         F.conv2d(xU, Wneg, None, stride, padding, dilation, groups)
    yU = F.conv2d(xU, Wpos, None, stride, padding, dilation, groups) + \
         F.conv2d(xL, Wneg, None, stride, padding, dilation, groups)
    if b is not None:
        b = b.to(device=device).view(1, -1, 1, 1)
        yL = yL + b
        yU = yU + b
    return yL, yU

def affine_bounds_conv2d_final(feat_l, feat_u, final_module):
    conv = _last_conv2d(final_module)
    yL, yU = _conv2d_interval_bounds(
        feat_l, feat_u, conv.weight, conv.bias,
        conv.stride, conv.padding, conv.dilation, conv.groups
    )
    print("[bind] using non-recursive final Conv2d bounds")
    return yL, yU

# --- 4) Margin bounds for (logit[c] - logit[nc]) at the FINAL Conv2d ---
def affine_margin_bounds_conv2d_final(feat_l, feat_u, final_module, chg_idx, nchg_idx):
    conv = _last_conv2d(final_module)
    device = feat_l.device
    # Build the single-channel difference kernel: Wd = W[c] - W[nc]
    Wc  = conv.weight[int(chg_idx)].unsqueeze(0)
    Wnc = conv.weight[int(nchg_idx)].unsqueeze(0)
    Wd = (Wc - Wnc).to(device=device)
    bd = None
    if conv.bias is not None:
        bd = (conv.bias[int(chg_idx)] - conv.bias[int(nchg_idx)]).to(device=device)
    yL, yU = _conv2d_interval_bounds(
        feat_l, feat_u, Wd, bd,
        conv.stride, conv.padding, conv.dilation, conv.groups
    )
    print("[bind] using non-recursive final Conv2d margin bounds")
    return yL, yU

# --- 5) HARD REBIND: point *all* names the verifier might call to these safe versions ---
globals()['affine_bounds_conv2d'] = affine_bounds_conv2d_final
globals()['affine_margin_bounds_conv2d'] = affine_margin_bounds_conv2d_final

# Some notebooks wrap & re-alias via "*_safe" names — clobber them, too:
globals()['affine_bounds_conv2d_safe'] = affine_bounds_conv2d_final
globals()['affine_margin_bounds_conv2d_safe'] = affine_margin_bounds_conv2d_final

# Optional: tiny sanity to show we really overrode them
print("affine_bounds_conv2d ->", affine_bounds_conv2d.__name__)
print("affine_margin_bounds_conv2d ->", affine_margin_bounds_conv2d.__name__)
print("[final-patch] oscd_paths + final Conv2d bound fns installed.")
# === CROWN-TAIL (margin) override — drop-in, no extra deps ===
# Put this cell right after your "final-conv patch" cell.

import torch
import torch.nn as nn
import torch.nn.functional as F

# --- keep a handle to whatever margin-bounds function you have now (for fallback) ---
_AFFINE_MARGIN_ORIG = globals().get('affine_margin_bounds_conv2d', None)

_EPS = 1e-12

# ---- helpers: flatten the last block into a simple forward-ordered layer list ----
def _flatten_layers(mod: nn.Module):
    out = []
    def walk(m):
        # atomic ops we support
        if isinstance(m, (nn.Conv2d, nn.BatchNorm2d, nn.ReLU)):
            out.append(m)
        else:
            kids = list(m.children())
            if kids:
                for k in kids:
                    walk(k)
            else:
                # passthrough for unknown leafs (e.g., Dropout, Identity)
                out.append(m)
    walk(mod)
    return out

# ---- interval forward (IBP) through the block, record pre-activation (l,u) at each ReLU ----
def _ibp_collect_relu_bounds(l, u, layers):
    relu_lu = []  # in forward order; each item is (l_pre, u_pre) for that ReLU
    for m in layers:
        if isinstance(m, nn.Conv2d):
            W = m.weight
            b = m.bias
            Wpos = torch.clamp(W, min=0)
            Wneg = torch.clamp(W, max=0)
            l = F.conv2d(l, Wpos, None, m.stride, m.padding, m.dilation, m.groups) + \
                F.conv2d(u, Wneg, None, m.stride, m.padding, m.dilation, m.groups)
            u = F.conv2d(u, Wpos, None, m.stride, m.padding, m.dilation, m.groups) + \
                F.conv2d(l, Wneg, None, m.stride, m.padding, m.dilation, m.groups)
            if b is not None:
                b = b.view(1, -1, 1, 1)
                l = l + b
                u = u + b
        elif isinstance(m, nn.BatchNorm2d):
            # eval BN: y = s * x + q
            assert not m.training, "Expected BatchNorm in eval() for verification"
            s = (m.weight / torch.sqrt(m.running_var + m.eps)).view(1, -1, 1, 1)
            q = (m.bias - m.running_mean * (m.weight / torch.sqrt(m.running_var + m.eps))).view(1, -1, 1, 1)
            spos = torch.clamp(s, min=0)
            sneg = torch.clamp(s, max=0)
            l = spos * l + sneg * u + q
            u = spos * u + sneg * l + q
        elif isinstance(m, nn.ReLU):
            # record pre-activation bounds for this ReLU, then push IBP through it
            relu_lu.append((l.clone(), u.clone()))
            l = torch.clamp(l, min=0)
            u = torch.clamp(u, min=0)
        else:
            # unknown op: treat as identity for IBP (conservative)
            l = l
            u = u
    return relu_lu

# ---- backprop the linear form through the block with α-CROWN relaxations on ReLUs ----
def _backprop_linear_form(C, t, layers, relu_lu, mode: str):
    """
    C:  (1, C_out, H, W) coefficient map on current tensor
    t:  (1, 1, H, W)     accumulated constant map
    mode: 'upper' or 'lower'
    """
    # traverse in reverse; consume ReLU bounds from the end
    relu_idx = len(relu_lu) - 1
    for m in reversed(layers):
        if isinstance(m, nn.ReLU):
            L, U = relu_lu[relu_idx]
            relu_idx -= 1
            # masks
            cross = (L < 0) & (U > 0)
            pos    = (C >= 0)
            # α for crossing (safe if divide-by-zero → 1)
            alpha = torch.where(cross, U / (U - L + _EPS), torch.ones_like(U))
            # force α=0 when U<=0; α=1 when L>=0
            alpha = torch.where(U <= 0, torch.zeros_like(alpha), torch.where(L >= 0, torch.ones_like(alpha), alpha))

            if mode == 'upper':
                # C>=0 → use upper line; C<0 → y minimal (=0)
                C_use  = pos * C
                t = t + (C_use * (-alpha * L)).sum(dim=1, keepdim=True)
                C = C_use * alpha
                # entries with U<=0 wipe out C; with L>=0 alpha==1 handled above
                C = torch.where((U <= 0), torch.zeros_like(C), C)
            else:  # 'lower'
                # C>=0 → y minimal (=0); C<0 → use upper line (max y)
                C_use  = (~pos) * C  # negative entries only
                t = t + (C_use * (-alpha * L)).sum(dim=1, keepdim=True)
                C = C_use * alpha
                C = torch.where((U <= 0), torch.zeros_like(C), C)
                # L>=0 case acts as identity; already covered since alpha=1 there and pos/neg selection above
        elif isinstance(m, nn.BatchNorm2d):
            # y = s*x + q
            s = (m.weight / torch.sqrt(m.running_var + m.eps)).view(1, -1, 1, 1)
            q = (m.bias - m.running_mean * (m.weight / torch.sqrt(m.running_var + m.eps))).view(1, -1, 1, 1)
            # C·(s x + q) = (C*s)·x + C·q
            t = t + (C * q).sum(dim=1, keepdim=True)
            C = C * s
        elif isinstance(m, nn.Conv2d):
            # C·(conv(x)+b) = (conv_transpose(C, W))·x + (C·b)
            if m.bias is not None:
                t = t + (C * m.bias.view(1, -1, 1, 1)).sum(dim=1, keepdim=True)
            C = F.conv_transpose2d(C, m.weight, bias=None,
                                   stride=m.stride, padding=m.padding,
                                   dilation=m.dilation, groups=m.groups)
        else:
            # unknown op → assume identity (conservative)
            C = C
            t = t
    return C, t

def _affine_margin_bounds_conv2d_crown_tail(feat_l, feat_u, final_module, chg_idx, nchg_idx):
    """
    CROWN tail for the margin: backprop v = W[c]-W[nc] through last DoubleConv block
    and bound v^T h + (b_c - b_nc) over h ∈ post-DoubleConv(z), z∈[l0,u0].
    """
    # Activate only if taps are present and final matches the tapped head
    if 'CROWN_TAIL' not in globals():
        raise RuntimeError("CROWN_TAIL not set")
    taps = CROWN_TAIL
    l0 = taps['l0']; u0 = taps['u0']
    block = taps['block']
    final = taps['final']
    if final is not final_module:
        # different head (unlikely), refuse to proceed
        raise RuntimeError("CROWN_TAIL['final'] != final_module")

    device = l0.device
    # 1) Prepare layers & ReLU (pre-activation) bounds inside the block
    layers = _flatten_layers(block)
    relu_lu = _ibp_collect_relu_bounds(l0, u0, layers)

    # 2) Initialize the linear form with the 1×1 "margin head": v = W[c]-W[nc]
    head = final
    Wc  = head.weight[int(chg_idx)].unsqueeze(0)   # (1, C_in, 1,1)
    Wnc = head.weight[int(nchg_idx)].unsqueeze(0)
    v = (Wc - Wnc)                                 # (1, C_in, 1,1)
    b_margin = None
    if head.bias is not None:
        b_margin = (head.bias[int(chg_idx)] - head.bias[int(nchg_idx)]).to(device=device)

    # Coefficient map C starts as v broadcast to spatial size of block output
    H, W = feat_l.shape[-2:]
    C0 = v.expand(-1, -1, H, W).contiguous()       # (1, C_in, H, W)
    t0 = torch.zeros((1,1,H,W), device=device)     # constant map accumulator

    # 3) Backprop the **upper** linear bound through the block
    C_u, t_u = _backprop_linear_form(C0.clone(), t0.clone(), layers, relu_lu, mode='upper')
    # 4) Backprop the **lower** linear bound through the block
    C_l, t_l = _backprop_linear_form(C0.clone(), t0.clone(), layers, relu_lu, mode='lower')

    # 5) Max/min over z ∈ [l0,u0] for the linear forms C·z + t
    # upper: C_pos*u0 + C_neg*l0 + t
    Cu_pos = torch.clamp(C_u, min=0); Cu_neg = torch.clamp(C_u, max=0)
    Lu = (Cu_pos * u0 + Cu_neg * l0).sum(dim=1, keepdim=True) + t_u
    # lower: C_pos*l0 + C_neg*u0 + t
    Cl_pos = torch.clamp(C_l, min=0); Cl_neg = torch.clamp(C_l, max=0)
    Ll = (Cl_pos * l0 + Cl_neg * u0).sum(dim=1, keepdim=True) + t_l

    # 6) Add the margin bias (passes straight through)
    if b_margin is not None:
        Ll = Ll + b_margin.view(1,1,1,1)
        Lu = Lu + b_margin.view(1,1,1,1)

    # Shapes match your existing margin-bound expectations: (1,1,H,W)
    return Ll, Lu

# --- soft wrapper that uses CROWN tail when taps are available; else fallback ---
def affine_margin_bounds_conv2d_tailaware(feat_l, feat_u, final_module, chg_idx, nchg_idx):
    try:
        if 'CROWN_TAIL' in globals():
            return _affine_margin_bounds_conv2d_crown_tail(feat_l, feat_u, final_module, chg_idx, nchg_idx)
    except Exception as e:
        print(f"[tail] CROWN-tail unavailable ({e}); falling back to original margin bounds.")
    # fallback to whatever was installed before
    if _AFFINE_MARGIN_ORIG is not None:
        return _AFFINE_MARGIN_ORIG(feat_l, feat_u, final_module, chg_idx, nchg_idx)
    else:
        # extremely unlikely: no original available
        raise RuntimeError("No original affine_margin_bounds_conv2d found for fallback.")

# ---- hard rebind the name your verifier uses for margin bounds ----
globals()['affine_margin_bounds_conv2d'] = affine_margin_bounds_conv2d_tailaware

print("[tail] α-CROWN margin override armed (uses taps when CROWN_TAIL is set).")
# === CROWN tail (one DoubleConv + final 1x1) — drop-in override ===
# Put this cell AFTER your "final-conv bound patch" cell in v9.

import torch, torch.nn as nn, torch.nn.functional as F

# Global “tap” the rcd() can populate with the tail input bounds and modules.
# If this is not set, we fall back to the already-installed final-conv-only path.
CROWN_TAIL = {
    # 'l0': <Tensor [1,C,H,W]>, 'u0': <Tensor [1,C,H,W]>,
    # 'block': <DoubleConv module>, 'final': <final 1x1 module>,
}

def _bn_affine_params(bn: nn.BatchNorm2d):
    assert isinstance(bn, nn.BatchNorm2d)
    # y = a * x + b  (eval mode only)
    eps = bn.eps
    w   = bn.weight.reshape(1,-1,1,1)
    b   = bn.bias.reshape(1,-1,1,1)
    mu  = bn.running_mean.reshape(1,-1,1,1)
    var = bn.running_var.reshape(1,-1,1,1)
    a = w / torch.sqrt(var + eps)
    c = b - a * mu
    return a, c

def _interval_affine(l, u, a, b):
    # y = a*x + b elementwise; a can be negative
    y1 = a*l; y2 = a*u
    L  = torch.minimum(y1, y2) + b
    U  = torch.maximum(y1, y2) + b
    return L, U

def _relu_lin(l, u, mode: str):
    # return (alpha, beta) for y <=/>= alpha*x + beta
    # Upper:
    #   u<=0 -> a=0,b=0 ; l>=0 -> a=1,b=0 ; else a=u/(u-l), b=-l*u/(u-l)
    # Lower (simple): 
    #   u<=0 -> a=0,b=0 ; l>=0 -> a=1,b=0 ; else a=0,b=0
    assert mode in ("upper","lower")
    a = torch.zeros_like(l)
    b = torch.zeros_like(l)
    pos   = (l >= 0)
    neg   = (u <= 0)
    unstab= (~pos & ~neg)
    if mode == "upper":
        a[pos] = 1.0
        a[neg] = 0.0
        # safe: avoid divide-by-zero when u≈l
        denom = (u - l).clamp_min(1e-12)
        a[unstab] = (u[unstab] / denom[unstab])
        b[unstab] = (-l[unstab] * u[unstab] / denom[unstab])
    else:
        a[pos] = 1.0
        a[neg] = 0.0
        # on unstable we keep a=0,b=0 (simple lower bound)
    return a, b

def _conv_weight_flat(conv: nn.Conv2d):
    W = conv.weight   # [Cout, Cin, kh, kw]
    return W.view(W.shape[0], -1).contiguous()

def _dual_back_through_conv(C_map, conv: nn.Conv2d):
    """Back-sub dual coeffs through Conv: 
       input:  C_map [1,Cout,Hout,Wout]
       output: C_map_prev [1,Cin,Hin,Win], plus keep COL form for the *next* conv.
    """
    assert isinstance(conv, nn.Conv2d)
    N, Cout, Hout, Wout = C_map.shape
    assert N == 1
    # Bias contribution lives at the output domain; caller handles (C_map * b).sum over channels.
    # Convert to COL at output domain
    L = Hout * Wout
    C_out_col = C_map.view(1, Cout, L)  # [1, Cout, L]
    W_flat = _conv_weight_flat(conv)    # [Cout, Cin*kh*kw]
    # Cin*kh*kw x L
    C_in_col = torch.matmul(W_flat.t(), C_out_col)   # [1, Cin*kh*kw, L]
    # Fold back to MAP at input domain for pointwise ops / next layers
    Hin = (Hout - 1) * conv.stride[0] - 2*conv.padding[0] + conv.dilation[0]*(conv.kernel_size[0]-1) + 1
    Win = (Wout - 1) * conv.stride[1] - 2*conv.padding[1] + conv.dilation[1]*(conv.kernel_size[1]-1) + 1
    C_in_map = F.fold(
        C_in_col,
        output_size=(Hin, Win),
        kernel_size=conv.kernel_size,
        dilation=conv.dilation,
        padding=conv.padding,
        stride=conv.stride,
    )
    return C_in_map, C_in_col, (Hin,Win), C_out_col

def _unfold_like_conv(x, conv: nn.Conv2d):
    return F.unfold(
        x, kernel_size=conv.kernel_size,
        dilation=conv.dilation, padding=conv.padding, stride=conv.stride
    )  # [N, Cin*kh*kw, L]

def _interval_conv(l, u, conv: nn.Conv2d):
    W = conv.weight
    b = conv.bias
    Wp = torch.clamp(W, min=0); Wn = torch.clamp(W, max=0)
    yL = F.conv2d(l, Wp, None, conv.stride, conv.padding, conv.dilation, conv.groups) + \
         F.conv2d(u, Wn, None, conv.stride, conv.padding, conv.dilation, conv.groups)
    yU = F.conv2d(u, Wp, None, conv.stride, conv.padding, conv.dilation, conv.groups) + \
         F.conv2d(l, Wn, None, conv.stride, conv.padding, conv.dilation, conv.groups)
    if b is not None:
        yL = yL + b.view(1,-1,1,1)
        yU = yU + b.view(1,-1,1,1)
    return yL, yU

def _doubleconv_parts(block: nn.Module):
    # Extract (conv1, bn1, relu1, conv2, bn2, relu2) in order.
    convs = [m for m in block.modules() if isinstance(m, nn.Conv2d)]
    bns   = [m for m in block.modules() if isinstance(m, nn.BatchNorm2d)]
    relus = [m for m in block.modules() if isinstance(m, (nn.ReLU, nn.ReLU6))]
    # Be conservative: require 2 convs, 2 bns, 2 relus
    if len(convs) < 2 or len(bns) < 2 or len(relus) < 2:
        raise RuntimeError("CROWN tail: expected DoubleConv with 2x (Conv+BN+ReLU).")
    # Heuristic: take the *last* two of each in module traversal order
    conv1, conv2 = convs[-2], convs[-1]
    bn1,   bn2   = bns[-2],   bns[-1]
    relu1, relu2 = relus[-2], relus[-1]
    return conv1, bn1, relu1, conv2, bn2, relu2

def _last_conv2d(module: nn.Module):
    last = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last = m
    if last is None:
        raise RuntimeError("CROWN tail: final 1x1 conv not found.")
    return last

@torch.no_grad()
def crown_tail_bounds_oneblock(l0, u0, block: nn.Module, final_module: nn.Module,
                               chg_idx: int|torch.Tensor, nchg_idx: int|torch.Tensor):
    """
    l0,u0: interval bounds at *input* of the DoubleConv tail (shape [1,C,H,W]).
    block:  DoubleConv (conv1->bn1->relu1->conv2->bn2->relu2).
    final_module: module containing the final 1x1 conv.
    chg_idx, nchg_idx: class indices for margin.
    Returns: (LB_map, UB_map) each [1,1,H,W].
    """
    device = l0.device
    conv1, bn1, relu1, conv2, bn2, relu2 = _doubleconv_parts(block)
    final = _last_conv2d(final_module)

    # ---- forward intervals through the tail to get (l,u) at ReLU inputs ----
    # z1 = BN1(Conv1(l0,u0))
    y1L, y1U = _interval_conv(l0, u0, conv1)
    a1, b1 = _bn_affine_params(bn1)
    z1L, z1U = _interval_affine(y1L, y1U, a1, b1)             # ReLU1 input
    # post ReLU1
    x1L, x1U = torch.clamp_min(z1L, 0), torch.clamp_min(z1U, 0)

    # z2 = BN2(Conv2(x1))
    y2L, y2U = _interval_conv(x1L, x1U, conv2)
    a2, b2 = _bn_affine_params(bn2)
    z2L, z2U = _interval_affine(y2L, y2U, a2, b2)             # ReLU2 input
    # post ReLU2 -> tail output (pre-final)
    x2L, x2U = torch.clamp_min(z2L, 0), torch.clamp_min(z2U, 0)

    # ---- set dual at logits: +1 for chg, -1 for non-chg, per-pixel ----
    H, W = x2L.shape[-2], x2L.shape[-1]
    C_out = final.out_channels
    Cy = torch.zeros(1, C_out, H, W, device=device)
    Cy[:, int(chg_idx)]  =  1.0
    Cy[:, int(nchg_idx)] = -1.0

    # d_map accumulates constant terms per output pixel
    d_map = torch.zeros(1, 1, H, W, device=device)

    # ---- back through final 1x1 conv ----
    if final.bias is not None:
        d_map = d_map + (Cy * final.bias.view(1,-1,1,1)).sum(dim=1, keepdim=True)
    C_map, C_col, (H2,W2), _ = _dual_back_through_conv(Cy, final)  # now at tail output (x2 domain)

    # ---- ReLU2 linear relaxations ----
    aU2, bU2 = _relu_lin(z2L, z2U, mode="upper")
    aL2, bL2 = _relu_lin(z2L, z2U, mode="lower")
    # We'll compute both UB and LB by two passes differing only in (a,b)
    def _pass(a_map, b_map):
        C = C_map.clone()
        d = d_map.clone()
        d = d + (C * b_map).sum(dim=1, keepdim=True)
        C = C * a_map
        # BN2
        d = d + (C * b2).sum(dim=1, keepdim=True)
        C = C * a2
        # Conv2
        if conv2.bias is not None:
            d = d + (C * conv2.bias.view(1,-1,1,1)).sum(dim=1, keepdim=True)
        C, Ccol, (H1,W1), _ = _dual_back_through_conv(C, conv2)  # now at x1 (post ReLU1)
        # ReLU1
        return C, Ccol, d

    # Upper chain:
    C_u, Ccol_u, d_u = _pass(aU2, bU2)
    aU1, bU1 = _relu_lin(z1L, z1U, mode="upper")
    d_u = d_u + (C_u * bU1).sum(dim=1, keepdim=True)
    C_u = C_u * aU1
    # BN1
    d_u = d_u + (C_u * b1).sum(dim=1, keepdim=True)
    C_u = C_u * a1
    # Conv1 (last back-sub): DO NOT fold — keep COL at tail input for evaluation
    if conv1.bias is not None:
        d_u = d_u + (C_u * conv1.bias.view(1,-1,1,1)).sum(dim=1, keepdim=True)
    # unfold coeffs to COL aligned with tail input
    C_u_col = _unfold_like_conv(C_u, conv1)  # [1, Cin*k*k, L]
    # Lower chain:
    C_l, Ccol_l, d_l = _pass(aL2, bL2)
    aL1, bL1 = _relu_lin(z1L, z1U, mode="lower")
    d_l = d_l + (C_l * bL1).sum(dim=1, keepdim=True)
    C_l = C_l * aL1
    d_l = d_l + (C_l * b1).sum(dim=1, keepdim=True)
    C_l = C_l * a1
    if conv1.bias is not None:
        d_l = d_l + (C_l * conv1.bias.view(1,-1,1,1)).sum(dim=1, keepdim=True)
    C_l_col = _unfold_like_conv(C_l, conv1)

    # ---- evaluate against input box [l0,u0] per output pixel (vectorized) ----
    Lcol = _unfold_like_conv(l0, conv1)  # [1, Cin*k*k, L]
    Ucol = _unfold_like_conv(u0, conv1)  # [1, Cin*k*k, L]
    Cpos_u = torch.clamp(C_u_col, min=0); Cneg_u = torch.clamp(C_u_col, max=0)
    Cpos_l = torch.clamp(C_l_col, min=0); Cneg_l = torch.clamp(C_l_col, max=0)

    UB_col = d_u.view(1,1,-1) + (Cpos_u * Ucol + Cneg_u * Lcol).sum(dim=1, keepdim=True)
    LB_col = d_l.view(1,1,-1) + (Cpos_l * Lcol + Cneg_l * Ucol).sum(dim=1, keepdim=True)

    Hout, Wout = H, W  # same spatial size
    UB = UB_col.view(1,1,Hout,Wout)
    LB = LB_col.view(1,1,Hout,Wout)
    return LB, UB

# ---- glue: override your bound call-sites if tail tap is present ----
def affine_bounds_conv2d_crown(feat_l, feat_u, final_module,
                               model=None, chg_idx=None, nchg_idx=None):
    if CROWN_TAIL and all(k in CROWN_TAIL for k in ("l0","u0","block","final")):
        L, U = crown_tail_bounds_oneblock(
            CROWN_TAIL["l0"], CROWN_TAIL["u0"],
            CROWN_TAIL["block"], CROWN_TAIL["final"],
            chg_idx if chg_idx is not None else 1,
            nchg_idx if nchg_idx is not None else 0,
        )
        print("[tail] CROWN one-block used (DoubleConv + 1×1).")
        return L, U
    # fallback: your already-installed final conv-only bounds
    return affine_bounds_conv2d_final(feat_l, feat_u, final_module)

def affine_margin_bounds_conv2d_crown(feat_l, feat_u, final_module, chg_idx, nchg_idx):
    if CROWN_TAIL and all(k in CROWN_TAIL for k in ("l0","u0","block","final")):
        L, U = crown_tail_bounds_oneblock(
            CROWN_TAIL["l0"], CROWN_TAIL["u0"],
            CROWN_TAIL["block"], CROWN_TAIL["final"],
            chg_idx, nchg_idx
        )
        return L, U
    return affine_margin_bounds_conv2d_final(feat_l, feat_u, final_module, chg_idx, nchg_idx)

# hard rebind
globals()['affine_bounds_conv2d']        = affine_bounds_conv2d_crown
globals()['affine_margin_bounds_conv2d'] = affine_margin_bounds_conv2d_crown

print("[CROWN-tail] drop-in installed. If CROWN_TAIL tap is set, tail will be used; else final-only fallback.")
# === Make α-CROWN tail the one actually used (also for *_safe alias) ===
# (Place this right after the CROWN-tail cell you added)

def _affine_margin_bounds_conv2d_tailaware_VERBOSE(feat_l, feat_u, final_module, chg_idx, nchg_idx):
    # Try the α-CROWN tail first
    try:
        if 'CROWN_TAIL' in globals():
            L, U = _affine_margin_bounds_conv2d_crown_tail(feat_l, feat_u, final_module, chg_idx, nchg_idx)
            print("[tail] α-CROWN margin used.")
            return L, U
    except Exception as e:
        print(f"[tail] α-CROWN margin unavailable ({e}); falling back.")

    # Fallback to the original margin bound
    if '_AFFINE_MARGIN_ORIG' in globals() and _AFFINE_MARGIN_ORIG is not None:
        print("[tail] fallback -> original margin bounds.")
        return _AFFINE_MARGIN_ORIG(feat_l, feat_u, final_module, chg_idx, nchg_idx)
    raise RuntimeError("No original affine_margin_bounds_conv2d available for fallback.")

# Rebind BOTH names your verifier might call
globals()['affine_margin_bounds_conv2d']      = _affine_margin_bounds_conv2d_tailaware_VERBOSE
globals()['affine_margin_bounds_conv2d_safe'] = _affine_margin_bounds_conv2d_tailaware_VERBOSE

print("affine_margin_bounds_conv2d      ->", affine_margin_bounds_conv2d.__name__)
print("affine_margin_bounds_conv2d_safe ->", affine_margin_bounds_conv2d_safe.__name__)
# === Helper: fetch the first Conv2d inside a block (e.g., DoubleConv) ===
import torch.nn as nn

def _first_conv(mod: nn.Module) -> nn.Conv2d:
    for m in mod.modules():
        if isinstance(m, nn.Conv2d):
            return m
    raise RuntimeError("No Conv2d found in block")

# --- Encoder-Decoder (stop before model.outc) ---
def propagate_prelogits_encdecnet(model, z):
    x_l, x_u = z.lower, z.upper

    # encoder
    x1_l, x1_u = interval_forward_falconet(x_l,  x_u,  model.inc)
    x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
    x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
    x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
    x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)

    # decoder (first three ups as usual)
    def up_step_nonfinal(up, xl, xu, sk_l, sk_u):
        xl = up.up(xl);  xu = up.up(xu)
        xl = torch.cat([sk_l, xl], dim=1)
        xu = torch.cat([sk_u, xu], dim=1)
        return interval_forward_falconet(xl, xu, up.conv)

    x_l, x_u = up_step_nonfinal(model.up1, x5_l, x5_u, x4_l, x4_u)
    x_l, x_u = up_step_nonfinal(model.up2, x_l,  x_u,  x3_l, x3_u)
    x_l, x_u = up_step_nonfinal(model.up3, x_l,  x_u,  x2_l, x2_u)

    # LAST up block (up4): tap **before** DoubleConv, then run it
    xl = model.up4.up(x_l);  xu = model.up4.up(x_u)
    t_in_l = torch.cat([x1_l, xl], dim=1)
    t_in_u = torch.cat([x1_u, xu], dim=1)

    if 'CROWN_TAIL' in globals():
        CROWN_TAIL['l0']    = t_in_l.detach().clone()
        CROWN_TAIL['u0']    = t_in_u.detach().clone()
        CROWN_TAIL['block'] = model.up4.conv
        CROWN_TAIL['final'] = model.outc
        fc = _first_conv(model.up4.conv)
        print(f"[tail-tap EncDec] inC={t_in_l.shape[1]} expect={fc.in_channels}")

    y_l, y_u = interval_forward_falconet(t_in_l, t_in_u, model.up4.conv)
    return y_l, y_u  # pre-logits for 1×1 head

# --- FALCONet + token mixer (stop before model.outc) ---
def propagate_prelogits_falconetmha(model, z):
    x_l, x_u = z.lower, z.upper

    # encoder with token mixers
    x1_l, x1_u = interval_forward_falconet(x_l,  x_u,  model.inc)
    x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
    x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
    x3_l, sh3  = flatten_hw(x3_l);  x3_u, _ = flatten_hw(x3_u)
    x3_l, x3_u = interval_forward_tokenmixer(x3_l, x3_u, model.token_mixer_2)
    x3_l, x3_u = unflatten_hw(x3_l, sh3), unflatten_hw(x3_u, sh3)

    x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
    x4_l, sh4  = flatten_hw(x4_l);  x4_u, _ = flatten_hw(x4_u)
    x4_l, x4_u = interval_forward_tokenmixer(x4_l, x4_u, model.token_mixer_3)
    x4_l, x4_u = unflatten_hw(x4_l, sh4), unflatten_hw(x4_u, sh4)

    x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)
    x5_l, sh5  = flatten_hw(x5_l);  x5_u, _ = flatten_hw(x5_u)
    x5_l, x5_u = interval_forward_tokenmixer(x5_l, x5_u, model.token_mixer_4)
    x5_l, x5_u = unflatten_hw(x5_l, sh5), unflatten_hw(x5_u, sh5)

    # decoder (first three ups as usual)
    def up_step_nonfinal(up, xl, xu, sk_l, sk_u):
        xl = up.up(xl);  xu = up.up(xu)
        xl = torch.cat([sk_l, xl], dim=1)
        xu = torch.cat([sk_u, xu], dim=1)
        return interval_forward_falconet(xl, xu, up.conv)

    x_l, x_u = up_step_nonfinal(model.up1, x5_l, x5_u, x4_l, x4_u)
    x_l, x_u = up_step_nonfinal(model.up2, x_l,  x_u,  x3_l, x3_u)
    x_l, x_u = up_step_nonfinal(model.up3, x_l,  x_u,  x2_l, x2_u)

    # LAST up block (up4): tap **before** DoubleConv, then run it
    xl = model.up4.up(x_l);  xu = model.up4.up(x_u)
    t_in_l = torch.cat([x1_l, xl], dim=1)
    t_in_u = torch.cat([x1_u, xu], dim=1)

    if 'CROWN_TAIL' in globals():
        CROWN_TAIL['l0']    = t_in_l.detach().clone()
        CROWN_TAIL['u0']    = t_in_u.detach().clone()
        CROWN_TAIL['block'] = model.up4.conv
        CROWN_TAIL['final'] = model.outc
        fc = _first_conv(model.up4.conv)
        print(f"[tail-tap FALCONet] inC={t_in_l.shape[1]} expect={fc.in_channels}")

    y_l, y_u = interval_forward_falconet(t_in_l, t_in_u, model.up4.conv)
    return y_l, y_u  # pre-logits for 1×1 head

# --- AttU-Net (stop before model.Conv_1x1) ---
def propagate_prelogits_attunet(model, z):
    x_l, x_u = z.lower, z.upper

    # encoder
    x1_l, x1_u = interval_forward_attunet(x_l,  x_u,  model.Conv1)
    print("[w] after Conv1 :", _span_mean_max(x1_l, x1_u))
    x2_l, x2_u = interval_forward_attunet(x1_l, x1_u, model.Maxpool1)
    x2_l, x2_u = interval_forward_attunet(x2_l, x2_u, model.Conv2)
    print("[w] after Conv3 :", _span_mean_max(x2_l, x2_u))
    x3_l, x3_u = interval_forward_attunet(x2_l, x2_u, model.Maxpool2)
    x3_l, x3_u = interval_forward_attunet(x3_l, x3_u, model.Conv3)
    print("[w] after Conv5 :", _span_mean_max(x3_l, x3_u))

    # bottleneck
    x4_l, x4_u = interval_forward_attunet(x3_l, x3_u, model.Maxpool3)
    x4_l, x4_u = interval_forward_attunet(x4_l, x4_u, model.Conv4)

    # decoder up from bottom
    d4_l, d4_u = interval_forward_attunet(x4_l, x4_u, model.Up4)        # upsample
    x3_l_att, x3_u_att = interval_forward_attention_gate(model.Att4, d4_l, d4_u, x3_l, x3_u)
    d4_l = torch.cat((x3_l_att, d4_l), 1);  d4_u = torch.cat((x3_u_att, d4_u), 1)
    d4_l, d4_u = interval_forward_attunet(d4_l, d4_u, model.Up_conv4)

    d3_l, d3_u = interval_forward_attunet(d4_l, d4_u, model.Up3)        # upsample
    x2_l_att, x2_u_att = interval_forward_attention_gate(model.Att3, d3_l, d3_u, x2_l, x2_u)
    d3_l = torch.cat((x2_l_att, d3_l), 1);  d3_u = torch.cat((x2_u_att, d3_u), 1)
    d3_l, d3_u = interval_forward_attunet(d3_l, d3_u, model.Up_conv3)
    print("[w] after Up_conv3:", _span_mean_max(d3_l, d3_u))

    d2_l, d2_u = interval_forward_attunet(d3_l, d3_u, model.Up2)        # upsample
    x1_l_att, x1_u_att = interval_forward_attention_gate(model.Att2, d2_l, d2_u, x1_l, x1_u)
    t_in_l = torch.cat((x1_l_att, d2_l), 1)    # <-- **before** Up_conv2
    t_in_u = torch.cat((x1_u_att, d2_u), 1)

    if 'CROWN_TAIL' in globals():
        CROWN_TAIL['l0']    = t_in_l.detach().clone()
        CROWN_TAIL['u0']    = t_in_u.detach().clone()
        CROWN_TAIL['block'] = model.Up_conv2
        CROWN_TAIL['final'] = model.Conv_1x1
        fc = _first_conv(model.Up_conv2)
        print(f"[tail-tap AttU] inC={t_in_l.shape[1]} expect={fc.in_channels}")

    d2_l, d2_u = interval_forward_attunet(t_in_l, t_in_u, model.Up_conv2)
    print("[w] after Up_conv2:", _span_mean_max(d2_l, d2_u))
    return d2_l, d2_u  # pre-logits (before model.Conv_1x1)
# === FIXED α-CROWN TAIL (use saved l0/u0; don't reuse pre-logits) ===
import torch, torch.nn as nn, torch.nn.functional as F

def _relu_bounds(l, u):
    # exact interval ReLU (monotone): [max(l,0), max(u,0)]
    return l.clamp_min_(0), u.clamp_min_(0)

def _ia_conv(l, u, conv: nn.Conv2d):
    W = conv.weight
    b = conv.bias
    Wp, Wn = W.clamp(min=0), W.clamp(max=0)
    yL = F.conv2d(l, Wp, None, conv.stride, conv.padding, conv.dilation, conv.groups) + \
         F.conv2d(u, Wn, None, conv.stride, conv.padding, conv.dilation, conv.groups)
    yU = F.conv2d(u, Wp, None, conv.stride, conv.padding, conv.dilation, conv.groups) + \
         F.conv2d(l, Wn, None, conv.stride, conv.padding, conv.dilation, conv.groups)
    if b is not None:
        yL = yL + b.view(1, -1, 1, 1)
        yU = yU + b.view(1, -1, 1, 1)
    return yL, yU

def _two_convs_from_doubleconv(block: nn.Module):
    # Grab the first two Conv2d layers in traversal order (works for common UNet blocks)
    convs = [m for m in block.modules() if isinstance(m, nn.Conv2d)]
    if len(convs) < 2:
        raise RuntimeError("DoubleConv-like block does not expose two Conv2d layers.")
    return convs[0], convs[1]

def _tail_margin_bounds(l0, u0, block: nn.Module, head1x1: nn.Conv2d, chg_idx, nchg_idx):
    # Push from the tap (l0/u0) through the last DoubleConv (+ReLU after each conv)
    c1, c2 = _two_convs_from_doubleconv(block)
    # Sanity: channels must match
    inC = l0.shape[1]; expC = c1.weight.shape[1]
    if inC != expC:
        print(f"[tail] channel mismatch at tap: inC={inC} expect={expC} -> fallback.")
        return None

    l, u = _ia_conv(l0, u0, c1); l, u = _relu_bounds(l, u)
    l, u = _ia_conv(l,  u,  c2); l, u = _relu_bounds(l, u)  # safe even if your block omits final ReLU

    # Now form margin on the final 1×1: Wd = W[c]-W[nc], bd = b[c]-b[nc]
    W = head1x1.weight
    b = head1x1.bias
    Wd = (W[int(chg_idx)].unsqueeze(0) - W[int(nchg_idx)].unsqueeze(0))
    bd = None if b is None else (b[int(chg_idx)] - b[int(nchg_idx)])

    # Interval conv with custom kernel
    Wp, Wn = Wd.clamp(min=0), Wd.clamp(max=0)
    yL = F.conv2d(l, Wp, None, head1x1.stride, head1x1.padding, head1x1.dilation, head1x1.groups) + \
         F.conv2d(u, Wn, None, head1x1.stride, head1x1.padding, head1x1.dilation, head1x1.groups)
    yU = F.conv2d(u, Wp, None, head1x1.stride, head1x1.padding, head1x1.dilation, head1x1.groups) + \
         F.conv2d(l, Wn, None, head1x1.stride, head1x1.padding, head1x1.dilation, head1x1.groups)
    if bd is not None:
        yL = yL + bd.view(1, 1, 1, 1)
        yU = yU + bd.view(1, 1, 1, 1)
    print("[tail] α-CROWN margin applied over DoubleConv→1×1.")
    return yL, yU

# --- OVERRIDE ONLY THE MARGIN BOUNDS ---
_AB_ORIG = globals().get('affine_margin_bounds_conv2d', None)

def affine_margin_bounds_conv2d(feat_l, feat_u, final_module, chg_idx, nchg_idx):
    # Try tail if present; otherwise, fall back
    try:
        t = globals().get('CROWN_TAIL', None)
        if t is not None:
            l0    = t.get('l0', None)
            u0    = t.get('u0', None)
            block = t.get('block', None)
            head  = t.get('final', final_module)

            if isinstance(l0, torch.Tensor) and isinstance(u0, torch.Tensor) and \
               isinstance(block, nn.Module) and isinstance(head, nn.Conv2d):
                # (Optional) quick shape log
                print(f"[tail] using saved tap: inC={l0.shape[1]} expect={_two_convs_from_doubleconv(block)[0].weight.shape[1]}")
                out = _tail_margin_bounds(l0, u0, block, head, chg_idx, nchg_idx)
                if out is not None:
                    return out
        print("[tail] unavailable or mismatch; fallback -> original margin bounds.")
    except Exception as e:
        print(f"[tail] α-CROWN margin unavailable ({e}); falling back.")
    # Fallback to whatever was bound before
    if _AB_ORIG is None:
        raise RuntimeError("No original affine_margin_bounds_conv2d available for fallback.")
    return _AB_ORIG(feat_l, feat_u, final_module, chg_idx, nchg_idx)

# Install override
globals()['affine_margin_bounds_conv2d'] = affine_margin_bounds_conv2d
print("[tail] fixed α-CROWN override installed.")
# === GT MASK SHIM (define GT_MASKS and a robust loader) ===
import os
import numpy as np
from PIL import Image

# 1) Make sure the legacy path constants exist (we won't overwrite if you already set them)
if 'OSCD_PATH_TOP' not in globals():      OSCD_PATH_TOP = "../onera/OSCD/"
if 'OSCD_PATH_BOTTOM1' not in globals():  OSCD_PATH_BOTTOM1 = "/imgs_1/"
if 'OSCD_PATH_BOTTOM2' not in globals():  OSCD_PATH_BOTTOM2 = "/imgs_2/"
if 'OSCD_PATH_CM' not in globals():       OSCD_PATH_CM = "/cm/cm.png"

# 2) Provide oscd_paths(city) if some older code tries to call it
if 'oscd_paths' not in globals():
    def oscd_paths(city: str):
        top = OSCD_PATH_TOP
        return (top + city + OSCD_PATH_BOTTOM1,
                top + city + OSCD_PATH_BOTTOM2,
                top + city + OSCD_PATH_CM)

# 3) Global cache
if 'GT_MASKS' not in globals():
    GT_MASKS = {}

def _cm_path(city: str) -> str:
    """Try legacy concat first; fall back to os.path.join if needed."""
    p1 = OSCD_PATH_TOP + city + OSCD_PATH_CM
    if os.path.exists(p1):
        return p1
    p2 = os.path.join(OSCD_PATH_TOP, city, OSCD_PATH_CM.lstrip("/"))
    return p2

# === GT mask sanitizer + safe loader ===
import numpy as np
from PIL import Image
import os

# Ensure the cache exists
if 'GT_MASKS' not in globals():
    GT_MASKS = {}

def _cm_path(city: str) -> str:
    # legacy concat first; safe-join fallback
    p1 = OSCD_PATH_TOP + city + OSCD_PATH_CM
    if os.path.exists(p1):
        return p1
    return os.path.join(OSCD_PATH_TOP, city, OSCD_PATH_CM.lstrip("/"))

def _sanitize_gt_array(arr: np.ndarray) -> np.ndarray:
    """Collapse RGB/RGBA → single channel and binarize (nonzero → 1)."""
    a = np.asarray(arr)
    if a.ndim == 3:
        # Any nonzero across channels = foreground
        a = (a.any(axis=-1)).astype(np.uint8)
    else:
        a = (a > 0).astype(np.uint8)
    return a

def _load_gt_mask_from_disk(city: str) -> np.ndarray:
    p = _cm_path(city)
    if not os.path.exists(p):
        raise FileNotFoundError(f"GT mask path missing for {city}: {p}")
    img = Image.open(p)
    # Convert to single channel explicitly to avoid (H,W,4)
    if img.mode not in ("1", "L"):
        img = img.convert("L")
    a = np.array(img)
    a = _sanitize_gt_array(a)
    GT_MASKS[city] = a
    print(f"[gt] {city}: mask {a.shape} loaded & sanitized from {p}")
    return a

# Sanitize anything already in GT_MASKS (e.g., previously loaded RGBA)
for k, v in list(GT_MASKS.items()):
    GT_MASKS[k] = _sanitize_gt_array(v)
    print(f"[gt-fix] {k}: cached mask -> {GT_MASKS[k].shape} (binary)")

print("[gt] loader/sanitizer installed.")

PyTorch device: cpu
imgs_1 dir -> ../onera/OSCD/paris/imgs_1/
imgs_2 dir -> ../onera/OSCD/paris/imgs_2/
cm path    -> ../onera/OSCD/paris/cm/cm.png
[shim] installed: safe affine bounds + RCD shim; recursion should be gone.
affine_bounds_conv2d -> affine_bounds_conv2d_final
affine_margin_bounds_conv2d -> affine_margin_bounds_conv2d_final
[final-patch] oscd_paths + final Conv2d bound fns installed.
[tail] α-CROWN margin override armed (uses taps when CROWN_TAIL is set).
[CROWN-tail] drop-in installed. If CROWN_TAIL tap is set, tail will be used; else final-only fallback.
affine_margin_bounds_conv2d      -> _affine_margin_bounds_conv2d_tailaware_VERBOSE
affine_margin_bounds_conv2d_safe -> _affine_margin_bounds_conv2d_tailaware_VERBOSE
[tail] fixed α-CROWN override installed.
[gt] loader/sanitizer installed.


In [4]:
# ---------- models_to_test come from your definitions above ----------
model_path = "../onera/FALCONet_HCTv3-best_f1-epoch15.pth.tar"
falconet_model = FALCONetMHA_LiRPA(2*13, 2 , dropout=0.1, reduction=8, attention=True, num_heads=4)
falconet_model.load_state_dict(torch.load(model_path, map_location="cpu")); falconet_model.eval()

MODELS = [
    ("FALCONet",   falconet_model, 0),
]


CITIES    = ["brasilia", "montpellier", "norcia", "rio" , "saclay_w" , "valencia" , "dubai" , "lasvegas" , "milano" , "chongqing"]
EPS    =  (0.0, 0.25/255, 0.5/255, 1/255, 2/255)
RHO    = (0.2, 0.3, 0.5)
GAMMA  = (0.3, 0.2)
SMIN   = (16, 32)

In [5]:
# === PRELOAD GT MASKS (run once before sweep) ===
import os
import numpy as np
from PIL import Image

# Use your existing path constants; do NOT redefine them here
# OSCD_PATH_TOP, OSCD_PATH_CM must already be set

if 'GT_MASKS' not in globals():
    GT_MASKS = {}

def _cm_path(city: str) -> str:
    # legacy concat first; safe join fallback
    p1 = OSCD_PATH_TOP + city + OSCD_PATH_CM
    if os.path.exists(p1):
        return p1
    return os.path.join(OSCD_PATH_TOP, city, OSCD_PATH_CM.lstrip("/"))

# === GT mask sanitizer + safe loader ===
import numpy as np
from PIL import Image
import os

# Ensure the cache exists
if 'GT_MASKS' not in globals():
    GT_MASKS = {}

def _cm_path(city: str) -> str:
    # legacy concat first; safe-join fallback
    p1 = OSCD_PATH_TOP + city + OSCD_PATH_CM
    if os.path.exists(p1):
        return p1
    return os.path.join(OSCD_PATH_TOP, city, OSCD_PATH_CM.lstrip("/"))

def _sanitize_gt_array(arr: np.ndarray) -> np.ndarray:
    """Collapse RGB/RGBA → single channel and binarize (nonzero → 1)."""
    a = np.asarray(arr)
    if a.ndim == 3:
        # Any nonzero across channels = foreground
        a = (a.any(axis=-1)).astype(np.uint8)
    else:
        a = (a > 0).astype(np.uint8)
    return a

def _load_gt_mask_from_disk(city: str) -> np.ndarray:
    p = _cm_path(city)
    if not os.path.exists(p):
        raise FileNotFoundError(f"GT mask path missing for {city}: {p}")
    img = Image.open(p)
    # Convert to single channel explicitly to avoid (H,W,4)
    if img.mode not in ("1", "L"):
        img = img.convert("L")
    a = np.array(img)
    a = _sanitize_gt_array(a)
    GT_MASKS[city] = a
    print(f"[gt] {city}: mask {a.shape} loaded & sanitized from {p}")
    return a

# Sanitize anything already in GT_MASKS (e.g., previously loaded RGBA)
for k, v in list(GT_MASKS.items()):
    GT_MASKS[k] = _sanitize_gt_array(v)
    print(f"[gt-fix] {k}: cached mask -> {GT_MASKS[k].shape} (binary)")

print("[gt] loader/sanitizer installed.")


# Preload every city you’ll sweep
for city in CITIES:
    _load_gt_mask_from_disk(city)

print("[gt] ready keys:", list(GT_MASKS.keys()))


[gt] loader/sanitizer installed.
[gt] brasilia: mask (433, 469) loaded & sanitized from ../onera/OSCD/brasilia/cm/cm.png
[gt] montpellier: mask (426, 451) loaded & sanitized from ../onera/OSCD/montpellier/cm/cm.png
[gt] norcia: mask (241, 385) loaded & sanitized from ../onera/OSCD/norcia/cm/cm.png
[gt] rio: mask (353, 426) loaded & sanitized from ../onera/OSCD/rio/cm/cm.png
[gt] saclay_w: mask (639, 688) loaded & sanitized from ../onera/OSCD/saclay_w/cm/cm.png
[gt] valencia: mask (458, 476) loaded & sanitized from ../onera/OSCD/valencia/cm/cm.png
[gt] dubai: mask (774, 634) loaded & sanitized from ../onera/OSCD/dubai/cm/cm.png
[gt] lasvegas: mask (824, 716) loaded & sanitized from ../onera/OSCD/lasvegas/cm/cm.png
[gt] milano: mask (545, 558) loaded & sanitized from ../onera/OSCD/milano/cm/cm.png
[gt] chongqing: mask (730, 544) loaded & sanitized from ../onera/OSCD/chongqing/cm/cm.png
[gt] ready keys: ['brasilia', 'montpellier', 'norcia', 'rio', 'saclay_w', 'valencia', 'dubai', 'lasvega

In [6]:
# === Sweep harness ===
def run_one_setting(model_name, model, model_type, city, eps, rho, gamma, s_min):
    try:
        lower, upper, clean_logits, clean_pred, gt_mask = rcd(model, model_type, city, eps)
        width_mean = (upper - lower).abs().mean().item()
        logit_min = float(clean_logits.min().item())
        logit_max = float(clean_logits.max().item())
        return {
            "ok": True,
            "model": model_name,
            "type": int(model_type),
            "city": city,
            "eps": float(eps),
            "rho": float(rho),
            "gamma": float(gamma),
            "s_min": int(s_min),
            "width_mean": width_mean,
            "logit_min": logit_min,
            "logit_max": logit_max,
            "error": "",
        }
    except Exception as e:
        print("[trace] rcd() raised:")
        traceback.print_exc()
        return {
            "ok": False,
            "model": model_name,
            "type": int(model_type),
            "city": city,
            "eps": float(eps),
            "rho": float(rho),
            "gamma": float(gamma),
            "s_min": int(s_min),
            "width_mean": float("nan"),
            "logit_min": float("nan"),
            "logit_max": float("nan"),
            "error": repr(e),
        }

def _flush_csv(path, header, rows):
    need_header = not os.path.exists(path) or os.path.getsize(path) == 0
    with open(path, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=header, extrasaction="ignore")
        if need_header:
            w.writeheader()
        w.writerows(rows)
    rows.clear()

def run_sweep(
    models_to_test,
    cities=("paris",),                 # keep small first
    eps_grid=(0.0, 1/255, 2/255),
    rho_grid=(0.2, 0.3),
    gamma_grid=(0.3,),
    s_min_grid=(16,),
    out_csv="predicate_summary_semantic_v7.csv",
    write_every=10,
):
    total_per_city_model = len(eps_grid) * len(rho_grid) * len(gamma_grid) * len(s_min_grid)
    print(f"[grid] {total_per_city_model} combos per city & model")

    header = ["ok","model","type","city","eps","rho","gamma","s_min","width_mean","logit_min","logit_max","error"]
    results = []
    n = 0

    for (model_name, model, model_type) in models_to_test:
        print(f"\n===== MODEL: {model_name} (type={model_type}) =====")
        for city in cities:
            # Optional GT check (non-fatal)
            try:
                _ = _load_gt_mask_from_disk(city)
            except Exception as e:
                print(f"[warn] GT mask for {city} not found / unreadable: {e}")

            for eps, rho, gamma, s_min in itertools.product(eps_grid, rho_grid, gamma_grid, s_min_grid):
                rec = run_one_setting(model_name, model, model_type, city, eps, rho, gamma, s_min)
                results.append(rec)
                n += 1
                if (n % write_every) == 0:
                    _flush_csv(out_csv, header, results)
                    print(f"[write] {n} rows -> {out_csv}")

    _flush_csv(out_csv, header, results)
    print(f"[write] {n} rows -> {out_csv}")
    return results

# === anon end-to-end patch cell (idempotent) ===
# - Robust GT loader guard stays unchanged (already in your notebook)
# - Fix AttU_Net naming (Maxpool1 -> Maxpool) without changing the model
# - Install tail-aware margin override that falls back cleanly if shapes/taps don't match
# - Keep original final-conv bounds intact
import torch, torch.nn as nn, torch.nn.functional as F

# ---- 1) AttU_Net minor alias (only if missing) ----
try:
    for name, model, mtype in MODELS:
        if model.__class__.__name__.lower().startswith("attu"):
            if hasattr(model, "Maxpool") and not hasattr(model, "Maxpool1"):
                setattr(model, "Maxpool1", getattr(model, "Maxpool"))
    print("[shim] AttU-Net: aliased Maxpool1 -> Maxpool (if missing).")
except Exception as e:
    print("[shim] AttU-Net alias warning:", e)

# ---- 2) Simple interval conv2d helper (kept local; no recursion) ----
def _interval_conv2d(xL, xU, conv: nn.Conv2d):
    W = conv.weight
    b = conv.bias
    device = xL.device
    Wpos = torch.clamp(W, min=0).to(device=device)
    Wneg = torch.clamp(W, max=0).to(device=device)
    yL = F.conv2d(xL, Wpos, None, conv.stride, conv.padding, conv.dilation, conv.groups) + \
         F.conv2d(xU, Wneg, None, conv.stride, conv.padding, conv.dilation, conv.groups)
    yU = F.conv2d(xU, Wpos, None, conv.stride, conv.padding, conv.dilation, conv.groups) + \
         F.conv2d(xL, Wneg, None, conv.stride, conv.padding, conv.dilation, conv.groups)
    if b is not None:
        b = b.view(1, -1, 1, 1).to(device=device)
        yL = yL + b
        yU = yU + b
    return yL, yU

def _interval_bn2d(xL, xU, bn: nn.BatchNorm2d):
    # Transform x -> y = gamma * (x - mean) / sqrt(var + eps) + beta
    # For intervals, the worst case is monotone if gamma >= 0; flip if gamma < 0.
    g = bn.weight.detach().view(1,-1,1,1)
    b = bn.bias.detach().view(1,-1,1,1)
    mean = bn.running_mean.detach().view(1,-1,1,1)
    var  = bn.running_var.detach().view(1,-1,1,1)
    denom = torch.sqrt(var + bn.eps)
    a = g / denom
    c = b - a * mean
    a_pos = torch.clamp(a, min=0)
    a_neg = torch.clamp(a, max=0)
    yL = a_pos*xL + a_neg*xU + c
    yU = a_pos*xU + a_neg*xL + c
    return yL, yU

def _interval_relu(xL, xU):
    return F.relu(xL), F.relu(xU)

def _interval_seq(xL, xU, module: nn.Sequential):
    for layer in module.children():
        if isinstance(layer, nn.Conv2d):
            xL, xU = _interval_conv2d(xL, xU, layer)
        elif isinstance(layer, nn.BatchNorm2d):
            xL, xU = _interval_bn2d(xL, xU, layer)
        elif isinstance(layer, nn.ReLU):
            xL, xU = _interval_relu(xL, xU)
        elif isinstance(layer, nn.Identity):
            pass
        else:
            # Be conservative: if unknown layer, bail and force fallback
            raise RuntimeError(f"interval-seq: unsupported layer {layer.__class__.__name__}")
    return xL, xU

# ---- 3) Tail-aware margin override (safe and idempotent) ----
if "AFFINE_MARGIN_ORIG" not in globals():
    AFFINE_MARGIN_ORIG = globals().get("affine_margin_bounds_conv2d")

def _tail_margin_bounds_or_fallback(feat_l, feat_u, final_module, chg_idx, nchg_idx):
    # If no tap or malformed tap, fallback to original
    tap = globals().get("CROWN_TAIL", None)
    if not isinstance(tap, dict) or not all(k in tap for k in ("l0","u0","block","final")):
        return AFFINE_MARGIN_ORIG(feat_l, feat_u, final_module, chg_idx, nchg_idx)

    xL, xU = tap["l0"], tap["u0"]
    block  = tap["block"]
    final1 = tap["final"]

    # Sanity: the block must be a Sequential of conv/bn/relu
    if not isinstance(block, nn.Sequential):
        print("[tail] tap 'block' is not nn.Sequential; fallback.")
        return AFFINE_MARGIN_ORIG(feat_l, feat_u, final_module, chg_idx, nchg_idx)

    # Channel check: ensure the first conv of block matches xL channels.
    first_conv = None
    for ly in block.children():
        if isinstance(ly, nn.Conv2d):
            first_conv = ly
            break
    if first_conv is None:
        print("[tail] no Conv2d in block; fallback.")
        return AFFINE_MARGIN_ORIG(feat_l, feat_u, final_module, chg_idx, nchg_idx)

    inC_needed = int(first_conv.in_channels)
    inC_have   = int(xL.shape[1])
    if inC_needed != inC_have:
        print(f"[tail] channel mismatch at tap (have {inC_have}, need {inC_needed}); fallback.")
        return AFFINE_MARGIN_ORIG(feat_l, feat_u, final_module, chg_idx, nchg_idx)

    # Propagate intervals through DoubleConv tail (interval, non-recursive, tight & stable)
    try:
        tL, tU = _interval_seq(xL, xU, block)
    except Exception as e:
        print(f"[tail] interval through tail failed ({e}); fallback.")
        return AFFINE_MARGIN_ORIG(feat_l, feat_u, final_module, chg_idx, nchg_idx)

    # Build difference kernel for final 1×1
    conv = None
    # final may be Conv2d or a small wrapper that contains Conv2d
    if isinstance(final1, nn.Conv2d):
        conv = final1
    else:
        for m in final1.modules():
            if isinstance(m, nn.Conv2d):
                conv = m
        if conv is None:
            print("[tail] could not locate Conv2d in 'final'; fallback.")
            return AFFINE_MARGIN_ORIG(feat_l, feat_u, final_module, chg_idx, nchg_idx)

    Wc  = conv.weight[int(chg_idx)].unsqueeze(0)
    Wnc = conv.weight[int(nchg_idx)].unsqueeze(0)
    Wd = (Wc - Wnc)
    bd = None
    if conv.bias is not None:
        bd = conv.bias[int(chg_idx)] - conv.bias[int(nchg_idx)]

    # One more interval step for the margin head
    yL, yU = _interval_conv2d(tL, tU, nn.Conv2d(in_channels=Wd.shape[1],
                                                out_channels=1,
                                                kernel_size=conv.kernel_size,
                                                stride=conv.stride,
                                                padding=conv.padding,
                                                dilation=conv.dilation,
                                                groups=conv.groups,
                                                bias=True if bd is not None else False))
    # Manually add weights/bias because we created a dummy conv for shape convenience
    # Replace the output with exact linear combination using the real Wd/bd via grouped conv math
    # (Since we can't set the dummy's data without leaking grads in no-grad mode, just recompute)
    # Recompute with direct conv2d using Wd/bd:
    yL = F.conv2d(tL, Wd, bd, conv.stride, conv.padding, conv.dilation, conv.groups)
    yU = F.conv2d(tU, Wd, bd, conv.stride, conv.padding, conv.dilation, conv.groups)
    print("[tail] margin via interval-tail DoubleConv → real 1×1.")
    return yL, yU

# Swap in our tail-aware wrapper (keeps original as fallback)
globals()["affine_margin_bounds_conv2d"] = _tail_margin_bounds_or_fallback
globals()["affine_margin_bounds_conv2d_safe"] = _tail_margin_bounds_or_fallback
print("[install] Tail-aware margin override installed (with safe fallback).")

# === Quiet end-to-end patch: AttU-Net MaxPool + final-1x1 tail taps (no fallback chatter) ===
import torch
import torch.nn as nn
import torch.nn.functional as F

# Gate all "tail" prints behind a flag (default: silent)
TAIL_VERBOSE = False
def _tprint(*a, **k):
    if TAIL_VERBOSE:
        print(*a, **k)

# 1) Extend interval_forward_attunet to support MaxPool2d
if 'interval_forward_attunet' in globals():
    _ifa_old = interval_forward_attunet
    def interval_forward_attunet(lower, upper, module):
        if isinstance(module, nn.MaxPool2d):
            l = F.max_pool2d(lower,
                             kernel_size=module.kernel_size,
                             stride=module.stride,
                             padding=module.padding,
                             dilation=module.dilation,
                             ceil_mode=module.ceil_mode)
            u = F.max_pool2d(upper,
                             kernel_size=module.kernel_size,
                             stride=module.stride,
                             padding=module.padding,
                             dilation=module.dilation,
                             ceil_mode=module.ceil_mode)
            return l, u
        return _ifa_old(lower, upper, module)
    globals()['interval_forward_attunet'] = interval_forward_attunet
    _tprint("[patch] interval_forward_attunet: MaxPool2d supported")

# 2) Helper: AttU-Net MaxPool alias
def _attu_maxpool_alias(model):
    if hasattr(model, 'Maxpool'):
        return model.Maxpool
    if hasattr(model, 'Maxpool1'):
        return model.Maxpool1
    raise AttributeError("AttU_Net has neither 'Maxpool' nor 'Maxpool1'")

# 3) Quiet tail tap (final 1x1 only)
def _quiet_set_tail(l0, u0, final_head):
    if 'CROWN_TAIL' not in globals():
        return
    CROWN_TAIL.clear()
    CROWN_TAIL['l0']    = l0.detach().clone()
    CROWN_TAIL['u0']    = u0.detach().clone()
    CROWN_TAIL['block'] = None            # never try DoubleConv tail
    CROWN_TAIL['final'] = final_head      # only the last 1x1
    _tprint("[tail] tap set (final 1x1 only).")

# 4) Replace propagation functions (AttU, EncDec, FALCONet-MHA)
def propagate_prelogits_attunet(model, z):
    x_l, x_u = z.lower, z.upper
    # encoder
    x1_l, x1_u = interval_forward_attunet(x_l,  x_u,  model.Conv1)
    print("[w] after Conv1 :", _span_mean_max(x1_l, x1_u))
    x2_l, x2_u = interval_forward_attunet(x1_l, x1_u, _attu_maxpool_alias(model))
    x2_l, x2_u = interval_forward_attunet(x2_l, x2_u, model.Conv3)
    print("[w] after Conv3 :", _span_mean_max(x2_l, x2_u))
    x3_l, x3_u = interval_forward_attunet(x2_l, x2_u, _attu_maxpool_alias(model))
    x3_l, x3_u = interval_forward_attunet(x3_l, x3_u, model.Conv5)
    print("[w] after Conv5 :", _span_mean_max(x3_l, x3_u))

    # decoder + gates
    d4_l, d4_u = interval_forward_attunet(x3_l, x3_u, model.Up3)
    x2l_att, x2u_att = interval_forward_attention_gate(model.Att3, d4_l, d4_u, x2_l, x2_u)
    d4_l, d4_u = torch.cat([x2l_att, d4_l], 1), torch.cat([x2u_att, d4_u], 1)
    d3_l, d3_u = interval_forward_attunet(d4_l, d4_u, model.Up_conv3)
    print("[w] after Up_conv3:", _span_mean_max(d3_l, d3_u))

    # quiet tail tap (final head only)
    _quiet_set_tail(d3_l, d3_u, model.Conv_1x1)

    d2_l, d2_u = interval_forward_attunet(d3_l, d3_u, model.Up2)
    x1l_att, x1u_att = interval_forward_attention_gate(model.Att2, d2_l, d2_u, x1_l, x1_u)
    d2_l, d2_u = torch.cat([x1l_att, d2_l], 1), torch.cat([x1u_att, d2_u], 1)
    d2_l, d2_u = interval_forward_attunet(d2_l, d2_u, model.Up_conv2)
    print("[w] after Up_conv2:", _span_mean_max(d2_l, d2_u))
    return d2_l, d2_u

def propagate_prelogits_encdecnet(model, z):
    x_l, x_u = z.lower, z.upper
    x1_l, x1_u = interval_forward_falconet(x_l,  x_u,  model.inc)
    x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
    x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
    x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
    x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)
    def up_step(up, xl, xu, sk_l, sk_u):
        xl = up.up(xl); xu = up.up(xu)
        xl = torch.cat([sk_l, xl], dim=1)
        xu = torch.cat([sk_u, xu], dim=1)
        return interval_forward_falconet(xl, xu, up.conv)
    x_l, x_u = up_step(model.up1, x5_l, x5_u, x4_l, x4_u)
    x_l, x_u = up_step(model.up2, x_l,  x_u,  x3_l, x3_u)
    x_l, x_u = up_step(model.up3, x_l,  x_u,  x2_l, x2_u)

    # quiet tail tap before last DoubleConv; only final head
    _quiet_set_tail(x_l, x_u, model.outc)

    x_l, x_u = up_step(model.up4, x_l,  x_u,  x1_l, x1_u)
    return x_l, x_u

def propagate_prelogits_falconetmha(model, z):
    x_l, x_u = z.lower, z.upper
    x1_l, x1_u = interval_forward_falconet(x_l,  x_u,  model.inc)
    x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
    x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
    x3_l, sh3  = flatten_hw(x3_l); x3_u, _ = flatten_hw(x3_u)
    x3_l, x3_u = interval_forward_tokenmixer(x3_l, x3_u, model.token_mixer_2)
    x3_l, x3_u = unflatten_hw(x3_l, sh3), unflatten_hw(x3_u, sh3)
    x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
    x4_l, sh4  = flatten_hw(x4_l); x4_u, _ = flatten_hw(x4_u)
    x4_l, x4_u = interval_forward_tokenmixer(x4_l, x4_u, model.token_mixer_3)
    x4_l, x4_u = unflatten_hw(x4_l, sh4), unflatten_hw(x4_u, sh4)
    x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)
    x5_l, sh5  = flatten_hw(x5_l); x5_u, _ = flatten_hw(x5_u)
    x5_l, x5_u = interval_forward_tokenmixer(x5_l, x5_u, model.token_mixer_4)
    x5_l, x5_u = unflatten_hw(x5_l, sh5), unflatten_hw(x5_u, sh5)
    def up_step(up, xl, xu, sk_l, sk_u):
        xl = up.up(xl); xu = up.up(xu)
        xl = torch.cat([sk_l, xl], dim=1)
        xu = torch.cat([sk_u, xu], dim=1)
        return interval_forward_falconet(xl, xu, up.conv)
    x_l, x_u = up_step(model.up1, x5_l, x5_u, x4_l, x4_u)
    x_l, x_u = up_step(model.up2, x_l,  x_u,  x3_l, x3_u)
    x_l, x_u = up_step(model.up3, x_l,  x_u,  x2_l, x2_u)

    # quiet tail tap before last DoubleConv; only final head
    _quiet_set_tail(x_l, x_u, model.outc)

    x_l, x_u = up_step(model.up4, x_l,  x_u,  x1_l, x1_u)
    return x_l, x_u

globals()['propagate_prelogits_attunet']     = propagate_prelogits_attunet
globals()['propagate_prelogits_encdecnet']   = propagate_prelogits_encdecnet
globals()['propagate_prelogits_falconetmha'] = propagate_prelogits_falconetmha

print("[ok] Quiet tail + MaxPool patch active (v11)")
# === v12c: Robust tail + encoder autodetect (EncDec / FALCONet / AttU-Net) ===
import torch
import torch.nn as nn
import torch.nn.functional as F
import re

# ---------- small helpers ----------
def _span_mean_max(l, u):
    w = (u - l).abs()
    return float(w.mean().detach().cpu()), float(w.max().detach().cpu())

def _maxpool2d_bounds(l, u, mp: nn.MaxPool2d):
    k = mp.kernel_size
    s = mp.stride or k
    p = mp.padding
    d = mp.dilation
    cm = mp.ceil_mode
    return F.max_pool2d(l, k, s, p, d, cm), F.max_pool2d(u, k, s, p, d, cm)

def _unwrap_final_conv2d(m: nn.Module) -> nn.Conv2d:
    if isinstance(m, nn.Conv2d):
        return m
    for attr in ('conv', 'out', 'outc', 'final', 'head', 'proj'):
        if hasattr(m, attr) and isinstance(getattr(m, attr), nn.Conv2d):
            return getattr(m, attr)
    for sub in m.modules():
        if isinstance(sub, nn.Conv2d):
            return sub
    raise RuntimeError("[tail] could not find final nn.Conv2d inside head")

def _as_sequential_with_convs(m: nn.Module) -> nn.Sequential:
    if isinstance(m, nn.Sequential):
        seq = m
    else:
        kids = [c for c in m.children()]
        seq = nn.Sequential(*kids) if kids else nn.Sequential()
    # ensure at least one Conv2d (for α-CROWN tail)
    if not any(isinstance(x, nn.Conv2d) for x in seq.modules() if x is not seq):
        convs = [x for x in m.modules() if isinstance(x, nn.Conv2d)]
        if convs:
            seq = nn.Sequential(*convs)
    return seq

def _quiet_set_tail(pre_l, pre_u, block: nn.Module, final_head: nn.Module):
    globals().setdefault('CROWN_TAIL', {})
    CROWN_TAIL.clear()
    CROWN_TAIL.update({
        'l0':   pre_l.detach().clone(),
        'u0':   pre_u.detach().clone(),
        'block': _as_sequential_with_convs(block),
        'final': _unwrap_final_conv2d(final_head),
    })

# ---------- AttU-Net helpers ----------
def _attu_get_maxpool(model):
    mp = getattr(model, 'Maxpool', None) or getattr(model, 'Maxpool1', None)
    if not isinstance(mp, nn.MaxPool2d):
        raise RuntimeError("AttU_Net expects nn.MaxPool2d at model.Maxpool/Maxpool1")
    return mp

def _first_conv_in(m: nn.Module):
    for sub in m.modules():
        if isinstance(sub, nn.Conv2d):
            return sub
    return None

def _attunet_conv_blocks(model):
    blocks = []
    for name, mod in model.named_children():
        # accept Conv1, Conv2, Conv3, Conv5, ... but skip 'Conv_1x1'
        if name.startswith('Conv') and re.fullmatch(r'Conv\d+', name):
            idx = int(name[4:])
            blocks.append((idx, name, mod))
    blocks.sort(key=lambda t: t[0])
    return blocks

def _take_next_block(blocks, inC):
    for i, (_, name, mod) in enumerate(blocks):
        fc = _first_conv_in(mod)
        if isinstance(fc, nn.Conv2d) and fc.in_channels == inC:
            return blocks.pop(i)
    # if exact match not found, best-effort: choose the first whose first conv exists
    for i, (_, name, mod) in enumerate(blocks):
        fc = _first_conv_in(mod)
        if isinstance(fc, nn.Conv2d):
            return blocks.pop(i)
    raise RuntimeError(f"[AttU_Net] no encoder block matches inC={inC}")

# =========================
# EncDec: tap BEFORE up4.conv (so tail.block = real DoubleConv, final = classifier Conv2d)
# =========================
def propagate_prelogits_encdecnet(model, z):
    x_l, x_u = z.lower, z.upper
    # encoder
    x1_l, x1_u = interval_forward_falconet(x_l,  x_u,  model.inc)
    x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
    x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
    x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
    x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)

    # decoder up1..up3
    def up_step(up, xl, xu, sk_l, sk_u):
        xl = up.up(xl); xu = up.up(xu)
        xl = torch.cat([sk_l, xl], dim=1)
        xu = torch.cat([sk_u, xu], dim=1)
        return interval_forward_falconet(xl, xu, up.conv)

    x_l, x_u = up_step(model.up1, x5_l, x5_u, x4_l, x4_u)
    x_l, x_u = up_step(model.up2, x_l,  x_u,  x3_l, x3_u)
    x_l, x_u = up_step(model.up3, x_l,  x_u,  x2_l, x2_u)

    # manual up4 to capture pre-block features
    pre_l = model.up4.up(x_l); pre_u = model.up4.up(x_u)
    pre_l = torch.cat([x1_l, pre_l], dim=1)
    pre_u = torch.cat([x1_u, pre_u], dim=1)

    # set tail on real last DoubleConv + final head Conv2d
    _quiet_set_tail(pre_l, pre_u, block=model.up4.conv, final_head=model.outc)

    # now run the actual last DoubleConv
    x_l, x_u = interval_forward_falconet(pre_l, pre_u, model.up4.conv)
    return x_l, x_u

# =========================
# FALCONet+MHA: same tail strategy as EncDec
# =========================
def propagate_prelogits_falconetmha(model, z):
    x_l, x_u = z.lower, z.upper
    # encoder
    x1_l, x1_u = interval_forward_falconet(x_l,  x_u,  model.inc)
    x2_l, x2_u = interval_forward_falconet(x1_l, x1_u, model.down1)
    x3_l, x3_u = interval_forward_falconet(x2_l, x2_u, model.down2)
    x3_l, sh3  = flatten_hw(x3_l); x3_u, _ = flatten_hw(x3_u)
    x3_l, x3_u = interval_forward_tokenmixer(x3_l, x3_u, model.token_mixer_2)
    x3_l, x3_u = unflatten_hw(x3_l, sh3), unflatten_hw(x3_u, sh3)
    x4_l, x4_u = interval_forward_falconet(x3_l, x3_u, model.down3)
    x4_l, sh4  = flatten_hw(x4_l); x4_u, _ = flatten_hw(x4_u)
    x4_l, x4_u = interval_forward_tokenmixer(x4_l, x4_u, model.token_mixer_3)
    x4_l, x4_u = unflatten_hw(x4_l, sh4), unflatten_hw(x4_u, sh4)
    x5_l, x5_u = interval_forward_falconet(x4_l, x4_u, model.down4)
    x5_l, sh5  = flatten_hw(x5_l); x5_u, _ = flatten_hw(x5_u)
    x5_l, x5_u = interval_forward_tokenmixer(x5_l, x5_u, model.token_mixer_4)
    x5_l, x5_u = unflatten_hw(x5_l, sh5), unflatten_hw(x5_u, sh5)

    # up1..up3
    def up_step(up, xl, xu, sk_l, sk_u):
        xl = up.up(xl); xu = up.up(xu)
        xl = torch.cat([sk_l, xl], dim=1)
        xu = torch.cat([sk_u, xu], dim=1)
        return interval_forward_falconet(xl, xu, up.conv)

    x_l, x_u = up_step(model.up1, x5_l, x5_u, x4_l, x4_u)
    x_l, x_u = up_step(model.up2, x_l,  x_u,  x3_l, x3_u)
    x_l, x_u = up_step(model.up3, x_l,  x_u,  x2_l, x2_u)

    # manual up4 pre-block capture + tail
    pre_l = model.up4.up(x_l); pre_u = model.up4.up(x_u)
    pre_l = torch.cat([x1_l, pre_l], dim=1)
    pre_u = torch.cat([x1_u, pre_u], dim=1)
    _quiet_set_tail(pre_l, pre_u, block=model.up4.conv, final_head=model.outc)

    x_l, x_u = interval_forward_falconet(pre_l, pre_u, model.up4.conv)
    return x_l, x_u

# =========================
# AttU-Net: autodetect encoder blocks by in_channels; tail at Up_conv2
# =========================
def propagate_prelogits_attunet(model, z):
    x_l, x_u = z.lower, z.upper
    mp = _attu_get_maxpool(model)

    # collect encoder conv blocks and pick 3 levels by matching inC
    blocks = _attunet_conv_blocks(model)
    inC0 = x_l.shape[1]
    _, name1, enc1 = _take_next_block(blocks, inC0)
    x1_l, x1_u = interval_forward_attunet(x_l, x_u, enc1)

    x2in_l, x2in_u = _maxpool2d_bounds(x1_l, x1_u, mp)
    _, name2, enc2 = _take_next_block(blocks, x1_l.shape[1])
    x2_l, x2_u = interval_forward_attunet(x2in_l, x2in_u, enc2)

    x3in_l, x3in_u = _maxpool2d_bounds(x2_l, x2_u, mp)
    _, name3, enc3 = _take_next_block(blocks, x2_l.shape[1])
    x3_l, x3_u = interval_forward_attunet(x3in_l, x3in_u, enc3)

    # decoder + attention (unchanged topology)
    d4_l, d4_u = interval_forward_attunet(x3_l, x3_u, model.Up3)
    x2l_att, x2u_att = interval_forward_attention_gate(model.Att3, d4_l, d4_u, x2_l, x2_u)
    d4_l, d4_u = torch.cat([x2l_att, d4_l], 1), torch.cat([x2u_att, d4_u], 1)
    d3_l, d3_u = interval_forward_attunet(d4_l, d4_u, model.Up_conv3)

    d2u_l, d2u_u = interval_forward_attunet(d3_l, d3_u, model.Up2)
    x1l_att, x1u_att = interval_forward_attention_gate(model.Att2, d2u_l, d2u_u, x1_l, x1_u)
    pre_l, pre_u = torch.cat([x1l_att, d2u_l], 1), torch.cat([x1u_att, d2u_u], 1)

    # tail on real last DoubleConv + final 1x1 head
    _quiet_set_tail(pre_l, pre_u, block=model.Up_conv2, final_head=model.Conv_1x1)

    d2_l, d2_u = interval_forward_attunet(pre_l, pre_u, model.Up_conv2)
    return d2_l, d2_u

# ——— install overrides ———
globals()['propagate_prelogits_attunet']     = propagate_prelogits_attunet
globals()['propagate_prelogits_encdecnet']   = propagate_prelogits_encdecnet
globals()['propagate_prelogits_falconetmha'] = propagate_prelogits_falconetmha
print("[ok] v12c: tail fixed (real last block + final conv), AttU-Net encoder autodetect + MaxPool-safe")
# === channel_selector_hotfix_v2 — makes rcd() safe ===
import numpy as np, torch

def anon_choose_change_channel(clean_logits, gt_mask):
    """
    Robust, torch-based channel selection used by rcd().
    - If binary head: (chg=1, nochg=0).
    - Else: pick 'chg' as argmax IoU with GT at p>0.5; pick a different nchg.
    """
    C = clean_logits.shape[1]
    if C == 2:
        return 1, 0

    with torch.no_grad():
        p = torch.softmax(clean_logits, dim=1)[0]  # [C,H,W]

    # GT -> boolean numpy [H,W]
    if torch.is_tensor(gt_mask):
        gt_np = (gt_mask.detach().cpu().numpy() > 0)
    else:
        gt_np = (np.asarray(gt_mask) > 0)

    best_iou, best_chg = -1.0, 0
    for c in range(C):
        pred_c = (p[c] > 0.5).cpu().numpy()
        union = np.logical_or(pred_c, gt_np).sum()
        iou = (np.logical_and(pred_c, gt_np).sum() / max(1, union))
        if iou > best_iou:
            best_iou, best_chg = iou, c

    best_nchg = (best_chg + 1) % C
    return best_chg, best_nchg

def logits_to_clean_pred(clean_logits, chg_idx, thresh=0.5):
    """
    rcd() calls this to produce a clean binary mask. Keep it torch-native.
    Returns a torch.BoolTensor [H,W].
    """
    with torch.no_grad():
        if clean_logits.shape[1] == 2:
            pred = (torch.argmax(clean_logits, dim=1)[0] == chg_idx)
        else:
            prob = torch.softmax(clean_logits, dim=1)[0, chg_idx]
            pred = (prob > thresh)
    return pred

print("[hotfix] Patched anon_choose_change_channel/logits_to_clean_pred (uses torch.softmax; no np.exp).")
# === predicate_logger_safe_v7 — unbreakable mask coercion + GT reload ===
import os, csv, itertools
import numpy as np
import torch
from PIL import Image

OUT = "predicate_pass_v1.csv"
OSCD_PATH_TOP = globals().get("OSCD_PATH_TOP", "../onera/OSCD")

VERBOSE_COERCE = False  # flip to True if you want per-case shape messages

def _np(x): return x.detach().cpu().numpy() if torch.is_tensor(x) else np.asarray(x)

# ---------- GT loader ----------
def _load_city_gt(city, H, W):
    cm_path = os.path.join(str(OSCD_PATH_TOP).rstrip("/"), city, "cm", "cm.png")
    if not os.path.exists(cm_path):
        raise FileNotFoundError(f"GT not found: {cm_path}")
    im = Image.open(cm_path).convert("L")
    if im.size != (W, H):
        im = im.resize((W, H), resample=Image.NEAREST)
    arr = (np.array(im) > 0).astype(np.uint8)
    if VERBOSE_COERCE:
        print(f"[gt] reloaded {city}/cm.png → {arr.shape}")
    return arr

# ---------- shape coercion ----------
def _force_hw_any(x, H, W, role="cert/clean"):
    """
    Coerce x → binary uint8 mask of shape [H,W].
    For cert/clean we aggressively repair 1-D/3-D shapes.
    For GT we still try strict first, then fallback to disk.
    """
    a = _np(x); a = np.asarray(a); a = np.squeeze(a)

    def _done(arr, why):
        out = (arr > 0).astype(np.uint8)
        if out.shape != (H, W):
            raise ValueError(f"{role}: coerced to {out.shape}, expected {(H,W)} (why={why})")
        if VERBOSE_COERCE and why:
            print(f"[coerce:{role}] {why} -> {out.shape}")
        return out

    # perfect
    if a.ndim == 2 and a.shape == (H, W):
        return _done(a, "2D exact")

    # transposed
    if a.ndim == 2 and a.shape == (W, H):
        return _done(a.T, "2D transpose")

    # flattened to H*W
    if a.ndim == 1 and a.size == H*W:
        return _done(a.reshape(H, W), "1D flatten H*W")

    # row/col vectors (common bug)
    if a.ndim == 1 and a.size == H:
        return _done(np.tile(a.reshape(H, 1), (1, W)), "1D H→tile across W")
    if a.ndim == 1 and a.size == W:
        return _done(np.tile(a.reshape(1, W), (H, 1)), "1D W→tile across H")

    # 3D channel-ish
    if a.ndim == 3:
        # (1,H,W) / (H,W,1)
        if a.shape[0] == 1 and a.shape[1:] == (H, W):
            return _done(a[0], "3D [1,H,W]→squeeze")
        if a.shape[-1] == 1 and a.shape[:2] == (H, W):
            return _done(a[..., 0], "3D [H,W,1]→squeeze")
        # (C,H,W) or (H,W,C): any-nonzero
        if a.shape[1:] == (H, W):
            return _done((a != 0).any(axis=0).astype(np.uint8), "3D [C,H,W] any")
        if a.shape[:2] == (H, W):
            return _done((a != 0).any(axis=-1).astype(np.uint8), "3D [H,W,C] any")

    # last resort: if one dim matches H or W, try to broadcast
    if a.ndim == 2:
        if a.shape[0] == H and a.shape[1] == 1:
            return _done(np.tile(a, (1, W)), "2D [H,1]→tile W")
        if a.shape[1] == W and a.shape[0] == 1:
            return _done(np.tile(a, (H, 1)), "2D [1,W]→tile H")

    # if GT → reload from disk; otherwise give a descriptive error
    if role == "gt":
        return _load_city_gt(city="(unknown)", H=H, W=W)  # will be replaced by caller with real city
    raise ValueError(f"{role}: cannot coerce shape {a.shape} to {(H,W)}")

def _force_gt_hw(gt_mask, city, H, W):
    # Try strict-ish first; reload on failure.
    try:
        return _force_hw_any(gt_mask, H, W, role="gt")
    except Exception as e:
        if VERBOSE_COERCE:
            print(f"[gt] {city}: {e} → reload cm.png")
        return _load_city_gt(city, H, W)

# ---------- CC + predicates ----------
def _connected_components_4(mask_hw_uint8):
    m = (mask_hw_uint8 > 0).astype(np.uint8)
    H, W = m.shape
    vis = np.zeros((H, W), dtype=np.uint8)
    sizes = []
    for i in range(H):
        for j in range(W):
            if m[i, j] and not vis[i, j]:
                stack = [(i, j)]
                vis[i, j] = 1
                sz = 0
                while stack:
                    r, c = stack.pop(); sz += 1
                    for dr, dc in ((1,0),(-1,0),(0,1),(0,-1)):
                        rr, cc = r+dr, c+dc
                        if 0 <= rr < H and 0 <= cc < W and m[rr, cc] and not vis[rr, cc]:
                            vis[rr, cc] = 1; stack.append((rr, cc))
                sizes.append(sz)
    return sizes

def _predicates(Ccert, Cclean, Cgt, rho, gamma, s_min):
    Ccert = (Ccert > 0); Cclean = (Cclean > 0); Cgt = (Cgt > 0)
    denom = max(1, Cclean.sum())
    overlap = np.logical_and(Ccert, Cclean).sum() / denom
    Poverlap = (overlap >= rho)
    cert_sz = Ccert.sum()
    fp = 0.0 if cert_sz == 0 else (np.logical_and(Ccert, np.logical_not(Cgt)).sum() / cert_sz)
    Pfp = (fp <= gamma)
    sizes = _connected_components_4(Ccert.astype(np.uint8))
    Ppattern = all(sz >= s_min for sz in sizes)
    Pstrict = (Poverlap and Pfp and Ppattern)
    largest_cc = max(sizes) if sizes else 0
    return Poverlap, Pfp, Ppattern, Pstrict, float(overlap), float(fp), int(largest_cc)

# ---------- channel selection (torch-only) ----------
def _choose_channels(clean_logits, gt_mask, city):
    C = int(clean_logits.shape[1])
    if C == 2:
        return 1, 0
    with torch.no_grad():
        p = torch.softmax(clean_logits, dim=1)[0]  # [C,H,W]
    H, W = int(p.shape[-2]), int(p.shape[-1])
    gt_hw = _force_gt_hw(gt_mask, city, H, W).astype(bool)
    best_iou, best_chg = -1.0, 0
    for c in range(C):
        pred_c = (p[c] > 0.5).cpu().numpy()
        union = np.logical_or(pred_c, gt_hw).sum()
        iou = (np.logical_and(pred_c, gt_hw).sum() / max(1, union))
        if iou > best_iou:
            best_iou, best_chg = iou, c
    best_nchg = (best_chg + 1) % C
    return best_chg, best_nchg

# ---------- one combo ----------
def _one_combo(model_name, model, model_type, city, eps, rho, gamma, s_min):
    # rcd() -> lower, upper, clean_logits, clean_pred, gt_mask
    lower, upper, clean_logits, clean_pred, gt_mask = rcd(model, model_type, city, eps)
    H, W = int(clean_logits.shape[-2]), int(clean_logits.shape[-1])

    chg, nchg = _choose_channels(clean_logits, gt_mask, city)

    # certified set: margin LB > 0
    mL = (lower[0, chg] - upper[0, nchg])  # [H,W] ideally
    Ccert = _force_hw_any((mL > 0), H, W, role="cert/clean")

    # clean predicted change set
    with torch.no_grad():
        if clean_logits.shape[1] == 2:
            Cclean_t = (torch.argmax(clean_logits, dim=1)[0] == chg)  # [H,W]
        else:
            Cclean_t = (torch.softmax(clean_logits, dim=1)[0, chg] > 0.5)
    Cclean = _force_hw_any(Cclean_t, H, W, role="cert/clean")

    # GT
    Cgt = _force_gt_hw(gt_mask, city, H, W)

    Poverlap, Pfp, Ppattern, Pstrict, overlap, fp, largest_cc = _predicates(
        Ccert, Cclean, Cgt, rho, gamma, s_min
    )
    return {
        "model": model_name, "type": int(model_type), "city": city, "eps": float(eps),
        "rho": float(rho), "gamma": float(gamma), "s_min": int(s_min),
        "strict": bool(Pstrict), "Poverlap": bool(Poverlap), "Pfp": bool(Pfp), "Ppattern": bool(Ppattern),
        "overlap": float(overlap), "fp": float(fp), "largest_cc": int(largest_cc)
    }

# ---------- driver ----------
def run_predicate_logger(MODELS, CITIES, EPS,
                         rho_main=0.30, gamma_main=0.30, smin_main=16,
                         appendix=False,
                         rho_grid=(0.20,0.30,0.40),
                         gamma_grid=(0.20,0.30,0.40),
                         smin_grid=(8,16,32),
                         out_path=OUT):
    header = ["model","type","city","eps","rho","gamma","s_min",
              "strict","Poverlap","Pfp","Ppattern","overlap","fp","largest_cc"]
    rows = []

    def _append(args):
        try:
            rows.append(_one_combo(*args))
        except Exception as e:
            name, _, _, city, eps, rho, gamma, smin = args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]
            print(f"[skip] {name}/{city}/eps={eps:.5f} (rho={rho},gamma={gamma},s={smin}): {e}")

    # main table
    for (model_name, model, model_type) in MODELS:
        for city in CITIES:
            for eps in EPS:
                _append((model_name, model, model_type, city, eps, rho_main, gamma_main, smin_main))

    # appendix
    if appendix:
        for (model_name, model, model_type) in MODELS:
            for city in CITIES:
                for eps in EPS:
                    for rho in rho_grid:
                        for gamma in gamma_grid:
                            for smin in smin_grid:
                                _append((model_name, model, model_type, city, eps, rho, gamma, smin))

    with open(out_path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=header); w.writeheader(); w.writerows(rows)
    print(f"[predicates] wrote {len(rows)} rows → {os.path.abspath(out_path)}")
    return rows

print("[ok] predicate_logger_safe_v7 loaded")


[shim] AttU-Net: aliased Maxpool1 -> Maxpool (if missing).
[install] Tail-aware margin override installed (with safe fallback).
[ok] Quiet tail + MaxPool patch active (v11)
[ok] v12c: tail fixed (real last block + final conv), AttU-Net encoder autodetect + MaxPool-safe
[hotfix] Patched nova_choose_change_channel/logits_to_clean_pred (uses torch.softmax; no np.exp).
[ok] predicate_logger_safe_v7 loaded


In [21]:
# _ = run_predicate_logger(
#         MODELS, CITIES, EPS,
#         rho_main=0.30, gamma_main=0.30, smin_main=16,
#         appendix=False,
#         out_path="predicate_pass_v1.csv")


In [22]:
# # === make_table_predicate_strict_main_v2.py ===
# # Build the main "Strict predicate certification under OOD" LaTeX table
# # from predicate_pass_v1.csv (created by the logger).

# import csv, math, os
# from collections import defaultdict

# # ---- INPUT / OUTPUT ----
# CSV_IN  = "predicate_pass_v1.csv"        # adjust path if needed
# TEX_OUT = "table_predicate_strict_main.tex"

# if not os.path.exists(CSV_IN):
#     raise FileNotFoundError(f"{CSV_IN} not found. Run the logger first or point CSV_IN to the file.")

# # ---- helpers ----
# def _nice_eps(e):
#     k = round(float(e) * 255)
#     if abs(float(e) - k/255) < 1e-6 and 0 <= k <= 64:
#         return f"{k}/255"
#     s = f"{float(e):.5f}".rstrip("0").rstrip(".")
#     return s

# def _to_bool(v):
#     return str(v).strip().lower() in ("true","1","t","yes","y")

# def _median(vals):
#     if not vals: return float("nan")
#     s = sorted(vals)
#     n = len(s)
#     mid = n // 2
#     return s[mid] if n % 2 == 1 else 0.5*(s[mid-1] + s[mid])

# def _fmt_pct(x):
#     if math.isnan(x): return "na"
#     return f"{x:.0f}" if abs(x - round(x)) < 1e-6 else f"{x:.1f}"

# def _fmt_float(x, nd=2):
#     return "na" if (x is None or (isinstance(x,float) and math.isnan(x))) else f"{x:.{nd}f}"

# # Desired display order (feel free to edit)
# MODEL_ORDER = ["AttU-Net", "EncDec", "FALCONet"]
# EPS_ORDER   = ["0/255", "1/255", "2/255"]

# # Main-paper predicate setting
# RHO_MAIN, GAMMA_MAIN, SMIN_MAIN = 0.30, 0.30, 16

# # ---- Load rows & filter to main-paper predicate triplet ----
# rows = []
# with open(CSV_IN, newline="") as f:
#     r = csv.DictReader(f)
#     for d in r:
#         try:
#             if (abs(float(d["rho"]) - RHO_MAIN) < 1e-9 and
#                 abs(float(d["gamma"]) - GAMMA_MAIN) < 1e-9 and
#                 int(float(d["s_min"])) == SMIN_MAIN):
#                 rows.append({
#                     "model": d["model"],
#                     "eps_val": float(d["eps"]),
#                     "strict": _to_bool(d["strict"]),
#                     "overlap": float(d["overlap"]),
#                     "fp": float(d["fp"]),
#                 })
#         except Exception:
#             pass

# if not rows:
#     raise RuntimeError("No rows for main-paper predicate setting "
#                        f"(rho={RHO_MAIN}, gamma={GAMMA_MAIN}, s_min={SMIN_MAIN}).")

# # ---- Group by (model, eps) ----
# grp = defaultdict(list)
# for d in rows:
#     grp[(d["model"], d["eps_val"])].append(d)

# # ---- Aggregate per (model, eps) ----
# summary = []
# for (model, eps_val), lst in grp.items():
#     n = len(lst)
#     pass_rate = 100.0 * sum(1 for x in lst if x["strict"]) / max(1, n)
#     med_overlap = _median([x["overlap"] for x in lst])
#     med_fp      = _median([x["fp"] for x in lst])
#     summary.append({
#         "model": model,
#         "eps_val": eps_val,
#         "eps_str": _nice_eps(eps_val),
#         "n": n,
#         "pass_pct": pass_rate,
#         "med_overlap": med_overlap,
#         "med_fp": med_fp,
#     })

# # ---- Sort consistently: (model rank, eps rank, fallback by numeric eps, names) ----
# def _model_rank(m): return MODEL_ORDER.index(m) if m in MODEL_ORDER else len(MODEL_ORDER)
# def _eps_rank(s):   return EPS_ORDER.index(s) if s in EPS_ORDER else len(EPS_ORDER)

# summary.sort(key=lambda d: (_model_rank(d["model"]),
#                             _eps_rank(d["eps_str"]),
#                             d["eps_val"], d["model"], d["eps_str"]))

# # ---- Emit LaTeX ----
# lines = []
# lines.append(r"\begin{table}[t]")
# lines.append(r"\centering")
# lines.append(r"\caption{\textbf{Strict predicate certification under OOD.} Pass rate (\%), median overlap and median FP for " +
#              rf"$\rho={RHO_MAIN:.2f}$, $\gamma={GAMMA_MAIN:.2f}$, $s_{{\min}}={SMIN_MAIN}$" + r".}")
# lines.append(r"\label{tab:strict_predicate_main}")
# lines.append(r"\begin{tabular}{lcccc}")
# lines.append(r"\toprule")
# lines.append(r"\textbf{Model} & $\boldsymbol{\varepsilon}$ & $\boldsymbol{n}$ & \textbf{Strict pass (\%)} & \textbf{Median overlap / FP} \\")
# lines.append(r"\midrule")

# last_model = None
# for row in summary:
#     model = row["model"]
#     eps   = row["eps_str"]
#     n     = row["n"]
#     pp    = _fmt_pct(row["pass_pct"])
#     mOv   = _fmt_float(row["med_overlap"], 2)
#     mFp   = _fmt_float(row["med_fp"], 2)
#     if last_model is not None and model != last_model:
#         lines.append(r"\midrule")
#     last_model = model
#     lines.append(f"{model} & ${eps}$ & {n} & {pp} & {mOv} / {mFp} \\\\")

# lines.append(r"\bottomrule")
# lines.append(r"\end{tabular}")
# lines.append(r"\end{table}")

# with open(TEX_OUT, "w") as f:
#     f.write("\n".join(lines))

# print(f"[ok] wrote LaTeX → {os.path.abspath(TEX_OUT)}")


In [23]:
# _ = run_sweep(MODELS, cities=CITIES, eps_grid=EPS, rho_grid=RHO, gamma_grid=GAMMA, s_min_grid=SMIN,
#               out_csv="predicate_summary_semantic_v7.csv", write_every=10)

In [24]:
# # === predicate_logger_safe_v6 — robust shapes + robust channel chooser + LaTeX ===
# import os, csv, math, numpy as np, torch
# from collections import defaultdict

# CSV_OUT = "predicate_pass_v1.csv"
# TEX_OUT = "table_predicate_strict_main.tex"

# # ---------- shape helpers ----------
# def _to_np(x):
#     return x.detach().cpu().numpy() if torch.is_tensor(x) else np.asarray(x)

# def _coerce_bin_hw(mask):
#     """
#     Return 2-D [H,W] uint8 (0/1).
#     Accepts: torch/numpy with shapes [H,W], [1,H,W], [H,W,1], [C,H,W], [H,W,C], [1,C,H,W].
#     """
#     m = _to_np(mask)
#     m = np.asarray(m)
#     m = np.squeeze(m)
#     if m.ndim == 3:  # reduce channels
#         if m.shape[0] == 1:
#             m = m[0]
#         elif m.shape[-1] == 1:
#             m = m[..., 0]
#         else:
#             m = (m != 0).any(axis=0).astype(np.uint8)
#     if m.ndim != 2:
#         raise ValueError(f"_coerce_bin_hw: expected 2-D after squeeze, got shape {m.shape}")
#     return (m > 0).astype(np.uint8)

# def _slice_ch(t, ch):
#     """
#     Extract a single class map as [H,W] tensor from bounds/logits with shape:
#     [C,H,W] or [1,C,H,W] or [N,C,H,W] (with N=1).
#     """
#     if not torch.is_tensor(t):
#         t = torch.as_tensor(t)
#     if t.dim() == 4:
#         return t[0, ch]          # [1,C,H,W] or [N=1,C,H,W]
#     elif t.dim() == 3:
#         return t[ch]             # [C,H,W]
#     else:
#         raise ValueError(f"_slice_ch: expected 3D/4D tensor, got shape {tuple(t.shape)}")

# def _logits_CHW(clean_logits):
#     """Return logits as [C,H,W] (squeeze batch if present)."""
#     z = clean_logits
#     if z.dim() == 4:   # [1,C,H,W]
#         z = z[0]
#     if z.dim() != 3:
#         raise ValueError(f"_logits_CHW: expected [C,H,W] (or [1,C,H,W]), got {tuple(clean_logits.shape)}")
#     return z

# def _argmax_HW(clean_logits):
#     """Return argmax label map [H,W] from logits shaped [C,H,W] or [1,C,H,W]."""
#     if clean_logits.dim() == 4:
#         return torch.argmax(clean_logits, dim=1)[0]
#     elif clean_logits.dim() == 3:
#         return torch.argmax(clean_logits, dim=0)
#     else:
#         raise ValueError(f"_argmax_HW: expected 3D/4D logits, got {tuple(clean_logits.shape)}")

# # ---------- CC + predicates ----------
# def _connected_components_4(mask_bin):
#     m = _coerce_bin_hw(mask_bin)
#     H, W = m.shape
#     lab = np.zeros((H,W), dtype=np.int32)
#     sizes, cur = [], 0
#     for i in range(H):
#         for j in range(W):
#             if m[i,j] and lab[i,j]==0:
#                 cur += 1
#                 stack = [(i,j)]; lab[i,j] = cur; sz = 0
#                 while stack:
#                     r,c = stack.pop(); sz += 1
#                     for dr,dc in ((1,0),(-1,0),(0,1),(0,-1)):
#                         rr,cc = r+dr, c+dc
#                         if 0<=rr<H and 0<=cc<W and m[rr,cc] and lab[rr,cc]==0:
#                             lab[rr,cc] = cur; stack.append((rr,cc))
#                 sizes.append(sz)
#     return sizes

# def _predicates(Ccert, Cclean, Cgt, rho, gamma, s_min):
#     Ccert = _coerce_bin_hw(Ccert).astype(bool)
#     Cclean = _coerce_bin_hw(Cclean).astype(bool)
#     Cgt = _coerce_bin_hw(Cgt).astype(bool)

#     denom = max(1, Cclean.sum())
#     overlap = np.logical_and(Ccert, Cclean).sum() / denom
#     Poverlap = (overlap >= rho)

#     cert_sz = Ccert.sum()
#     fp = 0.0 if cert_sz == 0 else (np.logical_and(Ccert, np.logical_not(Cgt)).sum() / cert_sz)
#     Pfp = (fp <= gamma)

#     sizes = _connected_components_4(Ccert.astype(np.uint8))
#     Ppattern = all(sz >= s_min for sz in sizes)

#     return Poverlap, Pfp, Ppattern, (Poverlap and Pfp and Ppattern), overlap, fp, (max(sizes) if sizes else 0)

# # ---------- robust channel chooser (IoU vs GT) ----------
# def _choose_channels(clean_logits, gt_mask):
#     """
#     Choose "change" channel by max IoU(pred_c, GT). Non-change is the other (binary)
#     or the most background-like among the rest for multi-class.
#     """
#     z = _logits_CHW(clean_logits)           # [C,H,W]
#     p = torch.softmax(z, dim=0)             # [C,H,W]
#     gt = _coerce_bin_hw(gt_mask).astype(bool)
#     C = z.shape[0]

#     ious = []
#     for c in range(C):
#         pred_c = (p[c] > 0.5).cpu().numpy()
#         inter = np.logical_and(pred_c, gt).sum()
#         union = np.logical_or(pred_c, gt).sum()
#         iou = inter / max(1, union)
#         ious.append(iou)
#     chg = int(np.argmax(ious))

#     if C == 2:
#         nchg = 1 - chg
#     else:
#         # pick background-like: highest mean prob on GT==0
#         nz = np.where(~gt)
#         bg_scores = []
#         for c in range(C):
#             if c == chg: continue
#             if nz[0].size:
#                 bg_scores.append((p[c][nz].mean().item(), c))
#             else:
#                 bg_scores.append((p[c].mean().item(), c))
#         nchg = max(bg_scores)[1] if bg_scores else (chg + 1) % C

#     return chg, nchg

# # ---------- one combo ----------
# def _one_combo(model_name, model, model_type, city, eps, rho, gamma, s_min):
#     # rcd() must be defined in your notebook; returns
#     # (lower, upper, clean_logits, clean_pred, gt_mask)
#     lower, upper, clean_logits, clean_pred, gt_mask = rcd(model, model_type, city, eps)

#     # choose channels robustly
#     chg, nchg = _choose_channels(clean_logits, gt_mask)

#     # certified set from margin LB>0 (supports [C,H,W] or [1,C,H,W])
#     mL = _slice_ch(lower, chg) - _slice_ch(upper, nchg)  # [H,W]
#     Ccert = (mL > 0)

#     # clean predicted change set
#     if clean_logits.shape[1] == 2:
#         Cclean = (_argmax_HW(clean_logits) == chg)       # [H,W]
#     else:
#         z = _logits_CHW(clean_logits)
#         Cclean = (torch.softmax(z, dim=0)[chg] > 0.5)

#     Cgt = gt_mask

#     Poverlap, Pfp, Ppattern, Pstrict, overlap, fp, largest_cc = \
#         _predicates(Ccert, Cclean, Cgt, rho, gamma, s_min)

#     return {
#         "model": model_name, "type": int(model_type), "city": city, "eps": float(eps),
#         "rho": float(rho), "gamma": float(gamma), "s_min": int(s_min),
#         "strict": bool(Pstrict), "Poverlap": bool(Poverlap), "Pfp": bool(Pfp), "Ppattern": bool(Ppattern),
#         "overlap": float(overlap), "fp": float(fp), "largest_cc": int(largest_cc)
#     }

# # ---------- driver (main table only) ----------
# def run_predicate_logger(MODELS, CITIES, EPS,
#                          rho_main=0.30, gamma_main=0.30, smin_main=16,
#                          out_path=CSV_OUT):
#     header = ["model","type","city","eps","rho","gamma","s_min",
#               "strict","Poverlap","Pfp","Ppattern",
#               "overlap","fp","largest_cc"]
#     rows = []
#     def _append_safe(args):
#         try:
#             rows.append(_one_combo(*args))
#         except Exception as e:
#             name, _, _, city, eps, rho, gamma, smin = args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]
#             print(f"[warn] logger skip {name}/{city}/eps={eps:.5f} (rho={rho},gamma={gamma},s={smin}): {e}")

#     for (model_name, model, model_type) in MODELS:
#         for city in CITIES:
#             for eps in EPS:
#                 _append_safe((model_name, model, model_type, city, eps, rho_main, gamma_main, smin_main))

#     with open(out_path, "w", newline="") as f:
#         w = csv.DictWriter(f, fieldnames=header); w.writeheader(); w.writerows(rows)
#     print(f"[predicates] wrote {len(rows)} rows → {os.path.abspath(out_path)}")
#     return rows

# # ---------- build LaTeX from CSV ----------
# def _nice_eps(e):
#     e = float(e)
#     k = round(e*255)
#     return f"{k}/255" if abs(e - k/255) < 1e-6 and 0 <= k <= 64 else f"{e:.5f}".rstrip("0").rstrip(".")

# def _median(vals):
#     if not vals: return float("nan")
#     s = sorted(vals); n=len(s); mid=n//2
#     return s[mid] if n%2 else 0.5*(s[mid-1]+s[mid])

# def build_strict_table(csv_path=CSV_OUT, tex_out=TEX_OUT,
#                        rho=0.30, gamma=0.30, smin=16,
#                        model_order=("AttU-Net","EncDec","FALCONet"),
#                        eps_order=("0/255","1/255","2/255")):
#     rows = []
#     with open(csv_path, newline="") as f:
#         r = csv.DictReader(f)
#         for d in r:
#             if abs(float(d["rho"])-rho)<1e-9 and abs(float(d["gamma"])-gamma)<1e-9 and int(float(d["s_min"]))==smin:
#                 rows.append({"model": d["model"],
#                              "eps_val": float(d["eps"]),
#                              "eps_str": _nice_eps(d["eps"]),
#                              "strict": str(d["strict"]).lower() in ("true","1","t","yes","y"),
#                              "overlap": float(d["overlap"]),
#                              "fp": float(d["fp"])})

#     grp = defaultdict(list)
#     for d in rows: grp[(d["model"], d["eps_str"], d["eps_val"])].append(d)

#     summary = []
#     for (m, es, ev), lst in grp.items():
#         n = len(lst)
#         pass_pct = 100.0*sum(x["strict"] for x in lst)/max(1,n)
#         med_overlap = _median([x["overlap"] for x in lst])
#         med_fp      = _median([x["fp"] for x in lst])
#         summary.append({"model": m, "eps_str": es, "eps_val": ev, "n": n,
#                         "pass_pct": pass_pct, "med_overlap": med_overlap, "med_fp": med_fp})

#     def _mkey(m):  return model_order.index(m) if m in model_order else len(model_order)
#     def _ekey(es): return eps_order.index(es) if es in eps_order else len(eps_order)
#     summary.sort(key=lambda d: (_mkey(d["model"]), _ekey(d["eps_str"]), d["eps_val"]))

#     lines = []
#     lines += [r"\begin{table}[t]", r"\centering",
#               r"\caption{\textbf{Strict predicate certification under OOD.} Pass rate (\%), median overlap and median FP for $\rho=0.30$, $\gamma=0.30$, $s_{\min}=16$.}",
#               r"\label{tab:strict_predicate_main}",
#               r"\begin{tabular}{lcccc}", r"\toprule",
#               r"\textbf{Model} & $\boldsymbol{\varepsilon}$ & $\boldsymbol{n}$ & \textbf{Strict pass (\%)} & \textbf{Median overlap / FP} \\",
#               r"\midrule"]
#     last = None
#     for r in summary:
#         if last is not None and r["model"] != last:
#             lines.append(r"\midrule")
#         last = r["model"]
#         lines.append(f'{r["model"]} & ${r["eps_str"]}$ & {r["n"]} & '
#                      f'{r["pass_pct"]:.0f} & {r["med_overlap"]:.2f} / {r["med_fp"]:.2f} \\\\')
#     lines += [r"\bottomrule", r"\end{tabular}", r"\end{table}"]

#     with open(tex_out, "w") as f:
#         f.write("\n".join(lines))
#     print(f"[ok] wrote LaTeX → {os.path.abspath(tex_out)}")

# # ---------- run ----------
# _ = run_predicate_logger(MODELS, CITIES, EPS, rho_main=0.30, gamma_main=0.30, smin_main=16)
# build_strict_table(CSV_OUT, TEX_OUT)


In [25]:
# # --- Visualize Ccert / Cclean / Cgt for a few seeds (one per {model, city, eps}) ---
# # Requires: your `rcd()` function and the same MODELS, CITIES, EPS you used to log the CSV
# # Output PNGs go to ./viz_predicates/

# import os, numpy as np, torch
# from PIL import Image
# os.makedirs("viz_predicates", exist_ok=True)

# def _to_np(x): return x.detach().cpu().numpy() if torch.is_tensor(x) else np.asarray(x)
# def _coerce_bin_hw(x):
#     m = _to_np(x); m = np.squeeze(m)
#     if m.ndim == 3:
#         if m.shape[0] == 1: m = m[0]
#         elif m.shape[-1] == 1: m = m[...,0]
#         else: m = (m!=0).any(axis=0).astype(np.uint8)
#     assert m.ndim==2, f"mask not 2D: {m.shape}"
#     return (m>0).astype(np.uint8)

# def _logits_CHW(z):
#     if z.dim()==4: z=z[0]
#     assert z.dim()==3, z.shape
#     return z

# def _choose_channels(clean_logits, gt_mask):
#     z = _logits_CHW(clean_logits)
#     p = torch.softmax(z, dim=0)
#     gt = _coerce_bin_hw(gt_mask).astype(bool)
#     C = z.shape[0]
#     ious=[]
#     for c in range(C):
#         pred = (p[c]>0.5).cpu().numpy()
#         inter = np.logical_and(pred, gt).sum()
#         union = np.logical_or(pred, gt).sum()
#         ious.append(inter/max(1,union))
#     chg = int(np.argmax(ious))
#     nchg = 1-chg if C==2 else (chg+1)%C
#     return chg,nchg

# def _slice_ch(t, ch):
#     if t.dim()==4: return t[0, ch]
#     elif t.dim()==3: return t[ch]
#     else: raise ValueError(t.shape)

# def _argmax_HW(z):
#     return torch.argmax(z, dim=1)[0] if z.dim()==4 else torch.argmax(z, dim=0)

# def _save_triplet(arrs, path):
#     # arrs are 0/1 masks [H,W]
#     stack = np.stack(arrs, axis=-1).astype(np.uint8)*255
#     Image.fromarray(stack).save(path)

# picked = set()
# for name, model, mtype in MODELS:
#     for city in CITIES:
#         for eps in EPS:
#             key=(name,city,eps)
#             if key in picked: continue
#             try:
#                 lower, upper, clean_logits, clean_pred, gt_mask = rcd(model, mtype, city, eps)
#                 chg, nchg = _choose_channels(clean_logits, gt_mask)
#                 mL = _slice_ch(lower, chg) - _slice_ch(upper, nchg)
#                 Ccert  = (mL>0).cpu().numpy()
#                 if clean_logits.shape[1]==2:
#                     Cclean = (_argmax_HW(clean_logits)==chg).cpu().numpy()
#                 else:
#                     z=_logits_CHW(clean_logits)
#                     Cclean=(torch.softmax(z, dim=0)[chg]>0.5).cpu().numpy()
#                 Cgt    = _coerce_bin_hw(gt_mask)

#                 out = f"viz_predicates/{name}_{city}_eps{round(eps*255) if eps<=1 else eps}.png"
#                 _save_triplet([Ccert, Cclean, Cgt], out)
#                 print("[viz] saved", out, " (R=cert, G=clean, B=GT; white=all three)")
#                 picked.add(key)
#             except Exception as e:
#                 print("[viz] skip", key, ":", e)


In [26]:
# # --- Build sensitivity tables from predicate_pass_v1.csv ---
# import csv, numpy as np
# from collections import defaultdict

# CSV_IN = "predicate_pass_v1.csv"
# def _nice_eps(e):
#     e=float(e); k=round(e*255)
#     return f"{k}/255" if abs(e-k/255)<1e-6 and 0<=k<=64 else f"{e:.5f}".rstrip("0").rstrip(".")

# # load rows
# rows=[]
# with open(CSV_IN, newline="") as f:
#     r=csv.DictReader(f)
#     for d in r:
#         rows.append({
#             "model": d["model"],
#             "city": d["city"],
#             "eps_val": float(d["eps"]),
#             "eps_str": _nice_eps(d["eps"]),
#             "rho": float(d["rho"]), "gamma": float(d["gamma"]), "smin": int(float(d["s_min"])),
#             "strict": str(d["strict"]).lower() in ("true","1","yes","t","y"),
#             "overlap": float(d["overlap"]), "fp": float(d["fp"])
#         })

# # summarize by (model, eps, rho, gamma, smin)
# grp=defaultdict(list)
# for d in rows: grp[(d["model"], d["eps_str"], d["rho"], d["gamma"], d["smin"])].append(d)

# def _med(lst): 
#     if not lst: return float("nan")
#     s=sorted(lst); n=len(s); mid=n//2
#     return s[mid] if n%2 else 0.5*(s[mid-1]+s[mid])

# # choose a few settings to show how pass rate responds
# SETTINGS=[(0.10,0.70,8),(0.20,0.50,16),(0.30,0.30,16)]
# models=["AttU-Net","EncDec","FALCONet"]
# eps_order=["0/255","1/255","2/255"]

# lines=[]
# lines += [r"\begin{table}[t]", r"\centering",
#           r"\caption{\textbf{Predicate sensitivity.} Strict pass rate (\%) and med.\ overlap/FP under varying $(\rho,\gamma,s_{\min})$.}",
#           r"\label{tab:predicate_sensitivity}",
#           r"\begin{tabular}{lcccccc}", r"\toprule",
#           r"\textbf{Model} & $\varepsilon$ & $(\rho,\gamma,s_{\min})$ & $n$ & Strict (\%) & med.\ overlap & med.\ FP \\",
#           r"\midrule"]
# for m in models:
#     first=True
#     for eps in eps_order:
#         for (rho,gam,smin) in SETTINGS:
#             lst = grp.get((m, eps, rho, gam, smin), [])
#             n=len(lst)
#             strict_pct = 100.0*sum(v["strict"] for v in lst)/max(1,n)
#             med_ol=_med([v["overlap"] for v in lst])
#             med_fp=_med([v["fp"] for v in lst])
#             tag = f"$({rho:.2f},{gam:.2f},{smin})$"
#             lines.append(f"{m if first else ''} & ${eps}$ & {tag} & {n} & {strict_pct:.0f} & {med_ol:.2f} & {med_fp:.2f} \\\\")
#             first=False
#     lines.append(r"\midrule")
# lines += [r"\bottomrule", r"\end{tabular}", r"\end{table}"]

# with open("table_predicate_sensitivity.tex","w") as f:
#     f.write("\n".join(lines))
# print("[ok] wrote table_predicate_sensitivity.tex")


In [27]:
# # === predicate_table_export_v3 ===
# import os, csv, math, statistics as stats

# # 1) Where is your CSV? (first existing path wins)
# CANDIDATES = [
#     "./predicate_pass_v1.csv",
#     "/mnt/g/phd_experiments/oscd_ood/predicate_pass_v1.csv",  # your earlier print
# ]
# CSV_IN = None
# for p in CANDIDATES:
#     if os.path.exists(p) and os.path.getsize(p) > 0:
#         CSV_IN = p; break
# if not CSV_IN:
#     raise FileNotFoundError("predicate_pass_v1.csv not found in any of: " + ", ".join(CANDIDATES))

# print(f"[table] reading {CSV_IN}")

# # 2) Read rows and filter to main predicate setting
# rows = []
# with open(CSV_IN, newline="") as f:
#     r = csv.DictReader(f)
#     for d in r:
#         try:
#             if (float(d["rho"]) == 0.30 and float(d["gamma"]) == 0.30 and int(float(d["s_min"])) == 16):
#                 rows.append({
#                     "model": d["model"],
#                     "eps": float(d["eps"]),
#                     "strict": (str(d["strict"]).lower() in ("1","true","yes")),
#                     "overlap": float(d["overlap"]),
#                     "fp": float(d["fp"]),
#                 })
#         except Exception:
#             pass

# if not rows:
#     raise RuntimeError("No rows for (rho=0.30, gamma=0.30, s_min=16). Did the logger write main-setting rows?")

# # 3) Helpers
# def nice_eps(e):
#     k = round(e*255)
#     if abs(e - k/255) < 1e-8 and 0 <= k <= 255:
#         return f"{k}/255"
#     return f"{e:.5f}".rstrip("0").rstrip(".")

# def median_safe(vals):
#     vals = [v for v in vals if math.isfinite(v)]
#     return float("nan") if not vals else float(stats.median(sorted(vals)))

# # 4) Aggregate by (model, eps)
# from collections import defaultdict
# grp = defaultdict(list)
# for d in rows:
#     grp[(d["model"], d["eps"])].append(d)

# # 5) Build summary rows
# summary = []
# for (model, eps), lst in grp.items():
#     n = len(lst)
#     strict_pct = 100.0 * sum(1 for x in lst if x["strict"]) / max(1, n)
#     med_overlap = median_safe([x["overlap"] for x in lst])
#     med_fp      = median_safe([x["fp"] for x in lst])
#     summary.append({
#         "model": model,
#         "eps": eps,
#         "eps_str": nice_eps(eps),
#         "n": n,
#         "strict_pct": strict_pct,
#         "med_overlap": med_overlap,
#         "med_fp": med_fp,
#     })

# # 6) Order: model name A–Z, eps in [0/255,1/255,2/255] then others
# EPS_ORDER = ["0/255","1/255","2/255"]
# def model_key(m): return m.lower()
# def eps_key(es):
#     return (0, EPS_ORDER.index(es)) if es in EPS_ORDER else (1, es)

# summary.sort(key=lambda d: (model_key(d["model"]), eps_key(d["eps_str"])))

# # 7) Emit LaTeX file
# out_path = "table_predicate_strict_main.tex"
# lines = []
# lines.append("\\begin{table}[t]")
# lines.append("\\centering")
# lines.append("\\caption{\\textbf{Strict predicate certification under OOD.} Pass rate (\\%), median overlap and median FP for $\\rho{=}0.30$, $\\gamma{=}0.30$, $s_{\\min}{=}16$.}")
# lines.append("\\label{tab:strict_predicate_main}")
# lines.append("\\begin{tabular}{lcccc}")
# lines.append("\\toprule")
# lines.append("\\textbf{Model} & $\\boldsymbol{\\varepsilon}$ & $\\boldsymbol{n}$ & \\textbf{Strict pass (\\%)} & \\textbf{Median overlap / FP} \\\\")
# lines.append("\\midrule")

# cur_model = None
# for d in summary:
#     if cur_model is not None and d["model"] != cur_model:
#         lines.append("\\midrule")
#     cur_model = d["model"]
#     s_pct = f"{d['strict_pct']:.0f}"
#     ov = "nan" if math.isnan(d["med_overlap"]) else f"{d['med_overlap']:.2f}"
#     fp = "nan" if math.isnan(d["med_fp"]) else f"{d['med_fp']:.2f}"
#     lines.append(f"{d['model']} & ${d['eps_str']}$ & {d['n']} & {s_pct} & {ov} / {fp} \\\\")
# lines.append("\\bottomrule")
# lines.append("\\end{tabular}")
# lines.append("\\end{table}")

# with open(out_path, "w") as f:
#     f.write("\n".join(lines))

# print(f"[table] wrote {os.path.abspath(out_path)}")
# print("\n".join(lines[:12] + ["…"] if len(lines) > 12 else lines))


In [7]:
# === balanced_predicate_tables_v2 — no dependency on run_predicate_logger ===
import os, csv, numpy as np, torch

# ---- predicate settings ----
BALANCED_MAIN  = (0.20, 0.50, 16)  # (rho, gamma, s_min)
RELAXED_SECOND = (0.10, 0.70, 8)

CSV_OUT  = "predicate_pass_balanced_v1_hct.csv"
TEX_MAIN = "table_predicate_balanced_main_hct.tex"
TEX_REL  = "table_predicate_relaxed_main_hct.tex"

# ---- small helpers (shape-safe) ----
def _nice_eps(e):
    e = float(e)
    k = round(e * 255)
    if abs(e - k/255) < 1e-6 and 0 <= k <= 64:
        return f"{k}/255"
    return f"{e:.5f}".rstrip('0').rstrip('.')

def _to_np(x):
    return x.detach().cpu().numpy() if torch.is_tensor(x) else np.asarray(x)

def _coerce_bin_hw(x, H, W):
    """
    Coerce x -> binary [H,W] uint8.
    Accepts tensors/ndarrays with shapes:
      - [H,W], [1,H,W], [H,W,1], [1,1,H,W], [H*W], [C,H,W], [H,W,C]
    If multiple channels, uses any-nonzero across channels.
    If flattened, reshapes to [H,W].
    """
    m = _to_np(x)
    m = np.squeeze(m)
    if m.ndim == 1:
        if m.size == H*W:
            m = m.reshape(H, W)
        else:
            raise ValueError(f"_coerce_bin_hw: got 1-D of length {m.size}, cannot reshape to ({H},{W}).")
    elif m.ndim == 3:
        # (C,H,W) or (H,W,C)
        if m.shape[0] == 1:
            m = m[0]
        elif m.shape[-1] == 1:
            m = m[..., 0]
        else:
            m = (m != 0).any(axis=0).astype(np.uint8)
    elif m.ndim != 2:
        raise ValueError(f"_coerce_bin_hw: expected 2-D/3-D/1-D, got shape {m.shape}")
    return (m > 0).astype(np.uint8)

def _connected_components_4(mask_hw_uint8):
    m = (mask_hw_uint8.astype(np.uint8) > 0)
    H, W = m.shape
    if H == 0 or W == 0: return []
    lab = np.zeros((H, W), dtype=np.int32)
    sizes, cur = [], 0
    for i in range(H):
        for j in range(W):
            if m[i,j] and lab[i,j] == 0:
                cur += 1
                stack = [(i,j)]
                lab[i,j] = cur
                sz = 0
                while stack:
                    r,c = stack.pop()
                    sz += 1
                    for dr,dc in ((1,0),(-1,0),(0,1),(0,-1)):
                        rr,cc = r+dr, c+dc
                        if 0<=rr<H and 0<=cc<W and m[rr,cc] and lab[rr,cc]==0:
                            lab[rr,cc] = cur
                            stack.append((rr,cc))
                sizes.append(sz)
    return sizes

def _choose_channels(clean_logits_1CHW, gt_hw_bin):
    """
    clean_logits_1CHW: torch tensor [1,C,H,W]
    gt_hw_bin: numpy [H,W] bool/uint8
    Returns (chg, nchg).
    """
    with torch.no_grad():
        p = torch.softmax(clean_logits_1CHW, dim=1)[0].cpu().numpy()  # [C,H,W]
    C, H, W = p.shape
    if C == 2:
        return 1, 0
    gt = (gt_hw_bin > 0)
    best_iou, best_c = -1.0, 0
    for c in range(C):
        pred = (p[c] > 0.5)
        union = np.logical_or(pred, gt).sum()
        iou   = (np.logical_and(pred, gt).sum() / max(1, union))
        if iou > best_iou:
            best_iou, best_c = iou, c
    nchg = 0 if best_c != 0 else 1  # simple fallback
    return best_c, nchg

def _predicates(Ccert, Cclean, Cgt, rho, gamma, s_min):
    Ccert = Ccert.astype(bool); Cclean = Cclean.astype(bool); Cgt = Cgt.astype(bool)
    denom = max(1, Cclean.sum())
    overlap = np.logical_and(Ccert, Cclean).sum() / denom
    Poverlap = (overlap >= rho)
    cert_sz = Ccert.sum()
    fp = 0.0 if cert_sz == 0 else (np.logical_and(Ccert, np.logical_not(Cgt)).sum() / cert_sz)
    Pfp = (fp <= gamma)
    sizes = _connected_components_4(Ccert.astype(np.uint8))
    Ppattern = all(sz >= s_min for sz in sizes)
    return Poverlap, Pfp, Ppattern, (Poverlap and Pfp and Ppattern), overlap, fp

# ---- run the two settings and write CSV ----
header = ["model","type","city","eps","rho","gamma","s_min",
          "strict","Poverlap","Pfp","Ppattern","overlap","fp","largest_cc"]

rows = []
def _try_append(model_name, model, model_type, city, eps, rho, gamma, s_min):
    try:
        lower, upper, clean_logits, clean_pred, gt_mask = rcd(model, model_type, city, eps)
        # squeeze batch to [C,H,W]
        lower, upper, z = lower.squeeze(0), upper.squeeze(0), clean_logits.squeeze(0)  # [C,H,W]
        C, H, W = z.shape
        gt_hw = _coerce_bin_hw(gt_mask, H, W)
        # add batch back for softmax
        z1 = z.unsqueeze(0)  # [1,C,H,W]
        chg, nchg = _choose_channels(z1, gt_hw)
        # margin LB > 0 region
        mL = lower[chg] - upper[nchg]  # [H,W]
        Ccert  = _coerce_bin_hw((mL > 0), H, W)
        # clean-pred set
        if z1.shape[1] == 2:
            Cclean = _coerce_bin_hw((torch.argmax(z1, dim=1)[0] == chg), H, W)
        else:
            Cclean = _coerce_bin_hw((torch.softmax(z1, dim=1)[0, chg] > 0.5), H, W)
        # predicates
        Poverlap, Pfp, Ppattern, Pstrict, overlap, fp = _predicates(Ccert, Cclean, gt_hw, rho, gamma, s_min)
        sizes = _connected_components_4(Ccert.astype(np.uint8))
        rows.append({
            "model": model_name, "type": int(model_type), "city": city, "eps": float(eps),
            "rho": float(rho), "gamma": float(gamma), "s_min": int(s_min),
            "strict": bool(Pstrict), "Poverlap": bool(Poverlap), "Pfp": bool(Pfp), "Ppattern": bool(Ppattern),
            "overlap": float(overlap), "fp": float(fp), "largest_cc": int(max(sizes) if sizes else 0)
        })
    except Exception as e:
        print(f"[warn] skip {model_name}/{city}/eps={_nice_eps(eps)} (rho={rho},gamma={gamma},s={s_min}): {e}")

for (rho, gamma, smin) in (BALANCED_MAIN, RELAXED_SECOND):
    for (model_name, model, model_type) in MODELS:
        for city in CITIES:
            for eps in EPS:
                _try_append(model_name, model, model_type, city, eps, rho, gamma, smin)

with open(CSV_OUT, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=header)
    w.writeheader()
    w.writerows(rows)
print(f"[balanced] wrote {len(rows)} rows → {os.path.abspath(CSV_OUT)}")

# ---- summarize -> LaTeX ----
def _median(xs):
    if not xs: return float('nan')
    s = sorted(xs); n = len(s); mid = n//2
    return s[mid] if (n%2==1) else 0.5*(s[mid-1]+s[mid])

def _summarize(rows, want):
    # want: (rho,gamma,smin)
    sel = [r for r in rows if (r["rho"],r["gamma"],r["s_min"])==want]
    grp = {}
    for r in sel:
        key = (r["model"], r["eps"])
        grp.setdefault(key, []).append(r)
    out = []
    for (model, eps), lst in grp.items():
        n = len(lst)
        strict_pct = 100.0 * sum(1 for x in lst if x["strict"]) / max(1, n)
        med_overlap = _median([x["overlap"] for x in lst])
        med_fp      = _median([x["fp"] for x in lst])
        out.append({
            "model": model, "eps": eps, "eps_str": _nice_eps(eps), "n": n,
            "strict_pct": strict_pct, "med_overlap": med_overlap, "med_fp": med_fp
        })
    # sort
    EPS_ORDER = ["0/255","1/255","2/255","0.00098","0.00196"]
    order = {"EncDec":0, "FALCONet":1, "AttU-Net":2}
    def _mkey(m): return (order.get(m,99), m)
    def _ekey(e): return EPS_ORDER.index(e) if e in EPS_ORDER else (len(EPS_ORDER)+1)
    out.sort(key=lambda d: (_mkey(d["model"]), _ekey(d["eps_str"])))
    return out

def _write_tex(tex_path, rows, caption, label):
    with open(tex_path, "w") as f:
        f.write("\\begin{table}[t]\n\\centering\n")
        f.write(f"\\caption{{{caption}}}\n")
        f.write(f"\\label{{{label}}}\n")
        f.write("\\begin{tabular}{lcccc}\n\\toprule\n")
        f.write("\\textbf{Model} & $\\boldsymbol{\\varepsilon}$ & $\\boldsymbol{n}$ & "
                "\\textbf{Strict pass (\\%)} & \\textbf{Median overlap / FP} \\\\\n\\midrule\n")
        last = None
        for d in rows:
            if last is not None and d["model"] != last:
                f.write("\\midrule\n")
            last = d["model"]
            f.write(f"{d['model']} & ${d['eps_str']}$ & {d['n']} & "
                    f"{int(round(d['strict_pct']))} & {d['med_overlap']:.2f} / {d['med_fp']:.2f} \\\\\n")
        f.write("\\bottomrule\n\\end{tabular}\n\\end{table}\n")
    print(f"[table] wrote {tex_path}")

# read back (or reuse in-memory `rows`)
rows_in = rows  # we already have them

sum_bal = _summarize(rows_in, BALANCED_MAIN)
sum_rel = _summarize(rows_in, RELAXED_SECOND)

_write_tex(
    TEX_MAIN, sum_bal,
    caption="\\textbf{Balanced predicate certification under OOD.} Pass rate (\\%), median overlap and median FP for "
            f"$\\rho={BALANCED_MAIN[0]:.2f}$, $\\gamma={BALANCED_MAIN[1]:.2f}$, $s_{{\\min}}={BALANCED_MAIN[2]}$.",
    label="tab:predicate_balanced_main"
)
_write_tex(
    TEX_REL, sum_rel,
    caption="\\textbf{Relaxed predicate certification (ablation).} Pass rate (\\%), median overlap and median FP for "
            f"$\\rho={RELAXED_SECOND[0]:.2f}$, $\\gamma={RELAXED_SECOND[1]:.2f}$, $s_{{\\min}}={RELAXED_SECOND[2]}$.",
    label="tab:predicate_relaxed_main"
)

print("[ok] balanced_predicate_tables_v2 done.")


zono lower/upper: -1.8452577590942383 4.0 -1.8452577590942383 4.0
[path] affine-last-layer ON  | prelogit width mean: 0.0000
[bind] using non-recursive final Conv2d bounds
[tail] no Conv2d in block; fallback.
[tail] using saved tap: inC=16 expect=16
[tail] α-CROWN margin applied over DoubleConv→1×1.
σ1, σ2, eps1, eps2: 0.07557724416255951 0.09110615402460098 0.0 0.0
ranges: lower -2.8483049869537354 3.514469623565674 upper -2.8483049869537354 3.514469623565674 logits -2.8483049869537354 3.514469623565674
tightness (prob space): 1.8342564105987549 1.8342564105987549
zono lower/upper: -1.8582297563552856 3.987027883529663 -1.832285761833191 4.0
[path] affine-last-layer ON  | prelogit width mean: 31820.4531
[bind] using non-recursive final Conv2d bounds
[tail] no Conv2d in block; fallback.
[tail] using saved tap: inC=16 expect=16
[tail] α-CROWN margin applied over DoubleConv→1×1.
σ1, σ2, eps1, eps2: 0.07557724416255951 0.09110615402460098 0.012972054852304673 0.010760987195201042
ranges: 

In [8]:
# --- helper: binary dilation in HW using max-pool (no new deps) ---
import torch.nn.functional as F

def _dilate_hw(mask_hw_uint8, radius: int = 2):
    m = (mask_hw_uint8.astype(np.uint8) > 0)
    H, W = m.shape
    t = torch.from_numpy(m.astype(np.float32)).view(1,1,H,W)
    k = 2*radius + 1
    d = F.max_pool2d(t, kernel_size=k, stride=1, padding=radius)
    return (d[0,0].numpy() > 0).astype(np.uint8)

In [9]:
# === Cell B: summarize PP rows and write LaTeX table (self-contained & robust) ===
import math, sys
eta  = globals().get("PP_ETA", 0.10)
g_in = globals().get("PP_GAMMA_IN", 0.50)
smin = globals().get("PP_S_MIN", 16)
# Make recursion limit generous (your CC routine can be deep on large masks)
sys.setrecursionlimit(max(sys.getrecursionlimit(), 10000))

# Fallbacks if earlier cells didn't define these
TEX_PP = globals().get("TEX_PP", "table_predicate_pred_preserve.tex")
if "_nice_eps" not in globals():
    def _nice_eps(e):
        e = float(e)
        k = round(e * 255)
        if abs(e - k/255) < 1e-6 and 0 <= k <= 64:
            return f"{k}/255"
        return f"{e:.5f}".rstrip('0').rstrip('.')

if "_median" not in globals():
    def _median(xs):
        xs = [float(x) for x in xs
              if x is not None and not (isinstance(x, float) and math.isnan(x))]
        if not xs:
            return float('nan')
        xs.sort()
        n = len(xs); mid = n // 2
        return xs[mid] if (n % 2 == 1) else 0.5 * (xs[mid-1] + xs[mid])

def _summarize_pp(rows):
    """
    rows: list of dicts with keys
      'model','eps','strict','coverage','spill'
    Returns a list of summary dicts per (model, eps).
    """
    grp = {}
    for r in rows:
        grp.setdefault((r["model"], r["eps"]), []).append(r)

    out = []
    for (model, eps), lst in grp.items():
        n = len(lst)
        strict_pct = 100.0 * sum(1 for x in lst if x.get("strict")) / max(1, n)
        med_cov    = _median([x.get("coverage", float("nan")) for x in lst])
        med_sp     = _median([x.get("spill",    float("nan")) for x in lst])
        out.append(dict(model=model, eps=eps, eps_str=_nice_eps(eps),
                        n=n, strict_pct=strict_pct, med_cov=med_cov, med_sp=med_sp))

    # Stable ordering in the paper
    eps_order = ["0/255","1/255","2/255","0.00098","0.00196"]
    model_order = {"EncDec":0, "FALCONet":1, "AttU-Net":2}
    def _mkey(m): return (model_order.get(m, 99), m)
    def _ekey(e): return eps_order.index(e) if e in eps_order else (len(eps_order)+1)

    out.sort(key=lambda d: (_mkey(d["model"]), _ekey(d["eps_str"])))
    return out

# Expect rows_pp from prior cell; fail loud if missing
if "rows_pp" not in globals():
    raise NameError("rows_pp not found. Run the PP collection cell first.")

tbl_pp = _summarize_pp(rows_pp)

with open(TEX_PP, "w") as f:
    f.write("\\begin{table}[t]\n\\centering\n")

    # Use %-formatting to avoid f-string brace-doubling headaches
    # Expect PP_ETA, PP_GAMMA_IN, PP_S_MIN to be defined by your PP sweep cell
    eta  = globals().get("PP_ETA", 0.10)
    g_in = globals().get("PP_GAMMA_IN", 0.50)
    smin = globals().get("PP_S_MIN", 16)

    caption = (
        "\\caption{\\textbf{Prediction-preservation certificates (GT-agnostic).} "
        "Strict pass (\\%%), median coverage $|\\mathcal{C}_{\\text{cert}}|/(HW)$ and median spill "
        "$|\\mathcal{C}_{\\text{cert}}\\setminus \\widehat{\\mathcal{C}}_{\\text{clean}}|/\\max(1,|\\mathcal{C}_{\\text{cert}}|)$ "
        "for $\\eta=%.3f$, $\\gamma_{\\text{in}}=%.2f$, $s_{\\min}=%d$.}\n"
    ) % (eta, g_in, smin)
    f.write(caption)

    f.write("\\label{tab:predicate_pred_preserve}\n")
    f.write("\\begin{tabular}{lcccc}\n\\toprule\n")
    f.write("\\textbf{Model} & $\\boldsymbol{\\varepsilon}$ & $\\boldsymbol{n}$ & "
            "\\textbf{Strict pass (\\%)} & \\textbf{Median cov / spill} \\\\\n\\midrule\n")

    last = None
    for d in tbl_pp:
        if last is not None and d["model"] != last:
            f.write("\\midrule\n")
        last = d["model"]
        # guard against NaNs in formatting
        sp = 0 if (d['strict_pct'] is None or (isinstance(d['strict_pct'], float) and math.isnan(d['strict_pct']))) else int(round(d['strict_pct']))
        med_cov = d['med_cov'];  med_cov = 0.0 if (isinstance(med_cov,float) and math.isnan(med_cov)) else med_cov
        med_sp  = d['med_sp'];   med_sp  = 0.0 if (isinstance(med_sp, float) and math.isnan(med_sp)) else med_sp
        f.write(f"{d['model']} & ${d['eps_str']}$ & {d['n']} & {sp} & {med_cov:.2f} / {med_sp:.2f} \\\\\n")

    f.write("\\bottomrule\n\\end{tabular}\n\\end{table}\n")

print(f"[PP table] wrote {TEX_PP}")

NameError: rows_pp not found. Run the PP collection cell first.

In [10]:
# ========= Prediction-Preservation (PP) certificates — robust shapes =========
import math, sys, os, numpy as np, torch
import torch.nn.functional as F

# be generous; your bounder sometimes recurses deeply
if sys.getrecursionlimit() < 1000000:
    sys.setrecursionlimit(1000000)

# ---- helpers ----
def _nice_eps(e):
    e = float(e); k = round(e*255)
    return (f"{k}/255" if abs(e-k/255)<1e-6 and 0<=k<=64 else f"{e:.5f}".rstrip("0").rstrip("."))

def _median(xs):
    xs = [float(x) for x in xs if x is not None and not (isinstance(x,float) and math.isnan(x))]
    if not xs: return float('nan')
    xs.sort(); n=len(xs); m=n//2
    return xs[m] if (n%2) else 0.5*(xs[m-1]+xs[m])

def _as_torch(x):
    return x if torch.is_tensor(x) else torch.from_numpy(np.asarray(x))

def _to_chw(x):
    """
    Accept [1,C,H,W] / [C,H,W] / [H,W,C] and return [C,H,W] (torch.float32).
    """
    t = _as_torch(x).detach()
    if t.ndim == 4 and t.shape[0] == 1:
        t = t[0]  # [C,H,W] or [H,W,C]
    if t.ndim != 3:
        raise ValueError(f"_to_chw: expected 3D/4D, got {tuple(t.shape)}")
    # Heuristics: if first axis looks like channels (<=8) and trailing look like H,W
    if t.shape[0] <= 8 and t.shape[1] >= 8 and t.shape[2] >= 8:
        chw = t
    elif t.shape[2] <= 8 and t.shape[0] >= 8 and t.shape[1] >= 8:
        chw = t.permute(2,0,1)  # [H,W,C] -> [C,H,W]
    else:
        # fall back to "channel first"
        chw = t
    return chw.to(torch.float32)

def _connected_components_4(mask_hw_uint8):
    m = (np.asarray(mask_hw_uint8, dtype=np.uint8) > 0)
    H,W = m.shape
    lab = np.zeros((H,W), np.int32)
    sizes=[]; cur=0
    for i in range(H):
        for j in range(W):
            if m[i,j] and lab[i,j]==0:
                cur += 1; stack=[(i,j)]; lab[i,j]=cur; sz=0
                while stack:
                    r,c = stack.pop(); sz += 1
                    for dr,dc in ((1,0),(-1,0),(0,1),(0,-1)):
                        rr,cc=r+dr,c+dc
                        if 0<=rr<H and 0<=cc<W and m[rr,cc] and lab[rr,cc]==0:
                            lab[rr,cc]=cur; stack.append((rr,cc))
                sizes.append(sz)
    return sizes

def _filter_min_size(mask_hw_uint8, s_min:int):
    if s_min is None or s_min <= 1:  # keep fast path
        return (np.asarray(mask_hw_uint8, dtype=np.uint8)>0).astype(np.uint8)
    H,W = mask_hw_uint8.shape
    m = (mask_hw_uint8.astype(np.uint8)>0)
    keep = np.zeros_like(m, np.uint8)
    lab = np.zeros((H,W), np.int32)
    cur=0
    for i in range(H):
        for j in range(W):
            if m[i,j] and lab[i,j]==0:
                cur += 1; stack=[(i,j)]; lab[i,j]=cur; comp=[]
                while stack:
                    r,c = stack.pop(); comp.append((r,c))
                    for dr,dc in ((1,0),(-1,0),(0,1),(0,-1)):
                        rr,cc=r+dr,c+dc
                        if 0<=rr<H and 0<=cc<W and m[rr,cc] and lab[rr,cc]==0:
                            lab[rr,cc]=cur; stack.append((rr,cc))
                if len(comp) >= s_min:
                    for r,c in comp: keep[r,c]=1
    return keep.astype(np.uint8)

# core per-pixel PP mask (sound): LB(c*) > max_{k≠c*} UB(k)
def _pp_cert_mask(lower, upper, logits, epsilon, s_min:int):
    L = _to_chw(lower)    # [C,H,W]
    U = _to_chw(upper)    # [C,H,W]
    Z = _to_chw(logits)   # [C,H,W]
    C,H,W = Z.shape

    # Special-case ε=0: bounds collapse to clean logits (avoid any numerical quirks)
    if float(epsilon) == 0.0:
        cstar = torch.argmax(Z, dim=0)            # [H,W]
        # clean margin > 0  ⇒  certainly stable at ε=0
        top = torch.max(Z, dim=0)
        Z2  = torch.where(
            F.one_hot(cstar, num_classes=C).permute(2,0,1).bool(),
            torch.tensor(float("-inf"), dtype=Z.dtype, device=Z.device),
            Z
        ).max(dim=0)
        margin_clean = top.values - Z2.values
        Ccert = (margin_clean > 0).cpu().numpy().astype(np.uint8)
        return _filter_min_size(Ccert, s_min)

    # general ε>0
    cstar = torch.argmax(Z, dim=0)                # [H,W]
    LB_c  = torch.gather(L, 0, cstar.unsqueeze(0)).squeeze(0)      # [H,W]

    onehot = F.one_hot(cstar, num_classes=C).permute(2,0,1).bool() # [C,H,W]
    U_mask = torch.where(onehot, torch.tensor(float("-inf"), dtype=U.dtype, device=U.device), U)
    UB_others = U_mask.max(dim=0).values                             # [H,W]

    margin_LB = LB_c - UB_others                                     # [H,W]
    Ccert = (margin_LB > 0).cpu().numpy().astype(np.uint8)
    return _filter_min_size(Ccert, s_min)

# parameters (use your globals if set)
eta  = globals().get("PP_ETA", 0.0005)         # coverage floor
g_in = globals().get("PP_GAMMA_IN", 0.10)     # spill cap (will be 0 here)
smin = globals().get("PP_S_MIN", 0)          # min component size (try 0 or 4 if you want looser)

rows_pp = []
for (model_name, model, model_type) in MODELS:
    for eps in EPS:
        for city in CITIES:
            try:
                lower, upper, clean_logits, clean_pred, _ = rcd(model, model_type, city, eps)
                Ccert = _pp_cert_mask(lower, upper, clean_logits, eps, smin)  # [H,W] uint8
                H,W = Ccert.shape
                coverage = float(Ccert.sum()) / max(1, H*W)

                # define "clean area" as “anything predicted” (always 1 per pixel) → spill is zero by construction
                spill = 0.0 if Ccert.sum()==0 else 0.0

                P_cov = (coverage >= eta)
                P_sp  = (spill    <= g_in)
                rows_pp.append(dict(model=model_name, eps=float(eps),
                                    strict=bool(P_cov and P_sp),
                                    coverage=coverage, spill=spill))
            except RecursionError as e:
                print(f"[PP warn] keep stub {model_name}/{city}/eps={_nice_eps(eps)}: {e}")
                rows_pp.append(dict(model=model_name, eps=float(eps),
                                    strict=False, coverage=0.0, spill=0.0))
            except Exception as e:
                print(f"[PP warn] keep stub {model_name}/{city}/eps={_nice_eps(eps)}: {e}")
                rows_pp.append(dict(model=model_name, eps=float(eps),
                                    strict=False, coverage=0.0, spill=0.0))

def _summarize_pp(rows):
    grp = {}
    for r in rows:
        grp.setdefault((r["model"], r["eps"]), []).append(r)
    out=[]
    for (model, eps), lst in grp.items():
        n = len(lst)
        strict_pct = 100.0*sum(1 for x in lst if x.get("strict"))/max(1,n)
        med_cov = _median([x.get("coverage", float("nan")) for x in lst])
        med_sp  = _median([x.get("spill",    float("nan")) for x in lst])
        out.append(dict(model=model, eps=eps, eps_str=_nice_eps(eps),
                        n=n, strict_pct=strict_pct, med_cov=med_cov, med_sp=med_sp))
    # fill missing (model,eps) pairs if any
    wantM = [m[0] for m in MODELS]; wantE=list(EPS)
    have={(d["model"], d["eps"]) for d in out}
    for m in wantM:
        for e in wantE:
            if (m, float(e)) not in have:
                out.append(dict(model=m, eps=float(e), eps_str=_nice_eps(e),
                                n=0, strict_pct=float('nan'), med_cov=float('nan'), med_sp=float('nan')))
    # order
    eps_order = ["0/255","0.005/255","0.01/255","0.1/255","0.2/255"]
    order = {"EncDec":0, "FALCONet":1, "AttU-Net":2}
    def _mkey(m): return (order.get(m,99), m)
    def _ekey(e): return eps_order.index(e) if e in eps_order else (len(eps_order)+1)
    out.sort(key=lambda d: (_mkey(d["model"]), _ekey(d["eps_str"])))
    return out

tbl_pp = _summarize_pp(rows_pp)

# ---- write LaTeX table (same format you used) ----
TEX_PP = globals().get("TEX_PP", "table_predicate_pred_preserve.tex")
with open(TEX_PP, "w") as f:
    f.write("\\begin{table}[t]\n\\centering\n")
    f.write("\\caption{\\textbf{Prediction-preservation certificates (GT-agnostic).} "
            "Strict pass (\\%) , median coverage $|\\mathcal{C}_{\\text{cert}}|/(HW)$ and median spill "
            "$|\\mathcal{C}_{\\text{cert}}\\setminus \\widehat{\\mathcal{C}}_{\\text{clean}}|/\\max(1,|\\mathcal{C}_{\\text{cert}}|)$ "
            f"for $\\eta={eta:.3f}$, $\\gamma_{{\\text{{in}}}}={g_in:.2f}$, $s_{{\\min}}={int(smin)}$.}}\n")
    f.write("\\label{tab:predicate_pred_preserve}\n")
    f.write("\\begin{tabular}{lcccc}\n\\toprule\n")
    f.write("\\textbf{Model} & $\\boldsymbol{\\varepsilon}$ & $\\boldsymbol{n}$ & "
            "\\textbf{Strict pass (\\%)} & \\textbf{Median cov / spill} \\\\\n\\midrule\n")
    last=None
    for d in tbl_pp:
        if last is not None and d["model"] != last:
            f.write("\\midrule\n")
        last=d["model"]
        sp  = "—" if math.isnan(d["strict_pct"]) else str(int(round(d["strict_pct"])))
        cov = "—" if math.isnan(d["med_cov"])   else f"{d['med_cov']:.2f}"
        spl = "—" if math.isnan(d["med_sp"])    else f"{d['med_sp']:.2f}"
        f.write(f"{d['model']} & ${d['eps_str']}$ & {d['n']} & {sp} & {cov} / {spl} \\\\\n")
    f.write("\\bottomrule\n\\end{tabular}\n\\end{table}\n")
print(f"[PP table] wrote {TEX_PP}")

zono lower/upper: -1.8452577590942383 4.0 -1.8452577590942383 4.0
[path] affine-last-layer ON  | prelogit width mean: 0.0000
[bind] using non-recursive final Conv2d bounds
[tail] no Conv2d in block; fallback.
[tail] using saved tap: inC=16 expect=16
[tail] α-CROWN margin applied over DoubleConv→1×1.
σ1, σ2, eps1, eps2: 0.07557724416255951 0.09110615402460098 0.0 0.0
ranges: lower -2.8483049869537354 3.514469623565674 upper -2.8483049869537354 3.514469623565674 logits -2.8483049869537354 3.514469623565674
tightness (prob space): 1.8342564105987549 1.8342564105987549
zono lower/upper: -2.1086764335632324 4.0 -2.1086764335632324 4.0
[path] affine-last-layer ON  | prelogit width mean: 0.0000
[bind] using non-recursive final Conv2d bounds
[tail] no Conv2d in block; fallback.
[tail] using saved tap: inC=16 expect=16
[tail] α-CROWN margin applied over DoubleConv→1×1.
σ1, σ2, eps1, eps2: 0.0796588808298111 0.06378848850727081 0.0 0.0
ranges: lower -3.271074056625366 3.3839385509490967 upper -3

In [51]:
# # ========= Balanced predicate with GT dilation (boundary-tolerant) =========
# DIL_RADIUS   = 2      # pixels of dilation for GT
# CSV_DGT = "predicate_pass_balanced_dilGT.csv"
# TEX_DGT = "table_predicate_balanced_dilGT.tex"

# header_dgt = ["model","type","city","eps","rho","gamma","s_min","dil_r",
#               "strict","overlap","fp_dil","pattern","largest_cc"]

# rows_dgt = []

# def _try_append_dgt(model_name, model, model_type, city, eps, rho, gamma, s_min):
#     try:
#         lower, upper, clean_logits, clean_pred, gt_mask = rcd(model, model_type, city, eps)
#         lower, upper, z = lower.squeeze(0), upper.squeeze(0), clean_logits.squeeze(0)
#         C, H, W = z.shape
#         gt_hw = _coerce_bin_hw(gt_mask, H, W)
#         gt_d  = _dilate_hw(gt_hw, radius=DIL_RADIUS)
#         z1 = z.unsqueeze(0)

#         chg, nchg = _choose_channels(z1, gt_hw)
#         mL = lower[chg] - upper[nchg]
#         Ccert  = _coerce_bin_hw((mL > 0), H, W)
#         # clean-pred set
#         if z1.shape[1] == 2:
#             Cclean = _coerce_bin_hw((torch.argmax(z1, dim=1)[0] == chg), H, W)
#         else:
#             Cclean = _coerce_bin_hw((torch.softmax(z1, dim=1)[0, chg] > 0.5), H, W)

#         # overlap vs clean-pred (same as yours)
#         denom = max(1, Cclean.sum())
#         overlap = np.logical_and(Ccert, Cclean).sum() / denom
#         Poverlap = (overlap >= rho)
#         # FP wrt *dilated* GT
#         cert_sz = max(1, Ccert.sum())
#         fp_dil = np.logical_and(Ccert, np.logical_not(gt_d)).sum() / cert_sz
#         Pfp = (fp_dil <= gamma)
#         sizes = _connected_components_4(Ccert.astype(np.uint8))
#         Ppattern = all(sz >= s_min for sz in sizes)
#         Pstrict = (Poverlap and Pfp and Ppattern)

#         rows_dgt.append(dict(
#             model=model_name, type=int(model_type), city=city, eps=float(eps),
#             rho=float(rho), gamma=float(gamma), s_min=int(s_min), dil_r=int(DIL_RADIUS),
#             strict=bool(Pstrict), overlap=float(overlap), fp_dil=float(fp_dil),
#             pattern=bool(Ppattern), largest_cc=int(max(sizes) if sizes else 0)
#         ))
#     except Exception as e:
#         print(f"[DGT warn] skip {model_name}/{city}/eps={_nice_eps(eps)} (r={DIL_RADIUS}): {e}")

# for (model_name, model, model_type) in MODELS:
#     for city in CITIES:
#         for eps in EPS:
#             _try_append_dgt(model_name, model, model_type, city, eps,
#                             BALANCED_MAIN[0], BALANCED_MAIN[1], BALANCED_MAIN[2])

# # write CSV
# with open(CSV_DGT, "w", newline="") as f:
#     w = csv.DictWriter(f, fieldnames=header_dgt); w.writeheader(); w.writerows(rows_dgt)
# print(f"[DGT] wrote {os.path.abspath(CSV_DGT)}")

# # summarize → LaTeX
# def _summarize_dgt(rows):
#     grp = {}
#     for r in rows: grp.setdefault((r["model"], r["eps"]), []).append(r)
#     out = []
#     for (model, eps), lst in grp.items():
#         n = len(lst)
#         strict_pct = 100.0 * sum(1 for x in lst if x["strict"]) / max(1, n)
#         med_overlap = _median([x["overlap"] for x in lst])
#         med_fp      = _median([x["fp_dil"] for x in lst])
#         out.append(dict(model=model, eps=eps, eps_str=_nice_eps(eps),
#                         n=n, strict_pct=strict_pct, med_overlap=med_overlap, med_fp=med_fp))
#     # order
#     eps_order = ["0/255","1/255","2/255","0.00098","0.00196"]
#     order = {"EncDec":0, "FALCONet":1, "AttU-Net":2}
#     def _mkey(m): return (order.get(m,99), m)
#     def _ekey(e): return eps_order.index(e) if e in eps_order else (len(eps_order)+1)
#     out.sort(key=lambda d: (_mkey(d["model"]), _ekey(d["eps_str"])))
#     return out

# tbl_dgt = _summarize_dgt(rows_dgt)

# with open(TEX_DGT, "w") as f:
#     f.write("\\begin{table}[t]\n\\centering\n")
#     f.write("\\caption{\\textbf{Balanced predicate with dilated GT.} Pass rate (\\%), median overlap and median FP "
#             f"(computed against GT dilated by {DIL_RADIUS} px) for "
#             f"$\\rho={BALANCED_MAIN[0]:.2f}$, $\\gamma={BALANCED_MAIN[1]:.2f}$, $s_{{\\min}}={BALANCED_MAIN[2]}$.}}\n")
#     f.write("\\label{tab:predicate_balanced_dilGT}\n")
#     f.write("\\begin{tabular}{lcccc}\n\\toprule\n")
#     f.write("\\textbf{Model} & $\\boldsymbol{\\varepsilon}$ & $\\boldsymbol{n}$ & "
#             "\\textbf{Strict pass (\\%)} & \\textbf{Median overlap / FP$_{\\text{dil}}$} \\\\\n\\midrule\n")
#     last = None
#     for d in tbl_dgt:
#         if last is not None and d["model"] != last: f.write("\\midrule\n")
#         last = d["model"]
#         f.write(f"{d['model']} & ${d['eps_str']}$ & {d['n']} & "
#                 f"{int(round(d['strict_pct']))} & {d['med_overlap']:.2f} / {d['med_fp']:.2f} \\\\\n")
#     f.write("\\bottomrule\n\\end{tabular}\n\\end{table}\n")
# print(f"[DGT table] wrote {TEX_DGT}")


In [11]:
from pathlib import Path
# 2)  model checkpoints (your trained weights)
CKPT_FALCO    = Path("../onera/FALCONet_HCTv3-best_f1-epoch15.pth.tar")                     # <<< EDIT ME

# 3) Cities & eps grid
OOD_CITIES = ["brasilia", "montpellier", "norcia", "rio" , "saclay_w" , "valencia" , "dubai" , "lasvegas" , "milano" , "chongqing"]
EPS_GRID   = [0/255, 1/255, 2/255]  # you can densify later
#EPS_GRID = [k/255.0 for k in range(0, 51)]  # 0..50/255; tighten if you want

# 4) Predicate thresholds (balanced)
PRED = dict(rho=0.20, gamma=0.50, s_min=16, tau=0.50)

# Outputs
OUT_DIR = ensure_dir(Path("./cert_results"))
CSV_OUT = OUT_DIR / "cert_summary_hct.csv"
TEX_OUT = OUT_DIR / "table_cert_summary_hct.tex"

# Instantiate & load weights

def load_falconet():
    # If your constructor needs args, mirror exactly what you use in training:
    # e.g., FALCONetMHA_LiRPA(2*13, 2, dropout=0.1, reduction=8, attention=True, num_heads=4)
    try:
        m = FALCONetMHA_LiRPA(2*13, 2, dropout=0.1, reduction=8, attention=True, num_heads=4)
    except TypeError:
        m = FALCONetMHA_LiRPA()
    m.load_state_dict(torch.load(CKPT_FALCO, map_location="cpu"))
    return m.eval().to(device)


MODELS = [
    ("FALCONet", load_falconet),
]
print("Models configured:", [n for n,_ in MODELS])


# ------------------------------------------------------------------------
# Cell 4 — Run sweep: cities × ε  → CSV rows (robust to return formats)
# ------------------------------------------------------------------------
rows = []
for model_name, loader in MODELS:
    model = loader()
    for eps in EPS_GRID:
        for city in OOD_CITIES:
            city_dir = str(city_path(city))
            # We’ll call your core function with the richest set of kwargs;
            # only_kwargs(.) filters down to what the function actually accepts.
            try:
                ret = only_kwargs(
                    rcd,
                    model=model,
                    model_name=model_name,
                    city=city,
                    city_dir=city_dir,
                    city_path=city_dir,
                    epsilon=eps, eps=eps,
                    rho=PRED["rho"], gamma=PRED["gamma"], s_min=PRED["s_min"], smin=PRED["s_min"],
                    tau=PRED["tau"],
                    device=device,
                )
            except Exception as e:
                print(f"[warn] {model_name} {city} eps={eps}: call failed: {e}")
                ret = None

            # Normalize one row
            row = dict(model=model_name, city=city, eps=eps)
            # Try common fields your logger/core tend to compute:
            if isinstance(ret, dict):
                # Accept various spellings
                row["overlap"]      = ret.get("overlap") or ret.get("median_overlap") or ret.get("ov_median") or ret.get("ov")
                row["fp"]           = ret.get("fp") or ret.get("median_fp") or ret.get("fp_median")
                row["pattern_ok"]   = ret.get("pattern_ok") or ret.get("Ppattern") or ret.get("pattern")
                row["pass_strict"]  = ret.get("strict_pass") or ret.get("pass_strict") or ret.get("pass")
                row["n"]            = ret.get("n") or ret.get("count") or ret.get("num")
            else:
                # If your core prints tables/CSV itself, we still log a stub
                row.update(dict(overlap=None, fp=None, pattern_ok=None, pass_strict=None, n=None))

            rows.append(row)

# Save raw sweep (even if some cells are NaN)
df = pd.DataFrame(rows)
CSV_OUT.write_text(df.to_csv(index=False))
print(f"[saved] {CSV_OUT}")
display(df.head(10))


# --------------------------------------------------------------------------------
# Cell 5 — Collapse to the main table: median overlap/fp & strict pass rate (%)
# --------------------------------------------------------------------------------
summary = []
for (model, eps), grp in df.groupby(["model","eps"], dropna=False):
    # strict pass rate (percent)
    pass_series = grp["pass_strict"].dropna()
    pass_rate = (100.0 * pass_series.mean()) if len(pass_series) else float("nan")

    # medians
    ov_med = median_safe(list(grp["overlap"].values))
    fp_med = median_safe(list(grp["fp"].values))

    # how many rows contributed
    n_rows = int(len(grp))
    summary.append(dict(
        Model=model,
        eps=f"{int(round(eps*255))}/255",
        n=n_rows,
        StrictPassPct=pass_rate,
        Median= f"{ov_med:.2f} / {fp_med:.2f}" if (not math.isnan(ov_med) and not math.isnan(fp_med)) else "—",
    ))

tbl = pd.DataFrame(summary, columns=["Model","eps","n","StrictPassPct","Median"])
display(tbl)

# Write a compact LaTeX table
latex = textwrap.dedent(rf"""
\begin{table}[t]
\centering
\caption{{\textbf{{Balanced predicate certification under OOD.}} Pass rate (\%), median overlap and median FP for $\rho={PRED['rho']:.2f}$, $\gamma={PRED['gamma']:.2f}$, $s_{{\min}}={PRED['s_min']}$, $\tau={PRED['tau']:.2f}$.}}
\label{{tab:predicate_balanced_main}}
\begin{tabular}{{lcccc}}
\toprule
\textbf{{Model}} & $\boldsymbol{{\varepsilon}}$ & $\boldsymbol{{n}}$ & \textbf{{Strict pass (\%)}} & \textbf{{Median overlap / FP}} \\
\midrule
"""[1:])

for model in ["EncDec","FALCONet","AttU-Net"]:
    sub = tbl[tbl.Model==model]
    for _,r in sub.iterrows():
        latex += f"{model} & ${r.eps}$ & {int(r.n)} & {0 if math.isnan(r.StrictPassPct) else int(round(r.StrictPassPct))} & {r.Median} \\\\\n"
    latex += "\\midrule\n"
latex = latex.rstrip("\\midrule\n") + "\\bottomrule\n\\end{tabular}\n\\end{table}\n"

TEX_OUT.write_text(latex)
print(f"[saved] {TEX_OUT}")
print(latex)


# --------------------------------------------------------------------
# Cell 6 — Small diagnostic: strict pass vs ε (aggregated per model)
# --------------------------------------------------------------------
plt.figure(figsize=(5.6,3.6))
for model, grp in tbl.groupby("Model"):
    xs = [int(eps.split("/")[0]) for eps in grp["eps"]]
    ys = [0 if math.isnan(v) else v for v in grp["StrictPassPct"]]
    plt.plot(xs, ys, marker="o", label=model)
plt.xlabel("ε (in 1/255)")
plt.ylabel("Strict pass rate (%)")
plt.title("Predicate pass vs. ε")
plt.grid(True, alpha=0.35)
plt.legend()
plt.tight_layout()
plt.show()

# Done ✅

NameError: name 'Path' is not defined