# VOC WSSS: Robustness Explainability Pipeline

This notebook links **explainability seeds** (CAMs/attention) to **weakly‑supervised segmentation (WSSS)** on **PASCAL VOC 2012**, and measures how **perturbations** affect both **seed robustness** and **downstream segmentation** (mIoU).

**Flow**
1) Setup & VOC paths
2) Perturbations (flip, rotation, blur, brightness, Gaussian noise)
3) Models: ResNet‑50 (Grad‑CAM), ViT‑B/16 (Attention Rollout)
4) Seed generation (clean + perturbed) → saved PNG heatmaps
5) Seed robustness metrics (IoU/SSIM)
6) Pseudo‑masks (threshold) → minimal DeepLabV3 segmentation training
7) mIoU evaluation & summary tables


In [10]:
from pathlib import Path

DATA_ROOT = Path("/kaggle/input/pascal-voc-2012-dataset")
VOC_ROOT  = DATA_ROOT / "VOC2012_train_val/VOC2012_train_val"

IMGSETS = VOC_ROOT / "ImageSets"
SEG_DIR = IMGSETS / "Segmentation"
ACT_DIR = IMGSETS / "Action"

OUTPUT = Path("./outputs"); OUTPUT.mkdir(parents=True, exist_ok=True)

print("ImageSets subfolders:", [p.name for p in IMGSETS.iterdir() if p.is_dir()])

# Prefer Segmentation splits; warn if missing and (only if you really must) fall back to Action
if (SEG_DIR / "val.txt").exists():
    SPLIT_DIR = SEG_DIR
    print("Using segmentation splits:", SPLIT_DIR)
elif (ACT_DIR / "val.txt").exists():
    SPLIT_DIR = ACT_DIR
    print("WARNING: 'Segmentation/val.txt' not found. Using 'Action/val.txt' (NOT suitable for mIoU).")
else:
    raise FileNotFoundError("No val.txt in ImageSets/Segmentation or ImageSets/Action.")

def load_ids(split="val"):
    fp = SPLIT_DIR / f"{split}.txt"
    with open(fp, "r") as f:
        ids = [x.strip() for x in f]
    return ids

VAL_IDS = load_ids("val")
print("val count:", len(VAL_IDS), "sample:", VAL_IDS[:8])

ImageSets subfolders: ['Segmentation', 'Main', 'Layout', 'Action']
Using segmentation splits: /kaggle/input/pascal-voc-2012-dataset/VOC2012_train_val/VOC2012_train_val/ImageSets/Segmentation
val count: 1449 sample: ['2007_000033', '2007_000042', '2007_000061', '2007_000123', '2007_000129', '2007_000175', '2007_000187', '2007_000323']


In [11]:
!pip install timm opencv-python scikit-image tqdm matplotlib



In [12]:
import os, json, time, math, shutil
from PIL import Image
import numpy as np
import torch
import torchvision as tv
import torchvision.transforms as T
from torchvision.transforms import functional as TF
from tqdm import tqdm

try:
    import timm
except:
    print("Install timm: pip install timm")

try:
    from skimage.metrics import structural_similarity as ssim
except:
    ssim = None
    print("Install scikit-image for SSIM: pip install scikit-image")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)


Device: cuda


In [30]:
import xml.etree.ElementTree as ET

VOC_CLASSES = [
    "aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair","cow",
    "diningtable","dog","horse","motorbike","person","pottedplant","sheep","sofa","train","tvmonitor"
]

def load_voc_split(split="val"):
    ids_path = VOC_ROOT / "ImageSets" / "Segmentation" / f"{split}.txt"
    with open(ids_path, "r") as f:
        ids = [x.strip() for x in f.readlines()]
    return ids

def voc_image_path(img_id): return VOC_ROOT / "JPEGImages" / f"{img_id}.jpg"
def voc_mask_path(img_id):  return VOC_ROOT / "SegmentationClass" / f"{img_id}.png"
def voc_xml_path(img_id):   return VOC_ROOT / "Annotations" / f"{img_id}.xml"

def voc_image_classes(img_id):
    xml = ET.parse(voc_xml_path(img_id)).getroot()
    return list({obj.find("name").text for obj in xml.findall("object")})

VAL_IDS = load_voc_split("val")

In [31]:
import cv2

def perturbations(img_pil, rot_deg=30, blur_ks=5, bright_factor=1.5, noise_std=0.1):
    out = {}
    out["hflip"] = TF.hflip(img_pil)
    out["rotation"] = img_pil.rotate(rot_deg, resample=Image.BILINEAR, expand=False, fillcolor=(0,0,0))
    img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
    blur = cv2.GaussianBlur(img_cv, (blur_ks, blur_ks), 0)
    out["blur"] = Image.fromarray(cv2.cvtColor(blur, cv2.COLOR_BGR2RGB))
    out["brightness"] = T.functional.adjust_brightness(img_pil, bright_factor)
    arr = np.array(img_pil).astype(np.float32)/255.0
    noise = np.random.normal(0, noise_std, arr.shape).astype(np.float32)
    noisy = np.clip(arr + noise, 0, 1)
    out["gauss"] = Image.fromarray((noisy*255).astype(np.uint8))
    return out

