# Segment Anything experiment

Recreating ritm_predict on SAM.

In [1]:
import os
import sys
import torch
import torchvision
import cv2
import numpy as np
from modules.dataset import ImageAnnotations

!git clone https://github.com/facebookresearch/segment-anything.git

os.chdir('segment-anything') 
sys.path.insert(0, '..') # allow you to import modules from sam repo

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
def sam_input_points_from_annotations(annotation:ImageAnnotations, bbox_id:int, include_aux=True, coord_bb=True):
    """Get input points and labels from an ImageAnnotations object."""
    bbox = next(box for box in annotation.box_annotations if box['id'] == annotation.box_annotations[bbox_id]['id'])
    input_points = []
    input_labels = []
    for point_id in bbox['points_inside']:
        point = next(p for p in annotation.point_annotations if p['id'] == point_id)
        aux_points = []
        if point['is_aux']:
            aux_points.append(point)
        else:
            if coord_bb:
                input_points.append([point['coord_bb'][1], point['coord_bb'][0]])
            else:
                input_points.append([point['coords'][1], point['coords'][0]])
            input_labels.append(point['is_positive'])
        if include_aux and aux_points:
            for aux_point in aux_points:
                if coord_bb:
                    input_points.append([aux_point['coord_bb'][1], aux_point['coord_bb'][0]])
                else:
                    input_points.append([[aux_point['coords'][1], aux_point['coords'][0]]])
                input_labels.append(point['is_positive'])
    return np.array(input_points), np.array(input_labels)

In [7]:
def get_sam_prediction_point(sam_predictor, annotations:ImageAnnotations, jsons_folder='../datasets/supervisely_annotations/'
                       , wgisd_folder='../datasets/wgisd_annotations/', include_aux=True, wgisd_n_points=3):
    """Get a prediction from a sam_predictor object."""
    annotations.load_image('../datasets/images/' + annotations.image_name)
    print('../datasets/images/' + annotations.image_name)
    
    if annotations.image_name.startswith('DSC'):
            annotations.read_supervisely(jsons_folder + annotations.image_name + '.json')
    elif annotations.image_name.startswith('SYH') or annotations.image_name.startswith('CSV'):
        sufx = f'-{wgisd_n_points}point' if wgisd_n_points==1 else f'-{wgisd_n_points}points'
        annotations.read_wgisd(wgisd_folder + annotations.image_name.replace('.jpg', '.npz'), excel_sufx=sufx)
    
    masks_list, scores_list, logits_list = [], [], []
    for id, bb in enumerate(annotations.box_annotations): # ):
        print(bb)
        image = annotations.bb_image_2_np(bb)
        sam_predictor.set_image(image)
        input_points, input_labels = sam_input_points_from_annotations(annotations, id, include_aux=include_aux)
        masks, scores, logits = sam_predictor.predict(
            point_coords=input_points,
            point_labels=input_labels,
            multimask_output=False)
        masks_list.append(masks), scores_list.append(scores), logits_list.append(logits)
    return masks_list, scores_list, logits_list

