Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the efficientSAM model do not support input bounding box-prompt? #40

Closed
duxuan11 opened this issue Jan 15, 2024 · 1 comment
Closed

Comments

@duxuan11
Copy link

duxuan11 commented Jan 15, 2024

the efficientSAM model do not support input multi-bounding box?

@yjh0410
Copy link

yjh0410 commented Mar 12, 2024

@duxuan11 Dear friend, although the official code does not provide an example of bbox-prompt, referring to the code of the SAM project, we only need to convert bbox to point format and specify labels as 2 (top-left) and 3 (bottom-right). Below, I provide an example, in which I designed two bboxes (xyxy format) to require EfficientSAM to segment two objects. You can refer to this code to implement your own needs (please do not paste and copy directly, because I slightly modified the file structure of the project)...

import cv2
from torchvision import transforms
import torch
import numpy as np
import argparse
import os

from models.build_efficient_sam import efficient_sam_model_registry

parser = argparse.ArgumentParser(description=("Runs automatic mask generation on an input image or directory of images, "
                                              "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
                                              "as well as pycocotools if saving in RLE format."),
                                              )

parser.add_argument("--input", type=str, required=True,
                    help="Path to either a single input image or folder of images.",
                    )

parser.add_argument("--output", type=str, required=True,
                    help=("Path to the directory where masks will be output. Output will be either a folder "
                          "of PNGs per image or a single json with COCO-style masks."),
                    )

parser.add_argument("--model-type", type=str, required=True,
                    help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
                    )

parser.add_argument("--checkpoint", type=str, required=True,
                    help="The path to the SAM checkpoint to use for mask generation.",
                    )

parser.add_argument("--device", type=str, default="cuda",
                    help="The device to run generation on.")

parser.add_argument("--convert-to-rle", action="store_true",
                    help=("Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
                          "Requires pycocotools."),
                    )

parser.add_argument("--show", action="store_true",
                    help=("To show the segmentation results on the input image."),
                    )


def main(args):
    # Build the EfficientSAM model.
    model = efficient_sam_model_registry[args.model_type](checkpoint=args.checkpoint)

    # load an image
    sample_image_np = cv2.imread("data/images/ex1.jpg")
    sample_image_np = cv2.cvtColor(sample_image_np, cv2.COLOR_BGR2RGB)
    sample_image_tensor = transforms.ToTensor()(sample_image_np)

    # bboxes of the sample
    bboxes = [[ 85.7600, 196.6265, 469.7600, 543.6144],
              [236.8000,  82.8916, 325.1200, 441.4458]]
    
    # convert the bboxes into the point prompts
    num_queries = len(bboxes)
    input_points = torch.as_tensor(bboxes).unsqueeze(0)      # [bs, num_queries, 4], bs = 1
    input_points = input_points.view(-1, num_queries, 2, 2)  # [bs, num_queries, num_pts, 2]
    input_labels = torch.tensor([2, 3])  # top-left, bottom-right
    input_labels = input_labels[None, None].repeat(1, num_queries, 1) # [bs, num_queries, num_pts]

    print('Running inference using ',)
    predicted_logits, predicted_iou = model(
        sample_image_tensor[None, ...],
        input_points,
        input_labels,
    )
    sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
    predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
    # [bs, num_queries, num_candidate_masks, img_h, img_w]
    predicted_logits = torch.take_along_dim(
        predicted_logits, sorted_ids[..., None, None], dim=2
    )
    masks = torch.ge(predicted_logits, 0).cpu().detach().numpy()
    masks = masks[0, :, 0, :, :]  # [num_queries, img_h, img_w]
    
    if args.show:
        masked_image_np = cv2.cvtColor(sample_image_np, cv2.COLOR_RGB2BGR)
        for i in range(num_queries):
            mask = masks[i]
            color = [(np.random.randint(255), np.random.randint(255), np.random.randint(255))]
            # [H, W] -> [H, W, 1]         
            mask = np.repeat(mask[..., None], 3, axis=-1)
            mask_rgb = mask * color * 0.6
            inv_alph_mask = (1 - mask * 0.6)
            masked_image_np = (masked_image_np * inv_alph_mask +  mask_rgb).astype(np.uint8)
        cv2.imshow("masked image", masked_image_np)
        cv2.waitKey(0)

    # save the results
    os.makedirs("outputs/efficient_sam/", exist_ok=True)
    masked_image_np = masked_image_np.copy().astype(np.uint8)
    cv2.imwrite("outputs/efficient_sam/result.png", masked_image_np)


if __name__ == "__main__":
    args = parser.parse_args()
    np.random.seed(12)

    main(args)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants