In [1]:
import os
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
from repvit_sam import SamPredictor, sam_model_registry

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sam_checkpoint = "/home/minhnh/project_drive/CV/FewshotObjectDetection/VoxDet-simplified/repvit_sam.pt"
model_type = "repvit"
device = torch.device('cuda:0')

repvit_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
repvit_sam = repvit_sam.to(device=device)
repvit_sam.eval()

def rpn_sam(filepath):
    # scale_factor = img_metas['scale_factor']
    image = Image.open(filepath)
    w, h = image.size

    xvalues = np.linspace(0, w, 25, dtype='int')
    yvalues = np.linspace(0, h, 25, dtype='int')
    xx, yy = np.meshgrid(xvalues, yvalues)
    positions = np.column_stack([xx.ravel(), yy.ravel()]).astype(int)

    predictor = SamPredictor(repvit_sam)
    nd_image = np.array(image)
    predictor.set_image(nd_image)
    point_label = np.array([1])
    prompt_points = np.expand_dims(positions, 1)

    proposals = []
    for point in prompt_points:
        masks, scores, logits = predictor.predict(
            point_coords=point,
            point_labels=point_label,
            multimask_output=False,
        )
        proposals.append(masks[0])
    bounding_boxes = torch.zeros((len(proposals), 4), dtype=torch.float)
    for index, mask in enumerate(proposals):
        mask = torch.Tensor(mask)
        h, w = mask.shape
        if (mask==0).all():
            bounding_boxes[index, 0] = w//2 - 2
            bounding_boxes[index, 1] = h//2 - 2
            bounding_boxes[index, 2] = w//2 + 2
            bounding_boxes[index, 3] = h//2 + 2
        else:
            y, x = torch.where(mask != 0)
            bounding_boxes[index, 0] = torch.min(x)
            bounding_boxes[index, 1] = torch.min(y)
            bounding_boxes[index, 2] = torch.max(x)
            bounding_boxes[index, 3] = torch.max(y)
    # bounding_boxes *= scale_factor
    return bounding_boxes

In [3]:
prefix = '/home/minhnh/project_drive/CV/FewshotObjectDetection/data/OWID//P2/'
imgs_base_path = prefix + 'images/'
proposals_base_path = prefix + 'proposals/'
img_filepath = os.listdir(imgs_base_path)
img_filepath = sorted(img_filepath)
start_idx = 0
end_idx = 5000
for img_filename in tqdm(img_filepath[start_idx:end_idx]):
    with torch.no_grad():
        img_id = img_filename.split('.')[0]

        output = rpn_sam(f'{imgs_base_path}{img_filename}')
        np.save(f'{proposals_base_path}{img_id}.npy', output.cpu().numpy())
        # break

  0%|          | 7/5000 [00:29<5:55:59,  4.28s/it]


KeyboardInterrupt: 