def get_sam_prediction_bbox(sam_predictor, annotations:ImageAnnotations, jsons_folder='../datasets/supervisely_annotations/'
                       , wgisd_folder='../datasets/wgisd_annotations/', bounding_boxes_only=True):
    """Get a prediction from a sam_predictor object."""
    annotations.load_image('../datasets/images/' + annotations.image_name)
    print('../datasets/images/' + annotations.image_name)

    if annotations.image_name.startswith('DSC'):
        annotations.read_supervisely(jsons_folder + annotations.image_name + '.json')
    elif annotations.image_name.startswith('SYH') or annotations.image_name.startswith('CSV'):
        annotations.read_wgisd(wgisd_folder + annotations.image_name.replace('.jpg', '.npz'))

    converted_boxes = [[x1, y1, x2, y2] for bb in annotations.box_annotations for [x1, y1], [x2, y2] in [bb['coords']]]
    input_boxes = torch.tensor(converted_boxes, device=sam_predictor.device)
    transformed_boxes = sam_predictor.transform.apply_boxes_torch(input_boxes, annotations.image.shape[:2])

    if not bounding_boxes_only:
        # input_points = [point['coord'] for point in annotations.point_annotations if not point['is_aux']]
        converted_coords = [[[point['coord'][1], point['coord'][0]]] for point in annotations.point_annotations if not point['is_aux']]
        input_coords = torch.tensor(converted_coords, device=sam_predictor.device)
        transformed_coords = sam_predictor.transform.apply_coords_torch(input_coords, annotations.image.shape[:2])
        input_labels = [point['is_positive'] for point in annotations.point_annotations if not point['is_aux']]
        transformed_labels = torch.tensor(input_labels, device=sam_predictor.device)
    else:
        transformed_coords, transformed_labels = None, None
    sam_predictor.set_image(annotations.image)
    masks, scores, logits = sam_predictor.predict_torch(
        point_coords=transformed_coords,
        point_labels=transformed_labels,
        boxes=transformed_boxes,
        multimask_output=False)

    return masks, scores, logits

In [4]:
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "weights/sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

In [8]:
## Time to iterate
import os
import sys

def save_sam_predictions(imgs, output_folder, sam_predictor, prediction_name, include_aux=True,
                                                    jsons_folder='../datasets/supervisely_annotations/',
                                                        wgisd_folder='../datasets/wgisd_annotations/', wgisd_n_points=None, bounding_boxes=False, bounding_boxes_only=True):
    """Save predictions from a sam_predictor object.
    Args:
        imgs: list of image names
        output_folder: folder to save the predictions
        sam_predictor: sam_predictor object
        prediction_name: name of the prediction
        include_aux: include auxiliary points in the prediction
        jsons_folder: folder with supervisely annotations
        wgisd_folder: folder with wgisd annotations
        wgisd_n_points: number of points to be used in wgisd annotations
        bounding_boxes: use bounding boxes instead of points
        bounding_boxes_only: if bounding_boxes param is set to True, set this arg to true for only use bounding boxes"""
    for img in imgs:
        anot = ImageAnnotations(img)
        print ('Processing', anot.image_name)
        if bounding_boxes:
            masks, scores, logits = get_sam_prediction_bbox(sam_predictor, anot, jsons_folder, wgisd_folder, bounding_boxes_only)
            mask_squeezed = masks.squeeze(1)
            combined_mask = torch.max(mask_squeezed, dim=0)[0]
            anot.image_pred[prediction_name] = combined_mask.cpu().numpy().astype(np.uint8)
            dados = {'full': anot.image_pred[prediction_name],
                     'bbs': mask_squeezed.cpu().numpy().astype(np.uint8)}
        else:
            masks, scores, logits = get_sam_prediction_point(sam_predictor, anot, include_aux=include_aux, wgisd_n_points=wgisd_n_points)
            anot.reconstruct_prediction_mask(masks, prediction_name)
            dados = {'full': (anot.image_pred[prediction_name]).astype(np.uint8),
                     'bbs': {i:(pred[0,:,:]).astype(np.uint8) for i, pred in enumerate(masks)}
            }
        
        np.savez_compressed(f'{output_folder}/{anot.image_name}_{prediction_name}.npz', **dados, allow_pickle=True)
        print(f'{anot.image_name} annotation saved.')

In [16]:
from copy import copy
all_jpgs = os.listdir('../datasets/images/')
wgisd_csv = [img for img in all_jpgs if img.startswith('CSV')]
wgisd_syh = [img for img in all_jpgs if img.startswith('SYH')]
supervisely = [img for img in all_jpgs if img.startswith('DSC')]

import time
start = time.time()

# save_sam_predictions(wgisd_csv, '../datasets/sam_predictions/', predictor, f'sam_vit_h_3_points', include_aux=True, wgisd_n_points=3)
# print('Tempo 3 pontos:', time.time() - start)


# save_sam_predictions(supervisely, '../datasets/sam_predictions/', predictor, f'sam_vit_h_bbs', bounding_boxes=True)
# print('Tempo bbs:', time.time() - start)

