In [5]:
import torch
import os
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor
import numpy as np

In [6]:
def generate_masks(image_paths, output_dir, model_type="vit_b", model_checkpoint="sam_vit_b.pth"):
    # Ensure output directory exists
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Load the SAM model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    sam = sam_model_registry[model_type](checkpoint=model_checkpoint)
    sam.to(device)
    predictor = SamPredictor(sam)

    # Process each image
    for image_path in image_paths:
        # Load the image
        image = Image.open(image_path).convert("RGB")
        image_np = np.array(image)

        # Set the image in the predictor
        predictor.set_image(image_np)

        # Generate masks (for simplicity, we use one point in the center of the image)
        height, width = image_np.shape[:2]
        center_point = np.array([[width // 2, height // 2]])
        labels = np.array([1])  # Label the center point as "foreground"

        # Generate mask predictions
        masks, scores, logits = predictor.predict(point_coords=center_point, point_labels=labels)

        # Save each mask as an image
        base_name = os.path.splitext(os.path.basename(image_path))[0]
        for i, mask in enumerate(masks):
            mask_img = Image.fromarray(mask.astype("uint8") * 255)  # Convert to binary mask (0 or 255)
            mask_output_path = os.path.join(output_dir, f"{base_name}_mask_{i}.png")
            mask_img.save(mask_output_path)
            print(f"Saved mask for {image_path} at {mask_output_path}")


In [15]:
image_paths = [
    r"C:\Users\lakho\Desktop\URECA\Whole Slide Images .svs\Level 1\TCGA-A6-2678-01Z-00-DX1.bded5c5c-555a-492a-91c7-151492d0ee5e_level_1.jpg",
]
output_dir = r"C:\Users\lakho\Desktop\URECA\Whole Slide Images .svs\Level 1\Masks"
model_checkpoint = r"C:\Users\lakho\Downloads\sam_vit_b_01ec64.pth"

In [16]:
generate_masks(image_paths, output_dir, model_type="vit_b", model_checkpoint=model_checkpoint)


DecompressionBombError: Image size (306473420 pixels) exceeds limit of 178956970 pixels, could be decompression bomb DOS attack.