
# Ant **Face Component** Copy–Paste (Semantic-Constrained) — Preview

This notebook keeps the **largest-component constraint** (treating it as the ant body) and
implements a **heuristic face detector** on the template sketch. We then paste the detected face component
into target ant sketches that **lack a face**, and preview *before vs after* (RAW & DIST), including a difference heatmap.


In [1]:

import os, sys, random
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt

from loader import AeDataset  # uses your current loader (no dataset-level CP used here)

random.seed(1); np.random.seed(1); torch.manual_seed(1)


  check_for_updates()


<torch._C.Generator at 0x20dfe4a2c50>

In [2]:

# === Config (auto from train_Embed.py; editable) ===
DB_DIR_DEFAULT = "data_embed_pt"
IMG_SIZE_DEFAULT = 256
PREFIX_DEFAULT = "train"

try:
    import importlib.util, types
    spec = importlib.util.spec_from_file_location('train_Embed', 'train_Embed.py')
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    hp = getattr(mod, 'hyper_params', {})
    DB_DIR_DEFAULT = hp.get('dbDir', DB_DIR_DEFAULT)
    IMG_SIZE_DEFAULT = hp.get('imgSize', IMG_SIZE_DEFAULT)
except Exception as e:
    print("[Info] Could not import train_Embed.py defaults:", e)

DB_DIR = DB_DIR_DEFAULT
IMG_SIZE = IMG_SIZE_DEFAULT
PREFIX = PREFIX_DEFAULT

print("DB_DIR =", DB_DIR, "| IMG_SIZE =", IMG_SIZE, "| PREFIX =", PREFIX)


DB_DIR = data_embed_pt | IMG_SIZE = 256 | PREFIX = train


In [3]:

# === Build dataset (no dataset-level copy-paste; we'll do it here) ===
ds = AeDataset(DB_DIR, [IMG_SIZE, IMG_SIZE], PREFIX, augment=False, sa_copy_paste=False)
print("Num files:", len(ds))
assert len(ds) > 0, "No .pt files found. Please check DB_DIR/PREFIX."


Num files: 25084


In [4]:

# ---- Utilities: connected components, bbox, chamfer distance, largest component, masks ----

def t_to01(t):
    t = torch.as_tensor(t)
    if t.dim()==2: t = t.unsqueeze(0)
    t = t.to(torch.float32)
    if t.max() > 1: t = t / 255.0
    return t.clamp(0,1)

def mask_from_raw(raw01, thr=0.5):
    return (raw01.squeeze(0) > thr).to(torch.uint8).cpu().numpy()

def connected_components(mask_np):
    H, W = mask_np.shape
    labels = np.zeros((H, W), dtype=np.int32)
    visited = np.zeros_like(mask_np, dtype=bool)
    sizes = []
    label = 0
    for y in range(H):
        for x in range(W):
            if mask_np[y, x] and not visited[y, x]:
                label += 1
                stack = [(y, x)]
                visited[y, x] = True
                labels[y, x] = label
                sz = 1
                while stack:
                    cy, cx = stack.pop()
                    for ny, nx in ((cy-1,cx),(cy+1,cx),(cy,cx-1),(cy,cx+1)):
                        if 0 <= ny < H and 0 <= nx < W and mask_np[ny,nx] and not visited[ny,nx]:
                            visited[ny, nx] = True
                            labels[ny, nx] = label
                            stack.append((ny, nx))
                            sz += 1
                sizes.append(sz)
    return labels, sizes

def bbox_of_label(labels, L):
    ys, xs = np.where(labels==L)
    if xs.size==0:
        return (0,0,labels.shape[1]-1, labels.shape[0]-1)
    return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())

def largest_component_mask(mask_np):
    lab, sizes = connected_components(mask_np)
    if not sizes: return np.zeros_like(mask_np, dtype=np.uint8)
    L = int(np.argmax(sizes)+1)
    return (lab==L).astype(np.uint8), lab, sizes, L

