
# Ant Face Copy–Paste — Writer Mirror Preview

Run `write_data_embed_cp.py` first to generate:
- `data_embed_pt/` (original)
- `data_embed_pt_cp/` (semantic CP mirror)

This notebook filters to `category == "ant"`, finds targets **without a face** in original,
and shows **before/after** from the mirror.


In [1]:

import os, torch, numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

ORIG = Path("data_embed_pt")
MIRR = Path("data_embed_pt_cp")

def load_pt(p):
    d = torch.load(p, map_location="cpu")
    raw = d["img_raw"]
    dis = d["edis_raw"]
    cat = d.get("category", "unknown")
    return raw, dis, cat, d


In [2]:

SPLIT = "train"  # 'train' / 'valid' / 'test'
orig_dir = ORIG / SPLIT
mirr_dir = MIRR / SPLIT

files = sorted(f for f in orig_dir.glob("*.pt"))
pairs = []
for f in files:
    raw, dis, cat, d = load_pt(f)
    if cat != "ant":
        continue
    mirr = mirr_dir / f.name
    if not mirr.exists():
        continue
    raw_m, dis_m, cat_m, dm = load_pt(mirr)
    pairs.append((f, raw, dis, d, mirr, raw_m, dis_m, dm))

print("ant pairs:", len(pairs))


ant pairs: 0


In [3]:

AREA_MIN, AREA_MAX, HEAD_BAND_FRAC = 15, 800, 0.35

def connected_components(mask_np):
    H, W = mask_np.shape
    labels = np.zeros((H, W), dtype=np.int32)
    visited = np.zeros_like(mask_np, dtype=bool)
    sizes = []
    label = 0
    for y in range(H):
        for x in range(W):
            if mask_np[y, x] and not visited[y, x]:
                label += 1
                stack = [(y, x)]
                visited[y, x] = True
                labels[y, x] = label
                sz = 1
                while stack:
                    cy, cx = stack.pop()
                    for ny, nx in ((cy-1,cx),(cy+1,cx),(cy,cx-1),(cy,cx+1)):
                        if 0 <= ny < H and 0 <= nx < W and mask_np[ny, nx] and not visited[ny, nx]:
                            visited[ny, nx] = True
                            labels[ny, nx] = label
                            stack.append((ny, nx))
                            sz += 1
                sizes.append(sz)
    return labels, sizes

def bbox_of_label(labels, L):
    ys, xs = np.where(labels==L)
    if xs.size==0: return (0,0,labels.shape[1]-1, labels.shape[0]-1)
    return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())

def largest_component_mask(mask_np):
    lab, sizes = connected_components(mask_np)
    if not sizes: return np.zeros_like(mask_np, dtype=np.uint8), lab, sizes, 0
    L = int(np.argmax(sizes)+1)
    return (lab==L).astype(np.uint8), lab, sizes, L

def has_face(raw):
    raw01 = raw if raw.dim()==2 else raw.squeeze(0)
    raw01 = raw01.to(torch.float32).numpy()
    m = (raw01 > 0.5).astype(np.uint8)
    body_mask, lab, sizes, Lmax = largest_component_mask(m)
    if Lmax == 0: return False
    x0,y0,x1,y1 = bbox_of_label(lab, Lmax)
    bw, bh = x1-x0+1, y1-y0+1
    if bw >= bh:
        band_w = max(1, int(round(bw*HEAD_BAND_FRAC)))
        head_band = np.zeros_like(m, dtype=np.uint8); head_band[y0:y1+1, x0:x0+band_w] = 1
        alt_band  = np.zeros_like(m, dtype=np.uint8); alt_band[y0:y1+1, x1-band_w+1:x1+1] = 1
    else:
        band_h = max(1, int(round(bh*HEAD_BAND_FRAC)))
        head_band = np.zeros_like(m, dtype=np.uint8); head_band[y0:y0+band_h, x0:x1+1] = 1
        alt_band  = np.zeros_like(m, dtype=np.uint8); alt_band[y1-band_h+1:y1+1, x0:x1+1] = 1
    for L in range(1, int(max(lab.max(), 0))+1):
        if L == Lmax: continue
        sz = sizes[L-1]
        if sz < AREA_MIN or sz > AREA_MAX: continue
        bx0,by0,bx1,by1 = bbox_of_label(lab, L)
        cx, cy = (bx0+bx1)/2, (by0+by1)/2
        if head_band[int(round(cy)), int(round(cx))] or alt_band[int(round(cy)), int(round(cx))]:
            return True
    return False


In [4]:

# Show up to 5 no-face -> has-face pairs
shown = 0
for (f, raw, dis, d, g, raw_m, dis_m, dm) in pairs:
    if not has_face(raw) and has_face(raw_m):
        r0 = raw.squeeze(0).numpy().astype(float)
        r1 = raw_m.squeeze(0).numpy().astype(float)
        d0 = dis.squeeze(0).numpy().astype(float)
        d1 = dis_m.squeeze(0).numpy().astype(float)
        diff = np.abs(r1 - r0)

        import matplotlib.pyplot as plt
        fig, axes = plt.subplots(1, 5, figsize=(5*3.2, 3.0))
        for ax in axes: ax.axis('off')
        axes[0].imshow(r0, cmap='gray', vmin=0, vmax=1); axes[0].set_title(f"Original RAW\n{f.name}")
        axes[1].imshow(d0, cmap='gray', vmin=0, vmax=1); axes[1].set_title("Original DIST")
        axes[2].imshow(r1, cmap='gray', vmin=0, vmax=1); axes[2].set_title("CP RAW")
        axes[3].imshow(d1, cmap='gray', vmin=0, vmax=1); axes[3].set_title("CP DIST")
        axes[4].imshow(diff, cmap='gray', vmin=0, vmax=1); axes[4].set_title("|ΔRAW|")
        plt.tight_layout(); plt.show()

        shown += 1
        if shown >= 5: break

if shown == 0:
    print("No qualifying 'no-face → has-face' ant pairs found in the selected split. Try another split.")


No qualifying 'no-face → has-face' ant pairs found in the selected split. Try another split.