preprocess = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])


In [32]:
def load_resnet50():
    try:
        model = tv.models.resnet50(weights=tv.models.ResNet50_Weights.IMAGENET1K_V2)
    except:
        model = tv.models.resnet50(pretrained=True)
    return model.eval().to(DEVICE)

def load_vit_b16():
    try:
        model = timm.create_model("vit_base_patch16_224", pretrained=True)
        return model.eval().to(DEVICE)
    except:
        print("timm ViT not available"); return None

resnet = load_resnet50()
vit    = load_vit_b16()
print("Loaded:", bool(resnet), bool(vit))


Loaded: True True


In [33]:
class GradCAM:
    def __init__(self, model, target_layer_name="layer4"):
        self.model = model; self.model.eval()
        self.target_layer = dict([*model.named_modules()])[target_layer_name]
        self.activations = None; self.gradients = None
        def fwd_hook(m, inp, out): self.activations = out.detach()
        def bwd_hook(m, gin, gout): self.gradients = gout[0].detach()
        self.target_layer.register_forward_hook(fwd_hook)
        self.target_layer.register_full_backward_hook(bwd_hook)

    def __call__(self, img_tensor, target_class=None):
        x = img_tensor.to(DEVICE).unsqueeze(0)
        logits = self.model(x)
        if target_class is None:
            target_class = logits.argmax(dim=1).item()
        self.model.zero_grad(set_to_none=True)
        logits[0, target_class].backward()
        w = self.gradients.mean(dim=(2,3), keepdim=True)
        cam = (w * self.activations).sum(dim=1)
        cam = torch.relu(cam); cam = cam - cam.min()
        if cam.max() > 0: cam = cam / cam.max()
        return cam.squeeze().cpu().numpy(), int(target_class)

gradcam = GradCAM(resnet, "layer4")
print("Grad-CAM ready.")


Grad-CAM ready.


In [34]:
import math
import torch

