In [3]:
import torch
import cv2
import numpy as np
import os
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# Paths
checkpoint = "C:/Users/User/Desktop/SAM_2/sam2/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "C:/Users/User/Desktop/SAM_2/sam2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
image_path = "C:/Users/User/Desktop/SAM_2/images/can-you-have-too-many-kittens.png"
output_folder = "C:/Users/User/Desktop/SAM_2/segmented_output"
os.makedirs(output_folder, exist_ok=True)

# Check file existence
if not os.path.exists(model_cfg):
    raise FileNotFoundError(f"Config file not found at {model_cfg}")
if not os.path.exists(checkpoint):
    raise FileNotFoundError(f"Checkpoint file not found at {checkpoint}")
if not os.path.exists(image_path):
    raise FileNotFoundError(f"Image not found at {image_path}")

# Load and convert image
image = cv2.imread(image_path)
if image is None:
    raise ValueError(f"Image at {image_path} could not be loaded.")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
clone = image.copy()  # This is used to temporarily draw on

# Variables for drawing
ref_points = []  # List to store multiple bounding boxes
drawing = False

# Mouse callback function
def draw_rectangle(event, x, y, flags, param):
    global ref_points, drawing, clone  # Use clone to not overwrite the original image

    if event == cv2.EVENT_LBUTTONDOWN:
        drawing = True
        ref_points.append([(x, y)])  # Start a new box
    elif event == cv2.EVENT_MOUSEMOVE and drawing:
        temp_img = clone.copy()  # Create a fresh copy of the image each time
        cv2.rectangle(temp_img, ref_points[-1][0], (x, y), (0, 255, 0), 2)
        cv2.imshow("Draw Bounding Box", temp_img)
    elif event == cv2.EVENT_LBUTTONUP:
        drawing = False
        ref_points[-1].append((x, y))  # Finish the current box
        cv2.rectangle(clone, ref_points[-1][0], ref_points[-1][1], (0, 255, 0), 2)
        cv2.imshow("Draw Bounding Box", clone)

# Show image and wait for user input
cv2.namedWindow("Draw Bounding Box")
cv2.setMouseCallback("Draw Bounding Box", draw_rectangle)
cv2.imshow("Draw Bounding Box", clone)

# Main loop for key events
while True:
    key = cv2.waitKey(1) & 0xFF  # Wait for a key press
    if key == ord('u'):  # Press 'u' to undo the last bounding box
        if len(ref_points) > 0:
            ref_points.pop()  # Remove the last drawn bounding box
            clone = image.copy()  # Reset image to original
            for box in ref_points:  # Redraw the remaining bounding boxes
                cv2.rectangle(clone, box[0], box[1], (0, 255, 0), 2)
            cv2.imshow("Draw Bounding Box", clone)
    elif key == 13:  # Press 'Enter' to finish drawing
        break  # Exit the loop once finished

cv2.destroyAllWindows()

# Ensure that the user drew at least one bounding box
if len(ref_points) == 0 or any(len(pt) != 2 for pt in ref_points):
    raise ValueError("At least one bounding box must be drawn.")

# Initialize the predictor
try:
    predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
except Exception as e:
    print(f"Error during SAM2 model initialization: {e}")
    raise

# Get the base filename without extension
base_filename = os.path.splitext(os.path.basename(image_path))[0]

# Run inference for each bounding box
for idx, (pt1, pt2) in enumerate(ref_points):
    bbox = [min(pt1[0], pt2[0]), min(pt1[1], pt2[1]), max(pt1[0], pt2[0]), max(pt1[1], pt2[1])]

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        predictor.set_image(image_rgb)
        masks, scores, _ = predictor.predict(box=bbox)

    if masks is not None and len(masks) > 0:
        # Get the index of the highest confidence mask
        best_idx = scores.argmax().item()
        best_mask = masks[best_idx]

        if best_mask is None or not np.any(best_mask):
            print(f"Best mask for bounding box {idx} is empty. Skipping.")
            continue

        # Create a full-size mask with the same dimensions as the input image
        full_size_mask = np.zeros(image.shape[:2], dtype=np.uint8)

        # Place the best mask within the bounding box
        mask_crop = best_mask[bbox[1]:bbox[3], bbox[0]:bbox[2]]
        full_size_mask[bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask_crop

        # Set the mask region to white (255)
        full_size_mask[full_size_mask > 0] = 255

        # Apply the full-size mask to the original image
        masked = cv2.bitwise_and(image, image, mask=full_size_mask)

        if masked is None or masked.size == 0:
            print(f"Masked image for bbox {idx} is empty. Skipping.")
            continue

        # Save results
        output_filename = f"{base_filename}_segmented_{idx}.png"
        output_full_path = os.path.join(output_folder, output_filename)
        cv2.imwrite(output_full_path, masked)
        print(f"Saved: {output_full_path}")

        mask_filename = f"{base_filename}_mask_{idx}.png"
        mask_full_path = os.path.join(output_folder, mask_filename)
        cv2.imwrite(mask_full_path, full_size_mask)
        print(f"Saved mask: {mask_full_path}")

        # Optional: Display
        cv2.imshow(f"Segmented {idx}", masked)
        cv2.imshow(f"Mask {idx}", full_size_mask)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

Saved: C:/Users/User/Desktop/SAM_2/segmented_output\can-you-have-too-many-kittens_segmented_0.png
Saved mask: C:/Users/User/Desktop/SAM_2/segmented_output\can-you-have-too-many-kittens_mask_0.png
Saved: C:/Users/User/Desktop/SAM_2/segmented_output\can-you-have-too-many-kittens_segmented_1.png
Saved mask: C:/Users/User/Desktop/SAM_2/segmented_output\can-you-have-too-many-kittens_mask_1.png
Saved: C:/Users/User/Desktop/SAM_2/segmented_output\can-you-have-too-many-kittens_segmented_2.png
Saved mask: C:/Users/User/Desktop/SAM_2/segmented_output\can-you-have-too-many-kittens_mask_2.png