def chamfer_distance(mask_np):
    H, W = mask_np.shape
    INF = 10**7
    d = np.full((H, W), INF, np.int32)
    d[mask_np.astype(bool)] = 0
    for y in range(H):
        for x in range(W):
            v = d[y, x]
            if y>0:
                v = min(v, d[y-1,x]+3)
                if x>0: v=min(v, d[y-1,x-1]+4)
                if x<W-1: v=min(v, d[y-1,x+1]+4)
            if x>0: v=min(v, d[y,x-1]+3)
            d[y,x]=v
    for y in range(H-1,-1,-1):
        for x in range(W-1,-1,-1):
            v = d[y, x]
            if y<H-1:
                v = min(v, d[y+1,x]+3)
                if x>0: v=min(v, d[y+1,x-1]+4)
                if x<W-1: v=min(v, d[y+1,x+1]+4)
            if x<W-1: v=min(v, d[y,x+1]+3)
            d[y,x]=v
    d = d.astype(np.float32)
    if d.max()>0: d/=d.max()
    return d


In [5]:

# ---- Heuristic "face" detector for ants ----
AREA_MIN = 15
AREA_MAX = 800
HEAD_BAND_FRAC = 0.35  # fraction along the shorter axis as 'head band'

def detect_face_component_mask(raw01):
    m = mask_from_raw(raw01)
    body_mask, lab, sizes, Lmax = largest_component_mask(m)
    H, W = m.shape
    x0,y0,x1,y1 = bbox_of_label(lab, Lmax)
    bw, bh = x1-x0+1, y1-y0+1

    if bw >= bh:  # wide -> head band on left or right
        band_w = max(1, int(round(bw*HEAD_BAND_FRAC)))
        head_band = np.zeros_like(m, dtype=np.uint8); head_band[y0:y1+1, x0:x0+band_w] = 1
        alt_band  = np.zeros_like(m, dtype=np.uint8); alt_band[y0:y1+1, x1-band_w+1:x1+1] = 1
    else:         # tall -> top or bottom
        band_h = max(1, int(round(bh*HEAD_BAND_FRAC)))
        head_band = np.zeros_like(m, dtype=np.uint8); head_band[y0:y0+band_h, x0:x1+1] = 1
        alt_band  = np.zeros_like(m, dtype=np.uint8); alt_band[y1-band_h+1:y1+1, x0:x1+1] = 1

    cand = []
    for L in range(1, int(max(lab.max(), 0))+1):
        if L == Lmax: continue
        sz = sizes[L-1]
        if sz < AREA_MIN or sz > AREA_MAX: continue
        bx0,by0,bx1,by1 = bbox_of_label(lab, L)
        cx, cy = (bx0+bx1)/2, (by0+by1)/2
        if head_band[int(round(cy)), int(round(cx))] or alt_band[int(round(cy)), int(round(cx))]:
            cand.append((sz, L))

    if not cand:
        return None
    cand.sort(reverse=True)
    Lsel = cand[0][1]
    return (lab==Lsel).astype(np.uint8)


In [6]:

# ---- Paste face component into target (constrained to target's largest component) ----
def paste_component_into_target(target_raw01, comp_mask):
    H, W = target_raw01.shape[-2:]
    tgt_m = mask_from_raw(target_raw01)
    body_mask, lab, sizes, Lmax = largest_component_mask(tgt_m)
    if Lmax == 0:
        return target_raw01, None

    ys, xs = np.where(comp_mask>0)
    if xs.size==0: return target_raw01, None
    cx0, cy0, cx1, cy1 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
    bw, bh = cx1-cx0+1, cy1-cy0+1

    from numpy.lib.stride_tricks import sliding_window_view
    area_mask = body_mask.astype(np.uint8)
    valid = np.zeros_like(area_mask, dtype=bool)
    if H >= bh and W >= bw:
        win = sliding_window_view(area_mask, (bh, bw))
        full = (win.sum(axis=(-2,-1)) == (bh*bw))
        valid[:full.shape[0], :full.shape[1]] = full
    ys2, xs2 = np.where(valid)
    if ys2.size==0:
        return target_raw01, None

    k = np.random.randint(0, ys2.size)
    dy, dx = int(ys2[k]), int(xs2[k])

    raw_aug = target_raw01.clone()
    patch = torch.from_numpy(comp_mask[cy0:cy1+1, cx0:cx1+1].astype(np.float32)).to(raw_aug.device)
    raw_aug[..., dy:dy+bh, dx:dx+bw] = torch.maximum(raw_aug[..., dy:dy+bh, dx:dx+bw], patch.unsqueeze(0))

    new_mask = (raw_aug.squeeze(0).cpu().numpy() > 0.5).astype(np.uint8)
    dist = chamfer_distance(new_mask)
    dist_t = torch.from_numpy(dist).to(raw_aug.device).unsqueeze(0).to(torch.float32)
    return raw_aug, dist_t


