In [2]:
import torch
import cv2
import numpy as np
import os
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# Paths
checkpoint = "C:/Users/User/Desktop/SAM_2/sam2/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "C:/Users/User/Desktop/SAM_2/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
image_path = "C:/Users/User/Desktop/SAM_2/images/can-you-have-too-many-kittens.png"
output_folder = "C:/Users/User/Desktop/SAM_2/segmented_output"
os.makedirs(output_folder, exist_ok=True)

# Load image
image = cv2.imread(image_path)
if image is None:
    raise ValueError(f"Image at {image_path} could not be loaded.")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
clone = image.copy()

# Points and labels
click_points = []
point_labels = []

# Mouse callback
def mouse_click(event, x, y, flags, param):
    global click_points, point_labels, clone

    if event == cv2.EVENT_LBUTTONDOWN:
        click_points.append([x, y])
        point_labels.append(1)  # Label 1 = foreground
        cv2.circle(clone, (x, y), 5, (0, 255, 0), -1)
        cv2.imshow("Select Points", clone)
    elif event == cv2.EVENT_RBUTTONDOWN:
        click_points.append([x, y])
        point_labels.append(0)  # Label 0 = background (optional)
        cv2.circle(clone, (x, y), 5, (0, 0, 255), -1)
        cv2.imshow("Select Points", clone)

# Display image and collect points
cv2.namedWindow("Select Points")
cv2.setMouseCallback("Select Points", mouse_click)
cv2.imshow("Select Points", clone)

while True:
    key = cv2.waitKey(1) & 0xFF
    if key == ord('u') and len(click_points) > 0:
        click_points.pop()
        point_labels.pop()
        clone = image.copy()
        for pt, label in zip(click_points, point_labels):
            color = (0, 255, 0) if label == 1 else (0, 0, 255)
            cv2.circle(clone, tuple(pt), 5, color, -1)
        cv2.imshow("Select Points", clone)
    elif key == 13:  # Enter key
        break

cv2.destroyAllWindows()

if len(click_points) == 0:
    raise ValueError("At least one point must be selected.")

# Load model
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
base_filename = os.path.splitext(os.path.basename(image_path))[0]

# Run inference
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(image_rgb)
    point_coords = np.array(click_points)
    point_labels_np = np.array(point_labels)
    masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels_np)

# Get best mask by confidence score
if masks is None or len(masks) == 0 or scores is None:
    raise ValueError("No masks were returned by the model.")

best_idx = scores.argmax().item()
best_mask = masks[best_idx]

if best_mask is None or not np.any(best_mask):
    raise ValueError("Best mask is empty.")

# Save best result
full_size_mask = (best_mask * 255).astype(np.uint8)
masked = cv2.bitwise_and(image, image, mask=full_size_mask)

if masked is None or masked.size == 0:
    raise ValueError("Masked image is empty.")

output_filename = f"{base_filename}_segmented_point_best.png"
output_mask_filename = f"{base_filename}_mask_point_best.png"

cv2.imwrite(os.path.join(output_folder, output_filename), masked)
cv2.imwrite(os.path.join(output_folder, output_mask_filename), full_size_mask)

print(f"Saved best segmented image: {output_filename}")
print(f"Saved best mask: {output_mask_filename}")

# Optional: Show results
cv2.imshow("Best Segmented", masked)
cv2.imshow("Best Mask", full_size_mask)
cv2.waitKey(0)
cv2.destroyAllWindows()

Saved best segmented image: can-you-have-too-many-kittens_segmented_point_best.png
Saved best mask: can-you-have-too-many-kittens_mask_point_best.png