for i in range(len(supervisely)):
    predictor1 = copy(predictor)
    save_sam_predictions(supervisely[i:i+1], '../datasets/sam_predictions/', predictor1, f'sam_vit_h_bbs_point', bounding_boxes=True,  bounding_boxes_only=False)
# save_sam_predictions(supervisely, '../datasets/sam_predictions/', predictor, f'sam_vit_h_bbs_point', bounding_boxes=True, bounding_boxes_only=False)
print('Tempo bbs e um ponto:', time.time() - start) 

Processing DSC_0109.JPG
../datasets/images/DSC_0109.JPG
DSC_0109.JPG annotation saved.
Processing DSC_0110.JPG
../datasets/images/DSC_0110.JPG
DSC_0110.JPG annotation saved.
Processing DSC_0111.JPG
../datasets/images/DSC_0111.JPG
DSC_0111.JPG annotation saved.
Processing DSC_0112.JPG
../datasets/images/DSC_0112.JPG
DSC_0112.JPG annotation saved.
Processing DSC_0113.JPG
../datasets/images/DSC_0113.JPG
DSC_0113.JPG annotation saved.
Processing DSC_0114.JPG
../datasets/images/DSC_0114.JPG
DSC_0114.JPG annotation saved.
Processing DSC_0115.JPG
../datasets/images/DSC_0115.JPG
DSC_0115.JPG annotation saved.
Processing DSC_0120.JPG
../datasets/images/DSC_0120.JPG
DSC_0120.JPG annotation saved.
Processing DSC_0122.JPG
../datasets/images/DSC_0122.JPG
DSC_0122.JPG annotation saved.
Processing DSC_0124.JPG
../datasets/images/DSC_0124.JPG
DSC_0124.JPG annotation saved.
Processing DSC_0126.JPG
../datasets/images/DSC_0126.JPG
DSC_0126.JPG annotation saved.
Processing DSC_0129.JPG
../datasets/images/

In [15]:
for i in range(len(wgisd_syh)):
    predictor1 = copy(predictor)
    save_sam_predictions(wgisd_syh[i:i+1], '../datasets/sam_predictions/', predictor1, f'sam_vit_h_bbs_point', bounding_boxes=True,  bounding_boxes_only=False)
# save_sam_predictions(supervisely[:1], '../datasets/sam_predictions/', predictor, f'sam_vit_h_bbs_point', bounding_boxes=True, bounding_boxes_only=False)

print('Tempo bbs:', time.time() - start) # até o 16 -> 26,95min + 24,72min = 51,67min

Processing SYH_2017-04-27_1233.jpg
../datasets/images/SYH_2017-04-27_1233.jpg
WGisd annotations loaded.
SYH_2017-04-27_1233.jpg annotation saved.
Processing SYH_2017-04-27_1236.jpg
../datasets/images/SYH_2017-04-27_1236.jpg
WGisd annotations loaded.
SYH_2017-04-27_1236.jpg annotation saved.
Processing SYH_2017-04-27_1238.jpg
../datasets/images/SYH_2017-04-27_1238.jpg
WGisd annotations loaded.
SYH_2017-04-27_1238.jpg annotation saved.
Processing SYH_2017-04-27_1239.jpg
../datasets/images/SYH_2017-04-27_1239.jpg
WGisd annotations loaded.
SYH_2017-04-27_1239.jpg annotation saved.
Processing SYH_2017-04-27_1241.jpg
../datasets/images/SYH_2017-04-27_1241.jpg
WGisd annotations loaded.
SYH_2017-04-27_1241.jpg annotation saved.
Processing SYH_2017-04-27_1251.jpg
../datasets/images/SYH_2017-04-27_1251.jpg
WGisd annotations loaded.
SYH_2017-04-27_1251.jpg annotation saved.
Processing SYH_2017-04-27_1253.jpg
../datasets/images/SYH_2017-04-27_1253.jpg
WGisd annotations loaded.
SYH_2017-04-27_1253.