In [None]:
from transformers import SamModel, SamProcessor
from PIL import Image
from sam_utils import (
    show_masks_on_image,
    get_mask_and_scores_from_sam_input,
    sample_points_from_box,
    show_mask
)
from transformers import pipeline
from multi_tab_bokeh import MultiToolVisualizer
from morphology_utils import smooth_binary_mask

import matplotlib.pyplot as plt
import numpy as np
import os

In [None]:
model_name = "facebook/sam-vit-huge"
sam_model = SamModel.from_pretrained(model_name,).to('cuda')
sam_processor = SamProcessor.from_pretrained(model_name)
mask_generator = pipeline("mask-generation", model=model_name, device='cuda')

In [None]:
image_path = f'{os.getcwd()}/assets/images/Lenna_(test_image).png'
raw_image = Image.open(image_path).convert("RGB")

inputs = sam_processor(raw_image, return_tensors="pt").to('cuda')
image_embeddings = sam_model.get_image_embeddings(inputs["pixel_values"])
automatic_mask_outputs = mask_generator(raw_image, points_per_batch=256, pred_iou_thresh=0.8)

### automatic masks

In [None]:
multi_tab_visualizer = MultiToolVisualizer(raw_image, automatic_mask_outputs["masks"], resize=1, port=5003)
multi_tab_visualizer.serve()

### outputs
- Make your selection(s) on the browser, then come back here and see the outputs.

### hover output

In [None]:
ind = list(multi_tab_visualizer.mask_visualizer.fixed_mask_indices)
mask = automatic_mask_outputs["masks"][ind[0]]
for _ind in ind[1:]:
    mask |= automatic_mask_outputs["masks"][_ind]
mask = smooth_binary_mask(mask, dilate=3, fill_holes=True, erode=6, 
                       kernel_size=1, smooth=True, closing=True)
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.imshow(np.array(raw_image))
show_mask(mask, ax)
ax.axis("off")
plt.subplots_adjust(wspace=0.05, hspace=0.05)
plt.tight_layout()
plt.show()


### click & draw a rectangular box

In [None]:
masks, scores = get_mask_and_scores_from_sam_input(
    sam_model, 
    sam_processor,
    raw_image, 
    image_embeddings,
    input_points=sample_points_from_box(multi_tab_visualizer.box_editor.get_box_coordinates(as_list=True), num_samples_per_box=10))

show_masks_on_image(raw_image, masks[0], scores)

In [None]:
masks, scores = get_mask_and_scores_from_sam_input(
    sam_model, 
    sam_processor,
    raw_image, 
    image_embeddings,
    input_points=[multi_tab_visualizer.tap_selector.get_tap_coordinates(as_list=True)])

show_masks_on_image(raw_image, masks[0], scores)