In [None]:
import os
import cv2
import numpy as np
import torch
from segment_anything import sam_model_registry, SamPredictor

# ====== 설정 ======
image_folder = r"C:/Users/hbi/project/images"
label_folder = r"C:/Users/hbi/project/labels"
output_folder = r"C:/Users/hbi/project/output_masks"
sam_checkpoint = r"C:/Users/hbi/project/sam_vit_h.pth"

os.makedirs(output_folder, exist_ok=True)

# SAM 모델 로드 (vit_h 사용)
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam)

# ====== 이미지 처리 루프 ======
for filename in os.listdir(image_folder):
    if not filename.lower().endswith(('.jpg', '.jpeg', '.png')):
        continue

    image_path = os.path.join(image_folder, filename)
    label_path = os.path.join(label_folder, os.path.splitext(filename)[0] + ".txt")

    image_bgr = cv2.imread(image_path)
    if image_bgr is None:
        print(f"[!] 이미지 로드 실패: {filename}")
        continue

    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    height, width = image_rgb.shape[:2]

    if not os.path.exists(label_path):
        print(f"[!] 바운딩박스 파일 없음: {label_path}")
        continue

    with open(label_path, "r") as f:
        lines = f.readlines()

    predictor.set_image(image_rgb)
    person_found = False

    for i, line in enumerate(lines):
        parts = line.strip().split()
        if len(parts) != 5:
            continue

        cls_id, x_center, y_center, w, h = map(float, parts)
        if int(cls_id) != 0:  # 사람(class 0)만 처리
            continue

        x_c, y_c = x_center * width, y_center * height
        box_w, box_h = w * width, h * height
        x1 = int(x_c - box_w / 2)
        y1 = int(y_c - box_h / 2)
        x2 = int(x_c + box_w / 2)
        y2 = int(y_c + box_h / 2)
        input_box = np.array([x1, y1, x2, y2])

        try:
            masks, _, _ = predictor.predict(
                box=input_box[None, :],
                multimask_output=False
            )

            mask = masks[0].astype(np.uint8) * 255
            out_name = f"{os.path.splitext(filename)[0]}_mask_{i}.png"
            out_path = os.path.join(output_folder, out_name)
            cv2.imwrite(out_path, mask)
            print(f"[✓] 마스크 저장됨: {out_name}")
            person_found = True

        except Exception as e:
            print(f"[!] 마스크 생성 실패: {filename} / 에러: {e}")

    if not person_found:
        print(f"[!] 사람 클래스 바운딩박스 없음: {filename}")