In [None]:
import os
import cv2
import numpy as np
from sam2.modeling import SamPredictor
from sam2.utils.visualization import show_masks
from sam2.utils.helpers import load_checkpoint

# Step 1: Load SAM2 model
model_checkpoint = "checkpoints/sam2.1_hiera_large.pt"  # Path to your SAM2 checkpoint
sam_model = load_checkpoint(model_checkpoint)
predictor = SamPredictor(sam_model)

# Step 2: Input and output folder setup
input_folder = "notebooks/images"  # Your folder with input images
output_folder = "notebooks/output_masks"  # Folder to save the masks
os.makedirs(output_folder, exist_ok=True)

# Step 3: Process each image
for filename in os.listdir(input_folder):
    if filename.startswith("color_") and filename.endswith((".jpg", ".png")):
        print(f"Processing {filename}...")
        image_path = os.path.join(input_folder, filename)

        # Load the image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB

        # Run SAM2 prediction
        predictor.set_image(image)
        input_point = np.array([[100, 100]])  # Example point (adjust as needed)
        input_label = np.array([1])  # Positive point label
        masks, scores = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False
        )

        # Save the mask as a black-and-white image
        mask_output = (masks[0] * 255).astype(np.uint8)
        mask_path = os.path.join(output_folder, f"{os.path.splitext(filename)[0]}_mask.png")
        cv2.imwrite(mask_path, mask_output)

        # Optional: Visualize the mask
        print(f"Saved mask for {filename} to {mask_path}")
