In [1]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

import sys
import cv2
import numpy as np
sys.path.append("..")
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
predictor = SamPredictor(sam)

In [2]:
mask_generator = SamAutomaticMaskGenerator(
    sam,
    points_per_side=16, 
    pred_iou_thresh=0.80,  
    stability_score_thresh=0.80, 
    min_mask_region_area=1000  
)

In [4]:
import os
import cv2
import numpy as np


def calculate_mask_coverage(mask):
    total_pixels = mask.size
    non_zero_pixels = np.sum(mask > 0)
    return non_zero_pixels / total_pixels

# Image_list
file_list = os.listdir("")


output_root = ""
os.makedirs(output_root, exist_ok=True)

for file in file_list:
    image_path = os.path.join("", file)
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    image_name, _ = os.path.splitext(file)
    image_output_dir = os.path.join(output_root, image_name)
    os.makedirs(image_output_dir, exist_ok=True)

    masks = mask_generator.generate(image)

    for i, mask_dict in enumerate(masks):
        mask = mask_dict['segmentation'].astype(np.uint8) * 255

        mask_coverage = calculate_mask_coverage(mask)

        if mask_coverage > 0.95:
            continue

        mask_filename = f"{image_name}_{i}.jpg"
        mask_path = os.path.join(image_output_dir, mask_filename)
        cv2.imwrite(mask_path, mask)


모든 마스크가 저장되었습니다.
