
# Semantic-Aware Copy–Paste (Writer) — Ant *Face* Preview

This notebook compares the **original** dataset vs the **semantic-aware CP mirror**
generated by `write_data_embed_sa_cp.py` for the class **ant**.
It shows samples that changed from **no-face → has-face**.


In [None]:

import os, torch, numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

ORIG = Path("data_embed_pt")
MIRR = Path("data_embed_pt_sa_cp")
SPLIT = "train"  # 'train' / 'valid' / 'test'

def load_pt(p):
    d = torch.load(p, map_location="cpu")
    return d["img_raw"], d["edis_raw"], d.get("category","unknown"), d


In [None]:

# Face detector used only for preview stats (must match writer's fallback)
AREA_MIN, AREA_MAX, HEAD_BAND_FRAC = 15, 1200, 0.45

def connected_components(mask_np):
    H, W = mask_np.shape
    labels = np.zeros((H, W), dtype=np.int32)
    visited = np.zeros_like(mask_np, dtype=bool)
    sizes = []
    label = 0
    for y in range(H):
        for x in range(W):
            if mask_np[y, x] and not visited[y, x]:
                label += 1
                stack = [(y, x)]
                visited[y, x] = True
                labels[y, x] = label
                sz = 1
                while stack:
                    cy, cx = stack.pop()
                    for ny, nx in ((cy-1,cx),(cy+1,cx),(cy,cx-1),(cy,cx+1)):
                        if 0 <= ny < H and 0 <= nx < W and mask_np[ny, nx] and not visited[ny, nx]:
                            visited[y, nx] = True
                            labels[y, nx] = label
                            stack.append((ny, nx))
                            sz += 1
                sizes.append(sz)
    return labels, sizes

def bbox_of_label(labels, L):
    ys, xs = np.where(labels==L)
    if xs.size==0: return (0,0,labels.shape[1]-1, labels.shape[0]-1)
    return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())

def largest_component_mask(mask_np):
    lab, sizes = connected_components(mask_np)
    if not sizes: return np.zeros_like(mask_np, dtype=np.uint8), lab, sizes, 0
    L = int(np.argmax(sizes)+1)
    return (lab==L).astype(np.uint8), lab, sizes, L

def has_face_tensor(raw01):
    a = raw01 if raw01.dim()==2 else raw01.squeeze(0)
    m = (a.to(torch.float32).numpy() > 0.5).astype(np.uint8)
    body, lab, sizes, Lmax = largest_component_mask(m)
    if Lmax == 0: return False
    x0,y0,x1,y1 = bbox_of_label(lab, Lmax)
    bw, bh = x1-x0+1, y1-y0+1
    if bw >= bh:
        band_w = max(1, int(round(bw*HEAD_BAND_FRAC)))
        head = np.zeros_like(m); head[y0:y1+1, x0:x0+band_w] = 1
        alt  = np.zeros_like(m); alt[y0:y1+1, x1-band_w+1:x1+1] = 1
    else:
        band_h = max(1, int(round(bh*HEAD_BAND_FRAC)))
        head = np.zeros_like(m); head[y0:y0+band_h, x0:x1+1] = 1
        alt  = np.zeros_like(m); alt[y1-band_h+1:y1+1, x0:x1+1] = 1
    for L in range(1, int(max(lab.max(), 0))+1):
        if L == Lmax: continue
        sz = sizes[L-1]
        if sz < AREA_MIN or sz > AREA_MAX: continue
        ys, xs = np.where(lab==L)
        if xs.size==0: continue
        cx, cy = (xs.min()+xs.max())/2, (ys.min()+ys.max())/2
        if head[int(round(cy)), int(round(cx))] or alt[int(round(cy)), int(round(cx))]:
            return True
    return False


In [None]:

orig_dir = ORIG / SPLIT
mirr_dir = MIRR / SPLIT

pairs = []
for f in sorted(orig_dir.glob("*.pt")):
    raw_o, dis_o, cat_o, d_o = load_pt(f)
    if cat_o != "ant": 
        continue
    f_m = mirr_dir / f.name
    if not f_m.exists():
        continue
    raw_m, dis_m, cat_m, d_m = load_pt(f_m)
    pairs.append((f, raw_o, dis_o, d_o, f_m, raw_m, dis_m, d_m))

print("ant pairs:", len(pairs))

# stats + show changed examples
shown = 0
for (f, raw_o, dis_o, d_o, f_m, raw_m, dis_m, d_m) in pairs:
    if not has_face_tensor(raw_o) and has_face_tensor(raw_m):
        r0 = raw_o.squeeze(0).numpy().astype(float)
        r1 = raw_m.squeeze(0).numpy().astype(float)
        d0 = dis_o.squeeze(0).numpy().astype(float)
        d1 = dis_m.squeeze(0).numpy().astype(float)
        diff = np.abs(r1 - r0)

        import matplotlib.pyplot as plt
        fig, axes = plt.subplots(1, 5, figsize=(5*3.2, 3.0))
        for ax in axes: ax.axis('off')
        axes[0].imshow(r0, cmap='gray', vmin=0, vmax=1); axes[0].set_title(f"Original RAW\n{f.name}")
        axes[1].imshow(d0, cmap='gray', vmin=0, vmax=1); axes[1].set_title("Original DIST")
        axes[2].imshow(r1, cmap='gray', vmin=0, vmax=1); axes[2].set_title("SA-CP RAW")
        axes[3].imshow(d1, cmap='gray', vmin=0, vmax=1); axes[3].set_title("SA-CP DIST")
        axes[4].imshow(diff, cmap='gray', vmin=0, vmax=1); axes[4].set_title("|ΔRAW|")
        plt.tight_layout(); plt.show()

        shown += 1
        if shown >= 5: 
            break

if shown == 0:
    print("No qualifying 'no-face → has-face' ant pairs in this split. Try another split or relax thresholds in writer.")