In [8]:
# --- Compatibility shim for largest_component_mask ---
def largest_component_mask_compat(mask_np):
    """Always returns (body_mask, lab, sizes, Lmax)."""
    res = largest_component_mask(mask_np)
    # If it's already 4 items, just use first 4
    if isinstance(res, tuple):
        if len(res) >= 4:
            body_mask, lab, sizes, Lmax = res[:4]
            return body_mask.astype(np.uint8), lab, sizes, int(Lmax)
        elif len(res) == 2:
            # Some variants return (lab, sizes)
            lab, sizes = res
            Lmax = int(np.argmax(sizes) + 1) if len(sizes) else 0
            body_mask = (lab == Lmax).astype(np.uint8) if Lmax > 0 else np.zeros_like(lab, dtype=np.uint8)
            return body_mask, lab, sizes, Lmax
        elif len(res) == 1:
            # Just body_mask
            body_mask = res[0].astype(np.uint8)
            # Reconstruct labels/sizes crudely
            lab, sizes = connected_components(body_mask)
            Lmax = int(np.argmax(sizes) + 1) if len(sizes) else 0
            return body_mask, lab, sizes, Lmax
    # Fallback: treat as a binary mask
    body_mask = np.asarray(res).astype(np.uint8)
    lab, sizes = connected_components(body_mask)
    Lmax = int(np.argmax(sizes) + 1) if len(sizes) else 0
    return body_mask, lab, sizes, Lmax

# --- Rebind detect_* to use the compat shim ---
def detect_face_component_mask(raw01):
    m = mask_from_raw(raw01)
    body_mask, lab, sizes, Lmax = largest_component_mask_compat(m)
    H, W = m.shape
    x0,y0,x1,y1 = bbox_of_label(lab, Lmax)
    bw, bh = x1-x0+1, y1-y0+1

    # Build two candidate head-bands along the shorter axis
    HEAD_BAND_FRAC = 0.35
    if bw >= bh:
        band_w = max(1, int(round(bw*HEAD_BAND_FRAC)))
        head_band = np.zeros_like(m, dtype=np.uint8); head_band[y0:y1+1, x0:x0+band_w] = 1
        alt_band  = np.zeros_like(m, dtype=np.uint8); alt_band[y0:y1+1, x1-band_w+1:x1+1] = 1
    else:
        band_h = max(1, int(round(bh*HEAD_BAND_FRAC)))
        head_band = np.zeros_like(m, dtype=np.uint8); head_band[y0:y0+band_h, x0:x1+1] = 1
        alt_band  = np.zeros_like(m, dtype=np.uint8); alt_band[y1-band_h+1:y1+1, x0:x1+1] = 1

    AREA_MIN, AREA_MAX = 15, 800
    cand = []
    for L in range(1, int(max(lab.max(), 0)) + 1):
        if L == Lmax: continue
        sz = sizes[L-1] if L-1 < len(sizes) else 0
        if sz < AREA_MIN or sz > AREA_MAX: continue
        bx0,by0,bx1,by1 = bbox_of_label(lab, L)
        cx, cy = (bx0+bx1)/2, (by0+by1)/2
        if head_band[int(round(cy)), int(round(cx))] or alt_band[int(round(cy)), int(round(cx))]:
            cand.append((sz, L))

    if not cand:
        return None
    cand.sort(reverse=True)
    Lsel = cand[0][1]
    return (lab == Lsel).astype(np.uint8)

