In [3]:
from pathlib import Path
import sys, json
from typing import Dict, List, Tuple, Any

import numpy as np
import cv2
import torch
from PIL import Image
from torchvision.transforms.functional import resize
from accelerate import Accelerator

# Label Anything
sys.path.append(str(Path.cwd().parent))
sys.path.append(str(Path.cwd().parent / 'label_anything'))
from label_anything import LabelAnything
from label_anything.data import get_preprocessing, utils
from label_anything.data.transforms import PromptsProcessor

# ==== 定数 ====
ANNOT_DIR = Path.cwd() / "annotations"
IMAGE_DIR = Path.cwd() / "images"
IMAGE_SIZE = 1024
MASK_SIDE = 256
CLASS_NAME_TO_ID = {
    "handrail": 1, "midrail": 2, "toeboard": 3,
    "base_board": 3, "baseboard": 3,  # COCO名の揺れ対応
    "手すり": 1, "中桟": 2, "巾木": 3,
}

# ==== 可視化ユーティリティ（必要なら使用） ====
def draw_masks(img: Image.Image, masks: torch.Tensor, colors):
    masked_image = resize(img.copy(), MASK_SIDE)
    for i, mask in enumerate(masks):
        mask = mask.numpy()
        masked_image = np.where(np.repeat(mask[:, :, np.newaxis], 3, axis=2),
                                np.asarray(colors[i % len(colors)], dtype="uint8"),
                                masked_image)
    masked_image = masked_image.astype(np.uint8)
    return cv2.addWeighted(np.array(resize(img, MASK_SIDE)), 0.3, masked_image, 0.7, 0)

def draw_boxes(img: Image.Image, boxes: torch.Tensor, colors):
    img = np.array(img)
    for i, cat in enumerate(boxes):
        for x1,y1,x2,y2 in cat:
            cv2.rectangle(img, (int(x1),int(y1)), (int(x2),int(y2)), colors[i % len(colors)], 2)
    return img

def draw_points(img: Image.Image, points: torch.Tensor, colors):
    img = np.array(img)
    for i, cat in enumerate(points):
        for x,y in cat:
            cv2.circle(img, (int(x),int(y)), 5, colors[i % len(colors)], -1)
    return img

def draw_all(img: Image.Image, masks, boxes, points, colors):
    segmented_image = draw_masks(img, masks, colors)
    img = Image.fromarray(segmented_image)
    img = resize(img, 1024)
    img = Image.fromarray(draw_boxes(img, boxes, colors))
    img = Image.fromarray(draw_points(img, points, colors))
    return img

def get_image(image_tensor: torch.Tensor) -> Image.Image:
    MEAN = np.array([123.675, 116.280, 103.530]) / 255
    STD  = np.array([58.395,  57.120,  57.375 ]) / 255
    x = image_tensor.numpy()
    x = (x * STD[:, None, None]) + MEAN[:, None, None]
    x = (x * 255).astype(np.uint8)
    return Image.fromarray(np.moveaxis(x, 0, -1))

# ==== JSON ローダ（COCO / 簡易） ====
def _mask_from_polygons(size_hw, polygons):
    H, W = int(size_hw[0]), int(size_hw[1])
    m = np.zeros((H, W), dtype=np.uint8)
    for poly in polygons or []:
        if not poly: continue
        pts = np.array(poly, dtype=np.float32).reshape(-1, 2)
        pts = np.round(pts).astype(np.int32)
        cv2.fillPoly(m, [pts], 255)
    return m

def _mask_from_rle(size_hw, counts):
    from pycocotools import mask as mask_utils  # RLEを使わないなら未インストールでも可
    rle = {"size": [int(size_hw[0]), int(size_hw[1])], "counts": counts}
    m = mask_utils.decode(rle)
    if m.ndim == 3: m = m[..., 0]
    return (m.astype(np.uint8) * 255)

