
## Setup

### Download packages and models

In [None]:
!pip install ipympl --quiet
!pip install git+https://github.com/facebookresearch/segment-anything.git --quiet

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
!wget -q https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/320px-Felis_catus-cat_on_snow.jpg -O "cat.jpg"

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m511.6/511.6 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m53.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for segment-anything (setup.py) ... [?25l[?25hdone


### SAM initialization

In [None]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

device = 'cuda' # @param ["cuda", "cpu"]
model_type = "vit_b" # @param ["vit_b", "vit_h"]

sam_checkpoint = {"vit_b": "sam_vit_b_01ec64.pth",
                  "vit_h": "sam_vit_h_4b8939.pth"}[model_type]

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

sam_predictor = SamPredictor(sam)
sam_mask_generator = SamAutomaticMaskGenerator(sam)

### Utils

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    return ax.imshow(mask_image)

def show_points(coords, labels, ax):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='blue', marker='.')
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='.')

def float2uint8(x: np.ndarray):
    return (x*255).astype(np.uint8)

## Annotation tool

* Upload file by dragging it into files area
* Input file name to input box and run cell below
* Left-click to add mask
* Right click to remove area
* Segmentation mask is saved as `mask_[original_file_name].png` each time the segmentation changes
* Re-run cell to reset

In [None]:
%matplotlib ipympl
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.backend_bases import MouseButton
from google.colab import output

output.enable_custom_widget_manager()

input_file = "cat.jpg" # @param {type:"string"}

image = cv2.imread(input_file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

sam_predictor.set_image(image)

fig = plt.figure(figsize=(12, 10))
ax = plt.gca()
ax.imshow(image)
ax.axis("off")

pos_points = []
neg_points = []

def onclick(event):
    ax.cla()
    ax.imshow(image)
    ax.axis("off")
    ix, iy = event.xdata, event.ydata
    if event.button == MouseButton.LEFT:
        pos_points.append((ix, iy))
    elif event.button == MouseButton.RIGHT:
        neg_points.append((ix, iy))

    input_point = np.array(neg_points + pos_points)
    input_label = np.array([0]*len(neg_points) + [1]*len(pos_points))
    sam_mask, _, _ = sam_predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False
    )
    show_points(input_point, input_label, ax)
    show_mask(sam_mask, ax)
    cv2.imwrite("mask_" + input_file, cv2.cvtColor(float2uint8(sam_mask[0,...]), cv2.COLOR_GRAY2BGR))

cid = fig.canvas.mpl_connect('button_press_event', onclick)