In [None]:
import os
import cv2
import numpy as np

def calculate_dice(mask1, mask2):
    mask1 = (mask1 > 0).astype(np.uint8)
    mask2 = (mask2 > 0).astype(np.uint8)

    intersection = np.sum(mask1 * mask2)
    dice = (2. * intersection) / (np.sum(mask1) + np.sum(mask2) + 1e-8)
    return dice

cam_binary_dir =""
all_sam_dir = ""
output_dir = ""
os.makedirs(output_dir, exist_ok=True)

for cam_file in os.listdir(cam_binary_dir):
    cam_path = os.path.join(cam_binary_dir, cam_file)

    cam_mask = cv2.imread(cam_path, cv2.IMREAD_GRAYSCALE)
    if cam_mask is None:
        print(f"Failed to read CAM mask: {cam_path}")
        continue

    # ALL_SAM에서 해당 이미지 이름에 대응되는 폴더 찾기
    sam_folder = os.path.join(all_sam_dir, cam_file.split('.')[0])
    if not os.path.exists(sam_folder):
        print(f"SAM folder not found for: {cam_file}")
        continue

    best_dice = 0
    best_mask_path = None

    for sam_file in os.listdir(sam_folder):
        sam_path = os.path.join(sam_folder, sam_file)

        sam_mask = cv2.imread(sam_path, cv2.IMREAD_GRAYSCALE)
        if sam_mask is None:
            print(f"Failed to read SAM mask: {sam_path}")
            continue

        dice = calculate_dice(cam_mask, sam_mask)
        if dice > best_dice:
            best_dice = dice
            best_mask_path = sam_path

    if best_dice >= 0.5 and best_mask_path is not None:
        selected_mask = cv2.imread(best_mask_path, cv2.IMREAD_GRAYSCALE)
        output_path = os.path.join(output_dir, cam_file)
        cv2.imwrite(output_path, selected_mask)
        print(f"Saved: {output_path} with Dice: {best_dice:.4f}")
    else:
        print(f"Skipped: {cam_file} (Max Dice: {best_dice:.4f})")