def load_ann_any_json(json_path: Path, expected_stem: str | None = None) -> Dict[str, Dict[int, list]]:
    out = {"bboxes": {}, "points": {}, "masks": {}}
    if not json_path.exists():
        return out

    data = json.load(open(json_path, "r", encoding="utf-8"))
    is_coco = all(k in data for k in ("images","categories","annotations"))

    if not is_coco:
        # 簡易形式: {'bboxes':{'1':[[x1,y1,x2,y2],...]}, 'points':{'2':[[x,y],...]}, 'masks':{'3':[entry,...]}}
        for k in ("bboxes","points"):
            raw = data.get(k, {})
            if isinstance(raw, dict):
                for cid, arr in raw.items():
                    try: cid = int(cid)
                    except: continue
                    if isinstance(arr, list):
                        out[k].setdefault(cid, []).extend(arr)
        raw_masks = data.get("masks", {})
        if isinstance(raw_masks, dict):
            for cid, entries in raw_masks.items():
                try: cid = int(cid)
                except: continue
                for entry in (entries or []):
                    fmt = str(entry.get("format","")).lower()
                    size = entry.get("size", None)
                    if not size: continue
                    if fmt == "polygons":
                        m = _mask_from_polygons(size, entry.get("polygons", []))
                        if (m>0).any(): out["masks"].setdefault(cid, []).append(m)
                    elif fmt == "rle":
                        m = _mask_from_rle(size, entry.get("counts"))
                        if (m>0).any(): out["masks"].setdefault(cid, []).append(m)
        return out

    # COCO形式
    images = data.get("images", [])
    anns   = data.get("annotations", [])

    # stem一致 → 1枚のみならフォールバック → 部分一致
    img_entry = None
    if expected_stem is not None:
        for im in images:
            if Path(str(im.get("file_name",""))).stem == expected_stem:
                img_entry = im; break
    if img_entry is None and len(images)==1:
        img_entry = images[0]
    if img_entry is None and expected_stem is not None:
        for im in images:
            if expected_stem in Path(str(im.get("file_name",""))).stem:
                img_entry = im; break
    if img_entry is None:
        return out

    image_id = img_entry["id"]
    H, W = int(img_entry.get("height", 0)), int(img_entry.get("width", 0))
    tanns = [a for a in anns if a.get("image_id")==image_id]

    def add(dic, cid, v): dic.setdefault(int(cid), []).append(v)

    for a in tanns:
        cid = int(a.get("category_id"))

        # bbox [x,y,w,h] → [x1,y1,x2,y2]
        bb = a.get("bbox")
        if isinstance(bb, (list,tuple)) and len(bb)==4:
            x,y,w,h = bb
            x1,y1,x2,y2 = float(x), float(y), float(x)+float(w), float(y)+float(h)
            if x2>x1 and y2>y1:
                add(out["bboxes"], cid, [x1,y1,x2,y2])

        # keypoints → points（v>0のみ）
        kps = a.get("keypoints")
        if isinstance(kps, list) and len(kps)>=3:
            for i in range(0,len(kps),3):
                xk, yk, v = kps[i], kps[i+1], kps[i+2]
                if v and xk is not None and yk is not None:
                    add(out["points"], cid, [float(xk), float(yk)])

        # segmentation → masks（RLE / polygons, 空配列は無視）
        seg = a.get("segmentation")
        if seg is not None:
            try:
                if isinstance(seg, dict) and "counts" in seg and "size" in seg:
                    m = _mask_from_rle(seg["size"], seg["counts"])
                    if (m>0).any(): add(out["masks"], cid, m)
                elif isinstance(seg, list) and seg and isinstance(seg[0], dict) and "counts" in seg[0]:
                    for r in seg:
                        m = _mask_from_rle(r["size"], r["counts"])
                        if (m>0).any(): add(out["masks"], cid, m)
                elif isinstance(seg, list) and len(seg)>0:
                    m = _mask_from_polygons([H,W], seg)
                    if (m>0).any(): add(out["masks"], cid, m)
            except Exception:
                # pycocotools未インストールでRLEが来た場合などは無視（polygonsを推奨）
                pass

    # 全ゼロマスクの最終除外（保険）
    for cid, arrs in list(out["masks"].items()):
        out["masks"][cid] = [m for m in arrs if isinstance(m, np.ndarray) and (m>0).any()]

    return out

def union_class_ids(dicts_per_support: List[Dict[str, Dict[int, list]]]) -> List[int]:
    s = set()
    for d in dicts_per_support:
        for k in ("bboxes","points","masks"):
            s |= set(d.get(k, {}).keys())
    return sorted(s)

# ==== 小技：空マスク除外 & 重心点補完 ====
def filter_empty_masks(masks_per_img: Dict[int, List[np.ndarray]]):
    for cid, arrs in list(masks_per_img.items()):
        masks_per_img[cid] = [m for m in arrs if isinstance(m, np.ndarray) and m.ndim==2 and (m>0).any()]

def mask_centroid(mask: np.ndarray):
    ys, xs = np.nonzero(mask > 0)
    if len(xs)==0: return None
    return int(xs.mean()), int(ys.mean())

# ==== モデル/前処理 ====
accelerator = Accelerator(cpu=True)
device = accelerator.device

la = LabelAnything.from_pretrained("pasqualedem/label_anything_sam_1024_coco")

