In [81]:
import numpy as np
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor
import plotly.express as px
from plotly.subplots import make_subplots
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator


In [82]:
# https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time
from time import perf_counter

class catchtime:

    def __init__(self, name=None):
        self.name = name

    def __enter__(self):
        self.start = perf_counter()
        return self

    def __exit__(self, type, value, traceback):
        self.time = perf_counter() - self.start
        if self.name:
            self.readout = f'Time ({self.name}): {self.time:.3f} seconds'
        else:
            self.readout = f'Time: {self.time:.3f} seconds'
        print(self.readout)

In [84]:
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

IMAGE_PATH = "./a2d2-preview/camera_lidar/20180810_150607/camera/cam_front_center/20180810150607_camera_frontcenter_000000083.png"
IMAGE_PROMPTS = ["car", "sidewalk", "tree"]

image = imread(IMAGE_PATH)


fig = make_subplots(rows=1, cols=3)
with catchtime("Load Image"):
    fig.add_trace(px.imshow(image, title="Original Image").data[0], row=1, col=1)

with catchtime("Manual Points"):
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        predictor.set_image(image)
        masks, _, _ = predictor.predict([[331, 804],[1075, 758],[1567.0, 832.0]], [1, 1, 1], multimask_output=True)

    mask_overlay = np.zeros_like(image)
    for mask in masks:
        mask_overlay[mask > 0.0] = True
fig.add_trace(px.imshow(mask_overlay, title="Manual Points").data[0], row=1, col=2)

with catchtime("Automatic Mask Generation"):
    generator = SAM2AutomaticMaskGenerator(predictor.model)
    masks = generator.generate(image)
    mask_overlay = np.zeros_like(image)
    for mask in masks:
        mask_overlay[mask["segmentation"] > 0.0] = True
fig.add_trace(px.imshow(mask_overlay, title="Automatic Points").data[0], row=1, col=3)
fig.show()

Time (Load Image): 0.191 seconds
Time (Manual Points): 0.388 seconds
Time (Automatic Mask Generation): 22.979 seconds


- Time (Load Image): 0.186 seconds
- Time (Manual Points): 0.438 seconds
- Time (Automatic Mask Generation): 23.076 seconds

Conclusion: Bounding boxes from an external model is necessary for performance