-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
the efficientSAM model do not support input bounding box-prompt? #40
Comments
@duxuan11 Dear friend, although the official code does not provide an example of import cv2
from torchvision import transforms
import torch
import numpy as np
import argparse
import os
from models.build_efficient_sam import efficient_sam_model_registry
parser = argparse.ArgumentParser(description=("Runs automatic mask generation on an input image or directory of images, "
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
"as well as pycocotools if saving in RLE format."),
)
parser.add_argument("--input", type=str, required=True,
help="Path to either a single input image or folder of images.",
)
parser.add_argument("--output", type=str, required=True,
help=("Path to the directory where masks will be output. Output will be either a folder "
"of PNGs per image or a single json with COCO-style masks."),
)
parser.add_argument("--model-type", type=str, required=True,
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
)
parser.add_argument("--checkpoint", type=str, required=True,
help="The path to the SAM checkpoint to use for mask generation.",
)
parser.add_argument("--device", type=str, default="cuda",
help="The device to run generation on.")
parser.add_argument("--convert-to-rle", action="store_true",
help=("Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
"Requires pycocotools."),
)
parser.add_argument("--show", action="store_true",
help=("To show the segmentation results on the input image."),
)
def main(args):
# Build the EfficientSAM model.
model = efficient_sam_model_registry[args.model_type](checkpoint=args.checkpoint)
# load an image
sample_image_np = cv2.imread("data/images/ex1.jpg")
sample_image_np = cv2.cvtColor(sample_image_np, cv2.COLOR_BGR2RGB)
sample_image_tensor = transforms.ToTensor()(sample_image_np)
# bboxes of the sample
bboxes = [[ 85.7600, 196.6265, 469.7600, 543.6144],
[236.8000, 82.8916, 325.1200, 441.4458]]
# convert the bboxes into the point prompts
num_queries = len(bboxes)
input_points = torch.as_tensor(bboxes).unsqueeze(0) # [bs, num_queries, 4], bs = 1
input_points = input_points.view(-1, num_queries, 2, 2) # [bs, num_queries, num_pts, 2]
input_labels = torch.tensor([2, 3]) # top-left, bottom-right
input_labels = input_labels[None, None].repeat(1, num_queries, 1) # [bs, num_queries, num_pts]
print('Running inference using ',)
predicted_logits, predicted_iou = model(
sample_image_tensor[None, ...],
input_points,
input_labels,
)
sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
# [bs, num_queries, num_candidate_masks, img_h, img_w]
predicted_logits = torch.take_along_dim(
predicted_logits, sorted_ids[..., None, None], dim=2
)
masks = torch.ge(predicted_logits, 0).cpu().detach().numpy()
masks = masks[0, :, 0, :, :] # [num_queries, img_h, img_w]
if args.show:
masked_image_np = cv2.cvtColor(sample_image_np, cv2.COLOR_RGB2BGR)
for i in range(num_queries):
mask = masks[i]
color = [(np.random.randint(255), np.random.randint(255), np.random.randint(255))]
# [H, W] -> [H, W, 1]
mask = np.repeat(mask[..., None], 3, axis=-1)
mask_rgb = mask * color * 0.6
inv_alph_mask = (1 - mask * 0.6)
masked_image_np = (masked_image_np * inv_alph_mask + mask_rgb).astype(np.uint8)
cv2.imshow("masked image", masked_image_np)
cv2.waitKey(0)
# save the results
os.makedirs("outputs/efficient_sam/", exist_ok=True)
masked_image_np = masked_image_np.copy().astype(np.uint8)
cv2.imwrite("outputs/efficient_sam/result.png", masked_image_np)
if __name__ == "__main__":
args = parser.parse_args()
np.random.seed(12)
main(args) |
the efficientSAM model do not support input multi-bounding box?
The text was updated successfully, but these errors were encountered: