In [None]:
import torch
import cv2
import numpy as np
import os
import cv2
import scipy
from ultralytics import YOLO

# 模型與圖片路徑
model_path = "models/yolo_finetuned/best.pt"

# 只保留這些類別 ID（根據 data.yaml 順序）
allowed_classes = [2]  # 只要床單

# 載入模型
model = YOLO(model_path)

def extract_mask_compare(image_path):
    image_name = os.path.basename(image_path)
    # 推論
    results = model(image_path)

    # 原圖
    orig_img = cv2.imread(image_path)
    h, w = orig_img.shape[:2]

    # 空白遮罩
    mask_all = np.zeros((h, w), dtype=np.uint8)
    for r in results:
        if r.masks is None:
            continue
        masks = r.masks.data.cpu().numpy()     # [N, H_pred, W_pred]
        classes = r.boxes.cls.cpu().numpy()    # [N] 物件的類別 ID
        for m, cls_id in zip(masks, classes):
            if int(cls_id) not in allowed_classes:
                continue  # 跳過不在清單內的類別
            m = (m * 255).astype(np.uint8)
            m = cv2.resize(m, (w, h), interpolation=cv2.INTER_NEAREST)
            mask_all = cv2.bitwise_or(mask_all, m)
    masked_image = orig_img.copy()
    masked_image[mask_all==0] = 0
    cv2.imwrite("results-bed-images/" + image_name, masked_image.astype(np.uint8))

img_files = os.listdir("bed-images/jpg_out/")
for img_file in img_files:
    if img_file.endswith('.jpg'):
        image_path = os.path.join("bed-images/jpg_out", img_file)
        extract_mask_compare(image_path)
        print(f"Processed {img_file}")