img_paths = sorted(list(IMAGE_DIR.glob("*.jpg")) + list(IMAGE_DIR.glob("*.jpeg")) +
                   list(IMAGE_DIR.glob("*.png")) + list(IMAGE_DIR.glob("*.JPG")) + list(IMAGE_DIR.glob("*.PNG")))
assert len(img_paths) >= 2, "画像はクエリ1 + サポート1以上が必要です（./images に配置）。"

def open_rgb(p: Path) -> Image.Image:
    return Image.open(p).convert("RGB")

query_orig = open_rgb(img_paths[0])
support_orig_images = [open_rgb(p) for p in img_paths[1:]]

preprocess = get_preprocessing({"common": {"custom_preprocess": True, "image_size": IMAGE_SIZE}})
query_image   = preprocess(query_orig)
support_images = [preprocess(img) for img in support_orig_images]

support_sizes: List[Tuple[int,int]] = [img.size for img in support_orig_images]  # (W,H)
all_sizes: List[Tuple[int,int]] = [query_orig.size] + support_sizes

prompts_processor = PromptsProcessor(
    long_side_length=IMAGE_SIZE, masks_side_length=MASK_SIDE, custom_preprocess=True
)

# ==== JSON → bboxes/points/masks ====
support_annots: List[Dict[str, Dict[int, list]]] = []
for p in img_paths[1:]:
    ann_path = (ANNOT_DIR / f"{p.stem}.json").resolve()
    support_annots.append(load_ann_any_json(ann_path, expected_stem=p.stem))

cat_ids = union_class_ids(support_annots) or [1,2,3]
cat_ids = sorted(cat_ids)

bboxes_list: List[Dict[int, List[List[float]]]] = []
points_list: List[Dict[int, List[List[float]]]] = []
masks_list : List[Dict[int, List[np.ndarray]]] = []

for ann in support_annots:
    per_b = {cid: list(ann.get("bboxes", {}).get(cid, [])) for cid in cat_ids}
    per_p = {cid: list(ann.get("points", {}).get(cid, [])) for cid in cat_ids}
    per_m = {cid: list(ann.get("masks",  {}).get(cid, [])) for cid in cat_ids}

    # 小技1: 空マスク除外
    filter_empty_masks(per_m)

    # 小技2: マスクがあるのにポイントが無いクラスに重心点を1つ補完
    for cid, arrs in per_m.items():
        if len(arrs)>0 and len(per_p.get(cid, []))==0:
            c = mask_centroid(arrs[0])
            if c is not None:
                per_p.setdefault(cid, []).append(list(c))

    bboxes_list.append(per_b)
    points_list.append(per_p)
    masks_list.append(per_m)

# bbox を LA 形式に変換
converted_bboxes: List[Dict[int, List[List[float]]]] = []
for img_bboxes, orig_img in zip(bboxes_list, support_orig_images):
    out = {}
    for cid, cat_bboxes in img_bboxes.items():
        out[cid] = [prompts_processor.convert_bbox(b, *orig_img.size, noise=False) for b in cat_bboxes]
    converted_bboxes.append(out)

# 背景 -1
bboxes_list_bg = [{**{-1: []}, **bb} for bb in converted_bboxes]
points_list_bg = [{**{-1: []}, **pp} for pp in points_list]
masks_list_bg  = [{**{-1: []}, **mm} for mm in masks_list]
cat_ids_bg = [-1] + cat_ids

# numpy 化（空でも配列化）
for i in range(len(bboxes_list_bg)):
    for cid in cat_ids_bg:
        bboxes_list_bg[i][cid] = np.array(bboxes_list_bg[i][cid], dtype=np.float32)
        points_list_bg[i][cid] = np.array(points_list_bg[i][cid], dtype=np.float32)
        # masks は ndarray のリストのまま（LA側で 256 に整形）

# tensor 化
bboxes, flag_bboxes = utils.annotations_to_tensor(prompts_processor, bboxes_list_bg, support_sizes, utils.PromptType.BBOX)
points, flag_points = utils.annotations_to_tensor(prompts_processor, points_list_bg, support_sizes, utils.PromptType.POINT)
masks,  flag_masks  = utils.annotations_to_tensor(prompts_processor, masks_list_bg,  support_sizes, utils.PromptType.MASK)

flag_examples = utils.flags_merge(flag_bboxes=flag_bboxes, flag_points=flag_points, flag_masks=flag_masks)

