Loading the Segment Anything Model

In [1]:
import torch
from segment_anything import sam_model_registry

CHECKPOINT_PATH = 'sam_vit_l_0b3195.pth'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_l"
print(DEVICE)

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
sam.to(device=DEVICE)

cpu


Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-23): 24 x Block(
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1024, out_features=3072, bias=True)
          (proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=1024, out_features=4096, bias=True)
          (lin2): Linear(in_features=4096, out_features=1024, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d

Load Image

In [6]:
import cv2
from segment_anything import SamAutomaticMaskGenerator

mask_generator = SamAutomaticMaskGenerator(sam)

image_bgr = cv2.imread("dog_image.jpg")
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
result = mask_generator.generate(image_rgb)

Supervision

In [14]:
import supervision as sv
from supervision.annotators.utils import ColorLookup

mask_annotator = sv.MaskAnnotator(color_lookup = ColorLookup.INDEX)
detections = sv.Detections.from_sam(result)
annotated_image = mask_annotator.annotate(image_bgr, detections)

print(annotated_image)

[[[202 157 120]
  [202 157 120]
  [202 157 120]
  ...
  [132 112 192]
  [198 160 128]
  [198 160 128]]

 [[201 156 119]
  [201 156 119]
  [201 156 119]
  ...
  [132 112 192]
  [199 161 129]
  [198 160 128]]

 [[201 156 119]
  [201 156 119]
  [201 156 119]
  ...
  [132 113 192]
  [199 161 129]
  [199 161 129]]

 ...

 [[220 114 102]
  [230 124 112]
  [236 130 118]
  ...
  [109 124 150]
  [112 126 154]
  [106 123 150]]

 [[218 112 100]
  [232 126 114]
  [230 124 112]
  ...
  [106 121 147]
  [105 119 147]
  [105 121 150]]

 [[224 118 106]
  [216 110  98]
  [230 124 112]
  ...
  [103 118 144]
  [106 120 148]
  [118 134 163]]]


Generate Segmentation Mask with Bounding Box

In [16]:
import cv2
from segment_anything import SamPredictor
import numpy as np

mask_predictor = SamPredictor(sam)

image_bgr = cv2.imread("dog_image.jpg")
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
mask_predictor.set_image(image_rgb)

box = np.array([70, 247, 626, 926])
masks, scores, logits = mask_predictor.predict(
    box=box,
    multimask_output=True
)