In [None]:
!pip install 'git+https://github.com/facebookresearch/detectron2.git'


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


In [None]:
#!/usr/bin/env python

import os
import json
import random
import cv2
import numpy as np
from PIL import Image, ImageOps
from tqdm.auto import tqdm
import torch
from transformers import pipeline
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo

BASE_DIR        = "/content/drive/MyDrive"
SAMPLE_JSON     = os.path.join(BASE_DIR, "sample_list.json")
HUMAN_NPZ_PATH  = os.path.join(BASE_DIR, "human_mask_cache_224.npz")
REF_NPZ_PATH    = os.path.join(BASE_DIR, "ref_mask_cache_224.npz")
REF_LABELS      = ["a white ball.", "a dark paper."]
SAMPLE_COUNT    = None
OUT_W, OUT_H    = 224, 224

def letterbox_mask(mask: np.ndarray, new_size=(OUT_W, OUT_H)):
    h, w = mask.shape
    nw, nh = new_size
    scale = min(nw / w, nh / h)
    rw, rh = int(w * scale), int(h * scale)
    mask_resized = cv2.resize(mask, (rw, rh), interpolation=cv2.INTER_NEAREST)
    top   = (nh - rh) // 2
    left  = (nw - rw) // 2
    out = np.zeros((nh, nw), dtype=mask.dtype)
    out[top:top+rh, left:left+rw] = mask_resized
    return out

def main():
    with open(SAMPLE_JSON, "r") as f:
        all_samples = json.load(f)

    if SAMPLE_COUNT is not None:
        samples = random.sample(all_samples, min(SAMPLE_COUNT, len(all_samples)))
    else:
        samples = all_samples

    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
    human_predictor = DefaultPredictor(cfg)

    ref_detector = pipeline(
        task="zero-shot-object-detection",
        model="IDEA-Research/grounding-dino-tiny",
        device=0 if torch.cuda.is_available() else -1
    )

    human_cache = {}
    ref_cache   = {}

    for s in tqdm(samples, desc="Building 224×224 mask cache (letterbox)"):
        for p in s["img_paths"][:2]:
            if p in human_cache:
                continue
            if not os.path.exists(p):
                print("Missing file:", p)
                continue

            pil = Image.open(p).convert("RGB")
            pil = ImageOps.exif_transpose(pil)
            arr = np.array(pil)
            H, W = arr.shape[:2]

            out = human_predictor(cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))
            inst = out["instances"]
            hm = np.zeros((H, W), dtype=np.uint8)
            if len(inst) > 0:
                cls = inst.pred_classes.cpu().numpy()
                ids = np.where(cls == 0)[0]
                if len(ids) > 0:
                    scores = inst.scores.cpu().numpy()[ids]
                    idx = ids[scores.argmax()]
                    if inst.has("pred_masks"):
                        hm = inst.pred_masks[idx].cpu().numpy().astype(np.uint8)
                    else:
                        y0, x0, y1, x1 = inst.pred_boxes[idx].tensor.cpu().numpy()[0]
                        hm[int(y0):int(y1), int(x0):int(x1)] = 1
            human_cache[p] = letterbox_mask(hm)

            dets = ref_detector(pil, candidate_labels=REF_LABELS, threshold=0.7)
            det_map = {d["label"].rstrip('.').lower(): d for d in dets}
            for lbl in REF_LABELS:
                k = f"{p}|{lbl.rstrip('.').lower()}"
                rm = np.zeros((H, W), dtype=np.uint8)
                d  = det_map.get(lbl.rstrip('.').lower())
                if d:
                    x0, y0 = int(d["box"]["xmin"]), int(d["box"]["ymin"])
                    x1, y1 = int(d["box"]["xmax"]), int(d["box"]["ymax"])
                    rm[y0:y1, x0:x1] = 1
                ref_cache[k] = letterbox_mask(rm)

    np.savez_compressed(HUMAN_NPZ_PATH, **human_cache)
    np.savez_compressed(REF_NPZ_PATH,   **ref_cache)

    print("Saved human masks to", HUMAN_NPZ_PATH, "entries:", len(human_cache))
    print("Saved ref masks to", REF_NPZ_PATH, "entries:", len(ref_cache))

if __name__ == "__main__":
    main()



In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
from PIL import Image

human_npz_path = '/content/drive/MyDrive/human_mask_cache_224.npz'
ref_npz_path   = '/content/drive/MyDrive/ref_mask_cache_224.npz'

human_masks = np.load(human_npz_path, allow_pickle=True)
ref_masks   = np.load(ref_npz_path, allow_pickle=True)

keys = list(human_masks.files)
selected = random.sample(keys, 5)

for key in selected:
    img = Image.open(key).convert('RGB')
    mask_h = human_masks[key]
    ref_keys = [k for k in ref_masks.files if k.startswith(key + '|')]
    n_cols = 2 + len(ref_keys)
    fig, axes = plt.subplots(1, n_cols, figsize=(4*n_cols, 4))
    axes[0].imshow(img)
    axes[0].axis('off')
    axes[0].set_title('Original')
    axes[1].imshow(mask_h, cmap='gray')
    axes[1].axis('off')
    axes[1].set_title('Human Mask')
    for i, rk in enumerate(ref_keys, start=2):
        mask_r = ref_masks[rk]
        axes[i].imshow(mask_r, cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f'Ref Mask\n{rk.split("|")[-1]}')
    plt.show()
