# Grounded Segment Anything

This notebook implements Grounded SAM, which is a pipeline that makes bounding box predictions from input text using Grounding DINO model, and then uses Segment Anything with bounding box prompts to generate the segmentation masks.

Output images with segmentation masks are saved to a specified output folder to visualize results.

## Prepare Environments

In [3]:
%%capture

!pip install segment_anything
!pip install groundingdino-py
!pip install pycocotools pillow numpy

In [4]:
import os

# If you have multiple GPUs, you can set the GPU to use here.
# The default is to use the first GPU, which is usually GPU 0.
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [5]:
import argparse
import os
import copy

import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import box_convert

# Grounding DINO
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.inference import annotate, load_image, predict

import supervision as sv

# segment anything
from segment_anything import build_sam, SamPredictor 
import cv2
import numpy as np
import matplotlib.pyplot as plt

from huggingface_hub import hf_hub_download
from pycocotools.coco import COCO
import pandas as pd



## Load Grounding DINO model

In [6]:
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file) 
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print("Model loaded from {} \n => {}".format(cache_file, log))
    _ = model.eval()
    return model   

In [7]:
# Use this command for evaluate the Grounding DINO model
# Or you can download the model by yourself
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"

In [8]:
groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased
Model loaded from /home/jovyan/.cache/huggingface/hub/models--ShilongLiu--GroundingDINO/snapshots/a94c9b567a2a374598f05c584e96798a170c56fb/groundingdino_swinb_cogcoor.pth 
 => _IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight', 'bert.embeddings.position_ids'])


## Load SAM model

The following command downloads the Segment Anything model.  If you already have the file, skip this! 

In [7]:
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

--2025-02-06 17:19:02--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 13.227.74.12, 13.227.74.9, 13.227.74.118, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|13.227.74.12|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2564550879 (2.4G) [binary/octet-stream]
Saving to: ‘sam_vit_h_4b8939.pth.1’

sam_vit_h_4b8939.pt  16%[==>                 ] 399.92M  88.1MB/s    eta 26s    ^C


In [8]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

sam_checkpoint = 'sam_vit_h_4b8939.pth'
sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)

## Initialize helper functions

In [10]:
import numpy as np


def iou(gtmask, test_mask):
    intersection = np.logical_and(gtmask, test_mask)
    union = np.logical_or(gtmask, test_mask)
    iou_score = np.sum(intersection) / np.sum(union)
    return (iou_score)

# Usually there is a mask for the entire plant in addition to individual leaves. 
# This function attempts to remove the full plant mask by caclculating the iou of each mask and the union of all masks.
def check_full_plant(masks):
    # Initialize the combined mask
    mask_all = np.zeros(masks[0].shape, dtype=np.float32)

    # Combine all masks into one
    for mask in masks:
        mask_all += mask.astype(np.float32)

    iou_withall = []
    # Calculate IoU for each mask with the combined mask
    for mask in masks:
        iou_withall.append(iou(mask, mask_all > 0))

    idx_notall = np.array(iou_withall) < 0.9
    return idx_notall

In [20]:
def show_masks(masks, image, include, random_color=True):
    # Convert image to RGBA
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    
    # Iterate through each mask
    for i in range(masks.shape[0]):
        if (True):
            # print(masks[i])
            mask = masks[i]
            if random_color:
                color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
            else:
                color = np.array([30/255, 144/255, 255/255, 0.6])

            h, w = mask.shape[-2:]
            mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

            mask_image_pil = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGBA")

            # Composite the mask with the image
            annotated_frame_pil = Image.alpha_composite(annotated_frame_pil, mask_image_pil)
    
    return np.array(annotated_frame_pil)

# Run Grounding DINO for detection

In [14]:
def is_bbox_large(bbox, threshold=0.9):
    """
    Check if bbox covers a large portion of the image.

    Parameters:
    - bbox: List representing the bounding box [x, y, width, height].
    - threshold: Threshold percentage for considering a bbox as covering a large portion of the image.

    Returns:
    - True if bbox covers a large portion of the image, False otherwise.
    """
    _, _, width, height = bbox
    bbox_area = width * height
    image_area = 1.0  # Assuming image area is 1 (normalized coordinates)
    return bbox_area >= threshold * image_area

def filter_large_bboxes(boxes, threshold=0.9):
    """
    Filter out bounding boxes that cover a large portion of the image.

    Parameters:
    - boxes: Tensor of bounding boxes in the format (left, top, width, height).
    - image_size: Tuple representing the size of the image (width, height).
    - threshold: Threshold percentage for considering a bbox as covering a large portion of the image.

    Returns:
    - Tensor of bounding boxes that do not cover a large portion of the image.
    """
    filtered_boxes = []
    for bbox in boxes:
        if not is_bbox_large(bbox, threshold):
            filtered_boxes.append(bbox)
    if len(filtered_boxes) > 0:
        return torch.stack(filtered_boxes)
    else:
        # Return an empty tensor with the same shape as the input boxes
        return torch.empty_like(boxes)

In the cell below, update image_dir and output_folder to the following:

+ `image_dir`: Directory where your images are
+ `output_folder`: Directory where images visualizing segmentation results will be saved to

In [1]:
import os
import numpy as np
from PIL import Image

# Directories
image_dir = '/home/jovyan/work/segment_anything/2024-06-04_cropped'
output_folder = '/home/jovyan/work/segment_anything/2025_prompt_test'

# Grounding DINO settings
TEXT_PROMPT = "leaf or small sprouting leaf"
BOX_TRESHOLD = 0.3
TEXT_TRESHOLD = 0.25


# Iterate through all images in the directory
for file_name in os.listdir(image_dir):
    # Check if it's an image file
    if not file_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
        continue
    
    name, ext = os.path.splitext(file_name)
    file_path = os.path.join(image_dir, file_name)
    
    print(f"Processing: {file_path}")
    
    if not os.path.isfile(file_path):
        print('File not found, skipping...')
        continue

    # Load image
    image_source, image = load_image(file_path)

    # Run Grounding DINO predictions
    boxes, logits, phrases, scores = predict(
        model=groundingdino_model, 
        image=image, 
        caption=TEXT_PROMPT, 
        box_threshold=BOX_TRESHOLD, 
        text_threshold=TEXT_TRESHOLD,
        device=DEVICE
    )
    
    H, W, _ = image_source.shape
    boxes = filter_large_bboxes(boxes, threshold=0.9)
    
    if boxes.size(0) == 0:
        print("No boxes detected, skipping...")
        continue

    # Annotate image with bounding boxes
    annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
    annotated_frame = annotated_frame[..., ::-1]  # Convert BGR to RGB
    
    # Save Grouding DINO bbox visualization to image
    dino_result = Image.fromarray(annotated_frame)
    dino_result.save(os.path.join(output_folder, f"{name}_dino_bboxes.png"))
    
    # Predict masks using SAM
    sam_predictor.set_image(image_source)
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
    transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).to(DEVICE)
    
    masks, _, _ = sam_predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    
    # Process and save masks
    masks_cpu = masks.cpu().numpy()
    idx_notall = check_full_plant(masks_cpu)
    
    original_image = Image.open(file_path)
    image_array = np.array(original_image)
    annotated_frame_with_mask = show_masks(masks_cpu, image_array, idx_notall)
    output_image = Image.fromarray(annotated_frame_with_mask)
    output_image.save(os.path.join(output_folder, f"{name}_sam_masks.png"))
    

    print(f"Saved results for: {file_name}")

print("Processing complete.")    
    

Processing: /home/jovyan/work/segment_anything/2024-06-04_cropped/IMG_6056.JPG


NameError: name 'load_image' is not defined