In [1]:
import torch
import cv2
import matplotlib.pyplot as plt
import numpy as np
import json
import sys
import os
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "segment-anything/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sp_path = "voc_points"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to("cuda")

predictor = SamPredictor(sam)

In [2]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([2 / 255, 2 / 255, 2 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)



In [3]:
curr_dir = os.getcwd()
voc_root = os.path.join(curr_dir, "VOCtrainval_11-May-2012", "VOCdevkit", "VOC2012")

if not os.path.isdir(voc_root):
    raise RuntimeError("VOC Directory is wrong")

splits_dir = os.path.join(voc_root, "ImageSets", "Segmentation")
splits_f = os.path.join(splits_dir, "{}.txt".format("train"))

with open(os.path.join(splits_f)) as f:
    file_names = [x.strip() for x in f.readlines()]

img_dir = os.path.join(voc_root, "JPEGImages")

In [4]:
for x in file_names:
  point_file = "voc_points/{}_points.json".format(x)
  image_loc = os.path.join(img_dir, "{}.jpg".format(x))
  image = cv2.imread(image_loc)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

  predictor.set_image(image)

  with open(point_file, "r") as input_file:
    whole_input = json.load(input_file)

  if whole_input == [] or len(whole_input) == 0:
    print("input is empty")
    continue

  semantic_mask = []

  for object in whole_input:
    sampled_points = object["sampled_points"]
    cat_id = object["cat_id"]
    print(cat_id)

    s_pts = np.array(sampled_points, dtype=int)
    label = np.ones(len(s_pts), dtype=int)

    masks, scores, logits = predictor.predict(
            point_coords=s_pts,
            point_labels=label,
            multimask_output=True,
        )
    
    mask_list = list(zip(masks, scores))

    best_mask = None

    if max(scores) < 0.95:
            best_mask = masks[np.where(scores == max(scores))[0][0]] * cat_id
    else:
        best_area = 0
        for idx in range(len(mask_list)):
            if mask_list[idx][1] >= 0.95:
                area = mask_list[idx][0].sum()
                if area > best_area:
                    best_mask = masks[idx] * cat_id
                    best_area = area

    semantic_mask.append(best_mask)

    if semantic_mask == []:
        print("no masks")
        continue

    flattened_mask = semantic_mask[0]

    for sm in range(len(semantic_mask) - 1):
        next_layer = semantic_mask[sm + 1]
        flattened_mask[next_layer != 0] = next_layer[next_layer != 0]

    print(flattened_mask)

    final_mask_img = Image.fromarray(flattened_mask.astype("uint8"), mode='P')
    final_mask_img.save("voc_pseudomasks_SAM/{}.png".format(x))



    

1
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
1
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
15
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
15
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
20
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
12
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
9
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
3
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
20
[[0 0 0 ... 0 0 0