def vit_attention_rollout(model, x, head_fusion="mean"):
    """
    Robust attention rollout for timm ViTs.
    - Hooks blk.attn.qkv
    - Computes attention per layer (softmax(QK^T / sqrt(d)))
    - Adds residual + row-normalize each layer
    - Multiplies attentions across layers in order
    - Returns a HxW heatmap for the patch grid
    """
    qkv_per_layer, hooks = [], []

    def hook_qkv(m, inp, out):
        # out: [B, N, 3*dim]
        qkv_per_layer.append(out.detach().cpu())

    # Register on each transformer block
    for blk in model.blocks:
        h = blk.attn.qkv.register_forward_hook(hook_qkv)
        hooks.append(h)

    with torch.no_grad():
        _ = model(x)

    for h in hooks:
        h.remove()

    if len(qkv_per_layer) == 0:
        raise RuntimeError("No qkv captured; check model type or hook path.")

    # Build per-layer attentions [N, N]
    att_layers = []
    for qkv in qkv_per_layer:
        B, N, threeC = qkv.shape
        assert B == 1, "Rollout code assumes batch size 1"
        num_heads = model.blocks[0].attn.num_heads
        head_dim  = (threeC // 3) // num_heads

        # [B,N,3C] -> [3,B,H,N,D] -> q,k,v: [B,H,N,D]
        qkv = qkv.reshape(B, N, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
        q, k = qkv[0], qkv[1]

        # Attention per head: [B,H,N,N]
        att = (q @ k.transpose(-2, -1)) / math.sqrt(head_dim)
        att = torch.softmax(att, dim=-1)  # softmax along last dim

        # Fuse heads -> [N,N]
        if head_fusion == "mean":
            att = att.mean(dim=1)[0]
        elif head_fusion == "max":
            att = att.max(dim=1)[0][0]
        elif head_fusion == "min":
            att = att.min(dim=1)[0][0]
        else:
            raise ValueError("head_fusion must be mean|max|min")

        eye = torch.eye(att.size(-1))
        att = att + eye
        att = att / att.sum(dim=-1, keepdim=True)
        att_layers.append(att)

    # Rollout across layers in order
    joint = att_layers[0]
    for l in range(1, len(att_layers)):
        joint = att_layers[l] @ joint

    # CLS-to-patch vector
    N = joint.size(0)               
    vec = joint[0]                  
    def vec_to_map(v):
        m = int(round(math.sqrt(len(v))))
        if m * m != len(v):
            return None
        return v.reshape(m, m)

    grid = vec[1:]                 
    hm = vec_to_map(grid.numpy())
    if hm is None:
        grid2 = vec[2:]
        hm = vec_to_map(grid2.numpy())
        if hm is None:
            raise ValueError(f"Cannot reshape attention vector of length {len(vec)} "
                             "to a square patch grid. Check tokens/patch size.")

    # Normalize to [0,1]
    hm = (hm - hm.min()) / (hm.max() - hm.min() + 1e-8)
    return hm

In [35]:
import cv2, pandas as pd

SEEDS_DIR = OUTPUT / "seeds"
(SEEDS_DIR / "resnet_gradcam").mkdir(parents=True, exist_ok=True)
(SEEDS_DIR / "vit_attnroll").mkdir(parents=True, exist_ok=True)

def save_heatmap(hm, out_path, size=(224,224)):
    arr = (hm*255).astype(np.uint8)
    if size is not None:
        arr = cv2.resize(arr, size, interpolation=cv2.INTER_LINEAR)
    Image.fromarray(arr).save(out_path)

def generate_seeds(img_id, do_resnet=True, do_vit=True):
    img = Image.open(voc_image_path(img_id)).convert("RGB")
    x = preprocess(img).unsqueeze(0).to(DEVICE)

    # clean
    if do_resnet:
        cam, _ = gradcam(x.squeeze(0))
        save_heatmap(cam, SEEDS_DIR/"resnet_gradcam"/f"{img_id}_clean.png")
    if do_vit and vit is not None:
        camv = vit_attention_rollout(vit, x)
        save_heatmap(camv, SEEDS_DIR/"vit_attnroll"/f"{img_id}_clean.png")

    # perturbed
    for name, pimg in perturbations(img).items():
        px = preprocess(pimg).unsqueeze(0).to(DEVICE)
        if do_resnet:
            cam_p, _ = gradcam(px.squeeze(0))
            save_heatmap(cam_p, SEEDS_DIR/"resnet_gradcam"/f"{img_id}_{name}.png")
        if do_vit and vit is not None:
            camv_p = vit_attention_rollout(vit, px)
            save_heatmap(camv_p, SEEDS_DIR/"vit_attnroll"/f"{img_id}_{name}.png")

for img_id in tqdm(VAL_IDS, desc="Seeds (subset)"):
    generate_seeds(img_id)
print("Done subset — expand VAL_IDS for full run.")


Seeds (subset): 100%|██████████| 1449/1449 [10:51<00:00,  2.23it/s]

Done subset — expand VAL_IDS for full run.





In [36]:
def load_gray(path): return np.array(Image.open(path).convert("L")).astype(np.float32)/255.0
def binarize(a, thr=0.3): return (a >= thr).astype(np.uint8)
def iou(a, b):
    inter = np.logical_and(a>0, b>0).sum()
    union = np.logical_or(a>0, b>0).sum()
    return inter/union if union>0 else 0.0

PERTS = ["hflip","rotation","blur","brightness","gauss"]

def seed_similarity(method_dir: Path, ids):
    rows = []
    for img_id in ids:
        c = load_gray(method_dir/f"{img_id}_clean.png")
        cb = binarize(c, 0.3)
        for p in PERTS:
            pp = load_gray(method_dir/f"{img_id}_{p}.png")
            pb = binarize(pp, 0.3)
            rows.append({
                "id": img_id, "method": method_dir.name, "perturb": p,
                "IoU": float(iou(cb, pb)),
                "SSIM": float(ssim(c, pp, data_range=1.0)) if ssim else None
            })
    return pd.DataFrame(rows)

df_res = seed_similarity(SEEDS_DIR/"resnet_gradcam", VAL_IDS)
df_vit = seed_similarity(SEEDS_DIR/"vit_attnroll", VAL_IDS) if (SEEDS_DIR/"vit_attnroll").exists() else pd.DataFrame()
df_all = pd.concat([df_res, df_vit], ignore_index=True)
display(df_all.groupby(["method","perturb"]).agg({"IoU":"mean","SSIM":"mean"}))
df_all.to_csv(OUTPUT/"seed_similarity_subset.csv", index=False)
print("Saved:", (OUTPUT/"seed_similarity_subset.csv").resolve())


Unnamed: 0_level_0,Unnamed: 1_level_0,IoU,SSIM
method,perturb,Unnamed: 2_level_1,Unnamed: 3_level_1
resnet_gradcam,blur,0.787784,0.860329
resnet_gradcam,brightness,0.792918,0.862774
resnet_gradcam,gauss,0.698969,0.772518
resnet_gradcam,hflip,0.45229,0.563718
resnet_gradcam,rotation,0.486979,0.554929
vit_attnroll,blur,0.823758,0.974954
vit_attnroll,brightness,0.504413,0.823859
vit_attnroll,gauss,0.640053,0.904561
vit_attnroll,hflip,0.237995,0.579485
vit_attnroll,rotation,0.222032,0.49734


Saved: /kaggle/working/outputs/seed_similarity_subset.csv


In [37]:
PSEUDO_DIR = OUTPUT / "pseudo_masks"
(PSEUDO_DIR / "resnet_gradcam").mkdir(parents=True, exist_ok=True)
(PSEUDO_DIR / "vit_attnroll").mkdir(parents=True, exist_ok=True)

def seed_to_mask(seed_path, thr=0.3):
    a = load_gray(seed_path)
    m = (a >= thr).astype(np.uint8) * 255
    return Image.fromarray(m)

# Example: clean seeds → pseudo masks (subset)
for img_id in VAL_IDS:
    sp = SEEDS_DIR/"resnet_gradcam"/f"{img_id}_clean.png"
    if sp.exists():
        seed_to_mask(sp, thr=0.3).save(PSEUDO_DIR/"resnet_gradcam"/f"{img_id}.png")


In [41]:
from torchvision.models.segmentation import deeplabv3_resnet50
from torch.utils.data import Dataset, DataLoader

class VOCPseudoDataset(Dataset):
    def __init__(self, ids, image_root, mask_root):
        self.ids = ids; self.image_root = image_root; self.mask_root = mask_root
        self.img_tf = T.Compose([T.Resize((256,256)), T.ToTensor(),
                                 T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
        self.mask_tf = T.Compose([T.Resize((256,256), interpolation=T.InterpolationMode.NEAREST)])

    def __len__(self): return len(self.ids)
    def __getitem__(self, i):
        img_id = self.ids[i]
        img = Image.open(self.image_root/f"{img_id}.jpg").convert("RGB")
        mask_p = self.mask_root/f"{img_id}.png"
        if not mask_p.exists():
            # empty mask if missing
            mask = Image.fromarray(np.zeros((256,256), dtype=np.uint8))
        else:
            mask = Image.open(mask_p).convert("L")
        return self.img_tf(img), torch.from_numpy(np.array(self.mask_tf(mask))).long(), img_id

train_ids = load_voc_split("trainval")[:200]
val_ids   = load_voc_split("val")[:200]

train_ds = VOCPseudoDataset(train_ids, VOC_ROOT/"JPEGImages", PSEUDO_DIR/"resnet_gradcam")
val_ds   = VOCPseudoDataset(val_ids,   VOC_ROOT/"JPEGImages", PSEUDO_DIR/"resnet_gradcam")
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2)
val_dl   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=2)

model = deeplabv3_resnet50(weights=None, num_classes=2).to(DEVICE)  
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
ce  = torch.nn.CrossEntropyLoss(ignore_index=255)

def train_one_epoch():
    model.train(); total=0
    for x, y, _ in train_dl:
        x = x.to(DEVICE); y = y.to(DEVICE)
        opt.zero_grad()
        out = model(x)["out"]
        loss = ce(out, (y>0).long())
        loss.backward(); opt.step()
        total += loss.item()*x.size(0)
    return total/len(train_dl.dataset)

@torch.no_grad()
def evaluate_miou():
    model.eval(); inter=0; union=0
    for x, y, _ in val_dl:
        x = x.to(DEVICE); y = y.to(DEVICE)
        pred = model(x)["out"].argmax(1)
        # IoU for foreground (class 1)
        p = (pred==1); g = (y>0)
        inter += (p & g).sum().item()
        union += (p | g).sum().item()
    miou = inter/union if union>0 else 0.0
    return miou

for ep in range(15):  
    tl = train_one_epoch()
    miou = evaluate_miou()
    print(f"Epoch {ep+1}: loss={tl:.4f}  mIoU_fg={miou:.3f}")


Epoch 1: loss=0.4739  mIoU_fg=0.435
Epoch 2: loss=0.3016  mIoU_fg=0.358
Epoch 3: loss=0.2182  mIoU_fg=0.478
Epoch 4: loss=0.1881  mIoU_fg=0.490
Epoch 5: loss=0.1386  mIoU_fg=0.513
Epoch 6: loss=0.1050  mIoU_fg=0.502
Epoch 7: loss=0.0911  mIoU_fg=0.501
Epoch 8: loss=0.0857  mIoU_fg=0.539
Epoch 9: loss=0.0751  mIoU_fg=0.519
Epoch 10: loss=0.0692  mIoU_fg=0.521
Epoch 11: loss=0.0597  mIoU_fg=0.532
Epoch 12: loss=0.0533  mIoU_fg=0.552
Epoch 13: loss=0.0499  mIoU_fg=0.514
Epoch 14: loss=0.0489  mIoU_fg=0.552
Epoch 15: loss=0.0434  mIoU_fg=0.551