In [9]:

# ---- Find template with face, and targets without face ----
def has_face(raw01):
    return detect_face_component_mask(raw01) is not None

def find_indices(ds, want_face=True, max_try=5000):
    idxs = []
    total = len(ds)
    tried = 0
    while len(idxs) < 5 and tried < min(max_try, total):
        i = np.random.randint(0, total)
        raw, dis = ds[i]
        raw01 = t_to01(raw)
        if (detect_face_component_mask(raw01) is not None) == want_face:
            idxs.append(i)
        tried += 1
    return list(dict.fromkeys(idxs))

template_idxs = find_indices(ds, want_face=True)
target_idxs   = find_indices(ds, want_face=False)

print("Templates with face (candidates):", template_idxs[:5])
print("Targets without face (candidates):", target_idxs[:5])
assert template_idxs, "Could not find a template with a detectable face."
assert target_idxs, "Could not find targets without a detectable face."


Templates with face (candidates): [19431, 17032, 24345, 4039, 13252]
Targets without face (candidates): [6713, 12016, 7676, 14546, 16480]


In [10]:

# ---- Preview transfer ----
def preview_transfer(tidx, tgt_indices):
    t_raw, t_dis = ds[tidx]
    t_raw01 = t_to01(t_raw)
    comp = detect_face_component_mask(t_raw01)
    assert comp is not None, "Template has no detectable face by heuristic."
    t_raw_np = t_raw01.squeeze(0).cpu().numpy()

    for idx in tgt_indices:
        s_raw, s_dis = ds[idx]
        s_raw01 = t_to01(s_raw)
        s_dis01 = t_to01(s_dis)

        s_aug_raw, s_aug_dis = paste_component_into_target(s_raw01, comp)
        if s_aug_dis is None:
            print(f"[Skip idx={idx}] cannot place component inside largest target component.")
            continue

        def im(ax, img, title):
            ax.imshow(img, cmap='gray', vmin=0, vmax=1); ax.set_title(title); ax.axis('off')

        raw_diff = (s_aug_raw - s_raw01).abs().squeeze(0).cpu().numpy()
        dis_diff = (s_aug_dis - s_dis01).abs().squeeze(0).cpu().numpy()

        fig, axes = plt.subplots(1, 6, figsize=(6*3.1, 3.2))
        im(axes[0], s_raw01.squeeze(0).cpu().numpy(), f"Target RAW (idx={idx})")
        im(axes[1], s_dis01.squeeze(0).cpu().numpy(), "Target DIST")
        im(axes[2], t_raw_np, f"Template RAW (idx={tidx})")
        im(axes[3], s_aug_raw.squeeze(0).cpu().numpy(), "After paste — RAW")
        im(axes[4], s_aug_dis.squeeze(0).cpu().numpy(), "After paste — DIST")
        im(axes[5], raw_diff, "ΔRAW (abs)")
        plt.tight_layout(); plt.show()

# Run demo
preview_transfer(template_idxs[0], target_idxs[:3])


[Skip idx=6713] cannot place component inside largest target component.
[Skip idx=12016] cannot place component inside largest target component.
[Skip idx=7676] cannot place component inside largest target component.


In [11]:
import torch, os

# 換成你資料夾裡真實的 pt 檔路徑
pt_path = os.path.join("data_embed_pt", "train", "0.pt")
data = torch.load(pt_path, map_location="cpu")

print("Keys in this .pt file:", list(data.keys()))
for k, v in data.items():
    print(f"{k}: type={type(v)}, shape={getattr(v, 'shape', None)}")

Keys in this .pt file: ['img_raw', 'edis_raw']
img_raw: type=<class 'torch.Tensor'>, shape=torch.Size([256, 256])
edis_raw: type=<class 'torch.Tensor'>, shape=torch.Size([256, 256])
