In [None]:
import io
from PIL import Image
from datasets import load_dataset
from transformers import SamProcessor, SamModel

def filter_category(data):
    # 16 = dog
    # 23 = giraffe
    return 16 in data["objects"]["category"] or 23 in data["objects"]["category"]

def convert_image(data):
    byte = io.BytesIO(data["image"]["bytes"])
    img = Image.open(byte)
    return {"img": img}

model_name = "facebook/sam-vit-base"
processor = SamProcessor.from_pretrained(model_name) 
model = SamModel.from_pretrained(model_name)

dataset = load_dataset("s076923/coco-val")
filtered_dataset = dataset["validation"].filter(filter_category)
converted_dataset = filtered_dataset.map(convert_image, remove_columns=["image"])

In [None]:
import numpy as np
from matplotlib import pyplot as plt


def show_point_box(image, input_points, input_labels, input_boxes=None, marker_size=375):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    ax = plt.gca()
    
    input_points = np.array(input_points)
    input_labels = np.array(input_labels)

    pos_points = input_points[input_labels[0] == 1]
    neg_points = input_points[input_labels[0] == 0]
    
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25
    )

    if input_boxes is not None:
        for box in input_boxes:
            x0, y0 = box[0], box[1]
            w, h = box[2] - box[0], box[3] - box[1]
            ax.add_patch(
                plt.Rectangle(
                    (x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2
                )
            )

    plt.axis("on")
    plt.show()


image = converted_dataset[0]["img"]
input_points = [[[250, 200]]]
input_labels = [[[1]]]

show_point_box(image, input_points[0], input_labels[0])
inputs = processor(
    image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
)

print("input_points shape :", inputs["input_points"].shape)
print("input_points :", inputs["input_points"])
print("input_labels shape :", inputs["input_labels"].shape)
print("input_labels :", inputs["input_labels"])
print("pixel_values shape :", inputs["pixel_values"].shape)
print("pixel_values :", inputs["pixel_values"])

In [None]:
import torch


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_masks_on_image(raw_image, masks, scores):
    if len(masks.shape) == 4:
        masks = masks.squeeze()
    if scores.shape[0] == 1:
        scores = scores.squeeze()

    nb_predictions = scores.shape[-1]
    fig, axes = plt.subplots(1, nb_predictions, figsize=(30, 15))

    for i, (mask, score) in enumerate(zip(masks, scores)):
        mask = mask.cpu().detach()
        axes[i].imshow(np.array(raw_image))
        show_mask(mask, axes[i])
        axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
        axes[i].axis("off")
    plt.show()


model.eval()
with torch.no_grad():
    outputs = model(**inputs)

masks = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(),
    inputs["original_sizes"].cpu(),
    inputs["reshaped_input_sizes"].cpu(),
)

show_masks_on_image(image, masks[0], outputs.iou_scores)
print("iou_scores shape :", outputs.iou_scores.shape)
print("iou_scores :", outputs.iou_scores)
print("pred_masks shape :", outputs.pred_masks.shape)
print("pred_masks :", outputs.pred_masks)

In [None]:
input_points = [[[250, 200], [15, 50]]]
input_labels = [[[0, 1]]]
input_boxes = [[[100, 100, 400, 600]]]

show_point_box(image, input_points[0], input_labels[0], input_boxes[0])
inputs = processor(
    image,
    input_points=input_points,
    input_labels=input_labels,
    input_boxes=input_boxes,
    return_tensors="pt"
)

model.eval()
with torch.no_grad():
    outputs = model(**inputs)

masks = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(),
    inputs["original_sizes"].cpu(),
    inputs["reshaped_input_sizes"].cpu(),
)

show_masks_on_image(image, masks[0], outputs.iou_scores)

In [None]:
from transformers import pipeline

generator = pipeline("mask-generation", model=model_name)
outputs = generator(image, points_per_batch=32)

plt.imshow(np.array(image))
ax = plt.gca()
for mask in outputs["masks"]:
    show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()

print("outputs mask의 개수 :", len(outputs["masks"]))
print("outputs scores의 개수 :", len(outputs["scores"]))

In [None]:
detector = pipeline(
    model="google/owlv2-base-patch16", task="zero-shot-object-detection"
)

image = converted_dataset[24]["img"]
labels = ["dog", "giraffe"]
results = detector(image, candidate_labels=labels, threshold=0.5)

input_boxes = []
for result in results:
    input_boxes.append(
        [
            result["box"]["xmin"],
            result["box"]["ymin"],
            result["box"]["xmax"],
            result["box"]["ymax"]
        ]
    )
    print(result)

inputs = processor(image, input_boxes=[input_boxes], return_tensors="pt")

model.eval()
with torch.no_grad():
    outputs = model(**inputs)

masks = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(),
    inputs["original_sizes"].cpu(),
    inputs["reshaped_input_sizes"].cpu()
)

plt.imshow(np.array(image))
ax = plt.gca()

for mask, iou in zip(masks[0], outputs.iou_scores[0]):
    max_iou_idx = torch.argmax(iou)
    best_mask = mask[max_iou_idx]
    show_mask(best_mask, ax=ax, random_color=True)

plt.axis("off")
plt.show()