# ==== 推論 ====
input_dict = {
    utils.BatchKeys.IMAGES: torch.stack([query_image] + support_images).unsqueeze(0),
    utils.BatchKeys.PROMPT_BBOXES: bboxes.unsqueeze(0),
    utils.BatchKeys.FLAG_BBOXES:   flag_bboxes.unsqueeze(0),
    utils.BatchKeys.PROMPT_POINTS: points.unsqueeze(0),
    utils.BatchKeys.FLAG_POINTS:   flag_points.unsqueeze(0),
    utils.BatchKeys.PROMPT_MASKS:  masks.unsqueeze(0),
    utils.BatchKeys.FLAG_MASKS:    flag_masks.unsqueeze(0),
    utils.BatchKeys.FLAG_EXAMPLES: flag_examples.unsqueeze(0),
    utils.BatchKeys.DIMS: torch.tensor([all_sizes], dtype=torch.int32),
}
def dict_to_device(d, device):
    if isinstance(d, torch.Tensor): return d.to(device)
    if isinstance(d, dict): return {k: dict_to_device(v, device) for k,v in d.items()}
    if isinstance(d, list): return [dict_to_device(v, device) for v in d]
    return d
input_dict = dict_to_device(input_dict, device)

with torch.no_grad():
    output = la(input_dict)
logits = output["logits"]
predictions = torch.argmax(logits, dim=1)

# ==== 簡易ログ ====
print("shapes:",
      "bboxes", tuple(bboxes.shape),
      "points", tuple(points.shape),
      "masks",  tuple(masks.shape))
print("flags:",
      "bbox:",  int(flag_bboxes.sum().item()),
      "point:", int(flag_points.sum().item()),
      "mask:",  int(flag_masks.sum().item()),
      "examples:", int(flag_examples.sum().item()))

# ==== 可視化（任意。colorsは任意の配列） ====
colors = [
    (255,255,0),(255,0,0),(0,255,0),(0,0,255),
    (255,0,255),(0,255,255),(255,165,0)
]
drawn_images = [
    draw_all(get_image(img_t), img_masks, img_bboxes, img_points, colors)
    for img_t, img_masks, img_bboxes, img_points in zip(
        support_images, masks, bboxes, points
    )
]
# Image.fromarray(drawn_images[0]).save("debug_support0_overlay.png")



shapes: bboxes (1, 4, 1, 4) points (1, 4, 7, 2) masks (1, 4, 256, 256)
flags: bbox: 1 point: 15 mask: 1 examples: 4


In [5]:
# ===== Save predictions to ./output =====
out_dir = Path("output")
out_dir.mkdir(exist_ok=True)

# predictions: (B, H, W) を想定（このスクリプトでは B=1）
pred = predictions[0].detach().cpu().numpy().astype(np.int32)  # H x W

# query の元画像（予測と同サイズのはず。万が一違えばリサイズ）
qh, qw = query_orig.size[1], query_orig.size[0]  # (H,W)
if pred.shape != (qh, qw):
    pred = cv2.resize(pred, (qw, qh), interpolation=cv2.INTER_NEAREST)

# インデックス → 色 のルックアップ（背景含めて len(cat_ids)+1）
# idx=0: 背景(-1), idx=1..C: cat_ids の順
cat_ids_with_bg = [-1] + cat_ids
# 好きなパレットに変更可
palette = np.array([
    (0,0,0),       # 背景: 黒
    (255,255,0),   # class 1
    (255,0,0),     # class 2
    (0,255,0),     # class 3
    (0,0,255),
    (255,0,255),
    (0,255,255),
    (255,165,0),
    (255,192,203),
], dtype=np.uint8)

num_needed = pred.max() + 1
if num_needed > len(palette):
    # パレット拡張（足りない分は循環）
    extra = np.vstack([palette[1:]] * ((num_needed // (len(palette)-1)) + 1))
    palette = np.vstack([palette[:1], extra[:num_needed-1]])

# カラーマップ画像（H,W,3）
color_map = palette[np.clip(pred, 0, len(palette)-1)]
# 保存（クラス色だけ）
Image.fromarray(color_map).save(out_dir / "query_pred_colormap.png")

# オーバーレイ（元画像と半透明合成）
query_rgb = np.array(query_orig)
overlay = cv2.addWeighted(query_rgb, 0.5, color_map, 0.5, 0.0)
Image.fromarray(overlay).save(out_dir / "query_pred_overlay.png")

# クラス別の2値マスク（背景はスキップ）
for idx, cid in enumerate(cat_ids_with_bg):
    if cid == -1:  # 背景
        continue
    mask_bin = (pred == idx).astype(np.uint8) * 255
    if mask_bin.any():  # そのクラスが1ピクセルでも存在する時だけ保存
        Image.fromarray(mask_bin).save(out_dir / f"class_{cid}_mask.png")

print(f"[saved] {out_dir}/query_pred_colormap.png")
print(f"[saved] {out_dir}/query_pred_overlay.png")


[saved] output/query_pred_colormap.png
[saved] output/query_pred_overlay.png
