In [1]:
import torch
import cv2
import matplotlib.pyplot as plt
import numpy as np
import json
import sys
import os
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "segment-anything/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sp_path = "voc_points"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to("cuda")

predictor = SamPredictor(sam)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([2 / 255, 2 / 255, 2 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)



In [3]:
curr_dir = os.getcwd()
voc_root = os.path.join(curr_dir, "VOCtrainval_11-May-2012", "VOCdevkit", "VOC2012")

if not os.path.isdir(voc_root):
    raise RuntimeError("VOC Directory is wrong")

splits_dir = os.path.join(voc_root, "ImageSets", "Segmentation")
splits_f = os.path.join(splits_dir, "{}.txt".format("train"))

with open(os.path.join(splits_f)) as f:
    file_names = [x.strip() for x in f.readlines()]

img_dir = os.path.join(voc_root, "JPEGImages")

In [4]:
for x in file_names:
  point_file = "voc_points/{}_points.json".format(x)
  image_loc = os.path.join(img_dir, "{}.jpg".format(x))
  image = cv2.imread(image_loc)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

  predictor.set_image(image)

  with open(point_file, "r") as input_file:
    whole_input = json.load(input_file)

  if whole_input == [] or len(whole_input) == 0:
    print("input is empty")
    continue

  semantic_mask = []

  for object in whole_input:
    sampled_points = object["sampled_points"]
    if "box" in object:
      sampled_box = object["box"]
      box = np.array(sampled_box, dtype=int)
    cat_id = object["cat_id"]
    print(cat_id)

    s_pts = np.array(sampled_points, dtype=int)
    label = np.ones(len(s_pts), dtype=int)

    if "box" in object:
      masks, scores, logits = predictor.predict(
              point_coords=s_pts,
              point_labels=label,
              box=box,
              multimask_output=True,
          )
    else:
      masks, scores, logits = predictor.predict(
              point_coords=s_pts,
              point_labels=label,
              multimask_output=True,
          )
    
    mask_list = list(zip(masks, scores))

    best_mask = None

    if max(scores) < 0.95:
            best_mask = masks[np.where(scores == max(scores))[0][0]] * cat_id
    else:
        best_area = 0
        for idx in range(len(mask_list)):
            if mask_list[idx][1] >= 0.95:
                area = mask_list[idx][0].sum()
                if area > best_area:
                    best_mask = masks[idx] * cat_id
                    best_area = area

    semantic_mask.append(best_mask)

    if semantic_mask == []:
        print("no masks")
        continue

    flattened_mask = semantic_mask[0]

    for sm in range(len(semantic_mask) - 1):
        next_layer = semantic_mask[sm + 1]
        flattened_mask[next_layer != 0] = next_layer[next_layer != 0]

    # print(flattened_mask)

    final_mask_img = Image.fromarray(flattened_mask.astype("uint8"), mode='P')
    final_mask_img.save("voc_pseudomasks_Gdino_SAM/{}.png".format(x))
    

1
1
15
15
20
12
9
3
20
20
5
5
5
5
15
15
4
4
4
1
11
5
1
19
3
3
14
14
15
13
15
15
15
15
1
10
15
15
15
2
15
7
7
8
8
2
18
3
3
15
15
15
15
6
1
7
4
4
12
15
14
1
6
6
15
15
15
15
15
15
15
15
2
6
14
14
15
13
8
8
10
15
13
9
9
9
9
18
2
20
10
10
10
10
10
10
18
20
16
8
15
11
5
12
12
9
15
15
12
17
17
17
17
17
17
13
15
16
11
9
4
6
6
5
5
11
9
9
9
9
4
4
20
15
15
14
13
10
10
10
10
10
10
8
12
11
7
15
15
15
15
15
15
17
9
9
9
11
18
10
10
10
13
6
6
6
15
12
15
18
10
1
14
15
1
3
15
15
15
15
1
3
6
6
20
2
4
4
13
15
7
15
15
15
16
15
18
18
9
9
11
15
15
7
7
15
3
4
19
15
14
15
5
18
18
15
12
15
15
15
11
15
15
15
15
16
10
10
15
10
8
8
15
10
10
10
10
10
10
10
10
7
7
7
7
17
17
17
17
17
17
15
15
3
11
15
5
5
5
20
20
15
15
16
16
16
1
3
3
3
15
15
19
13
15
16
16
16
17
17
17
17
15
15
9
5
5
5
5
5
5
5
5
5
5
11
9
9
9
9
3
19
3
3
15
5
5
18
9
20
5
8
11
15
15
15
9
3
15
17
17
12
20
11
6
8
16
16
16
8
9
7
15
15
15
15
1
13
13
13
15
15
15
15
9
4
15
14
1
6
15
15
15
10
16
20
9
9
9
9
18
11
12
20
15
15
15
5
15
15
17
17
17
15
15
15
1
15
15
1