In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

In [None]:
import os
import sys

inference_ipynb_path='/root/code/SimpleAICV_pytorch_training_examples/13.interactive_segmentation_training/sam_predict_example'
BASE_DIR = os.path.dirname(os.path.dirname(inference_ipynb_path))
sys.path.append(BASE_DIR)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torch.nn as nn
import torch.nn.functional as F

from SimpleAICV.interactive_segmentation.models.segment_anything.sam import sam_b
from SimpleAICV.interactive_segmentation.common import load_state_dict


sam_checkpoint = '/root/autodl-tmp/pretrained_models/sam_segmentation_train_on_interactive_segmentation_dataset/sam_b_multilevel_epoch_2.pth'

sam_model = sam_b()
sam_model = sam_model.cuda()
sam_model = sam_model.eval()

load_state_dict(sam_checkpoint,sam_model)

In [None]:
test_image_path='/root/code/SimpleAICV_pytorch_training_examples/13.interactive_segmentation_training/sam_predict_example/test_images/truck.jpg'
origin_image = cv2.imdecode(np.fromfile(test_image_path, dtype=np.uint8),
                        cv2.IMREAD_COLOR)
origin_image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB)
print(origin_image.shape,origin_image.dtype)

plt.figure(figsize=(10,10))
plt.imshow(origin_image)
plt.axis('on')
plt.show()

In [None]:
origin_h, origin_w = origin_image.shape[0], origin_image.shape[1]
factor = 1024 / max(origin_h, origin_w)
resize_h, resize_w = int(round(origin_h * factor)), int(
    round(origin_w * factor))
resized_image = cv2.resize(origin_image, (resize_w, resize_h))
print(resized_image.shape,resized_image.dtype,factor)

plt.figure(figsize=(10, 10))
plt.imshow(resized_image)
plt.axis('on')
plt.show()

mean = [123.675, 116.28, 103.53]
std = [58.395, 57.12, 57.375]
norm_image = (resized_image - mean) / std
print(norm_image.shape,np.max(norm_image),np.min(norm_image))

padded_img = np.zeros((max(resize_h, resize_w), max(resize_h, resize_w), 3),
                        dtype=np.float32)
padded_img[:resize_h, :resize_w, :] = norm_image
print(padded_img.shape,np.max(padded_img),np.min(padded_img))

padded_img = torch.tensor(padded_img).float().cuda().permute(2, 0, 1).unsqueeze(0)
print(padded_img.shape,torch.max(padded_img),torch.min(padded_img))

In [None]:
input_point = np.array([[500, 375]])*factor
# if positive point, input label = 1 ,elif negative point, input label = 0
input_label = np.array([[1]])
print(input_point.shape,input_label.shape)

prompt_point=np.concatenate([input_point,input_label],axis=1)
print(prompt_point.shape)

plt.figure(figsize=(10,10))
plt.imshow(resized_image)
show_points(input_point, input_label[0], plt.gca())
plt.axis('on')
plt.show()

In [None]:
input_prompt_point = torch.tensor(np.expand_dims(prompt_point,axis=0)).float().cuda()
print(padded_img.shape,input_prompt_point.shape,padded_img.dtype,input_prompt_point.dtype)

batch_images = padded_img.clone()
batch_prompts = {'prompt_point':input_prompt_point,
                'prompt_box':None,
                'prompt_mask':None}

with torch.no_grad():
    mask_preds, iou_preds = sam_model(batch_images, batch_prompts, mask_out_idxs=[0, 1, 2, 3])
    mask_preds, iou_preds = mask_preds[0], iou_preds[0]
    binary_mask_preds = mask_preds > 0.
print(mask_preds.shape,iou_preds.shape,iou_preds)

for i, (mask, score) in enumerate(zip(binary_mask_preds, iou_preds)):
    mask=mask.cpu().numpy()
    score=score.cpu().numpy()
    plt.figure(figsize=(10,10))
    plt.imshow(resized_image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label[0], plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  

In [None]:
input_point = np.array([[500, 375], [1125, 625]])*factor
input_label = np.array([[1], [1]])
print(input_point.shape,input_label.shape)

prompt_point = np.concatenate([input_point,input_label],axis=1)
print(prompt_point.shape)

# Choose the model's best mask
prompt_mask = binary_mask_preds.clone().unsqueeze(0).float()
prompt_mask = F.interpolate(prompt_mask,(256, 256), mode = "nearest")
input_prompt_mask = prompt_mask[:,3:4,:,:]
print(input_prompt_mask.shape)

input_prompt_point = torch.tensor(np.expand_dims(prompt_point,axis=0)).float().cuda()
print(padded_img.shape,input_prompt_point.shape,padded_img.dtype,input_prompt_point.dtype)

batch_images = padded_img.clone()
batch_prompts = {'prompt_point':input_prompt_point,
                'prompt_box':None,
                'prompt_mask':input_prompt_mask}

with torch.no_grad():
    mask_preds, iou_preds = sam_model(batch_images, batch_prompts, mask_out_idxs=[0, 1, 2, 3])
    mask_preds, iou_preds = mask_preds[0], iou_preds[0]
    binary_mask_preds = mask_preds > 0.
print(mask_preds.shape,iou_preds.shape,iou_preds)

plt.figure(figsize=(10,10))
plt.imshow(resized_image)
show_mask(binary_mask_preds.cpu().numpy()[0], plt.gca())
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())
plt.axis('off')
plt.show() 

In [None]:
input_point = np.array([[500, 375], [1125, 625]])*factor
input_label = np.array([[1], [0]])
print(input_point.shape,input_label.shape)

prompt_point = np.concatenate([input_point,input_label],axis=1)
print(prompt_point.shape)

# Choose the model's best mask
prompt_mask = binary_mask_preds.clone().unsqueeze(0).float()
prompt_mask = F.interpolate(prompt_mask,(256, 256), mode = "nearest")
input_prompt_mask = prompt_mask[:,1:2,:,:]
print(input_prompt_mask.shape)

input_prompt_point = torch.tensor(np.expand_dims(prompt_point,axis=0)).float().cuda()
print(padded_img.shape,input_prompt_point.shape,padded_img.dtype,input_prompt_point.dtype)

batch_images = padded_img.clone()
batch_prompts = {'prompt_point':input_prompt_point,
                'prompt_box':None,
                'prompt_mask':input_prompt_mask}

with torch.no_grad():
    mask_preds, iou_preds = sam_model(batch_images, batch_prompts, mask_out_idxs=[0, 1, 2, 3])
    mask_preds, iou_preds = mask_preds[0], iou_preds[0]
    binary_mask_preds = mask_preds > 0.
print(mask_preds.shape,iou_preds.shape,iou_preds)

plt.figure(figsize=(10,10))
plt.imshow(resized_image)
show_mask(binary_mask_preds.cpu().numpy()[0], plt.gca())
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())
plt.axis('off')
plt.show() 

In [None]:
input_box = np.array([425, 600, 700, 875])*factor

input_prompt_box = torch.tensor(np.expand_dims(input_box,axis=0)).float().cuda()
print(padded_img.shape,input_prompt_box.shape,padded_img.dtype,input_prompt_box.dtype)

batch_images = padded_img.clone()
batch_prompts = {'prompt_point':None,
                'prompt_box':input_prompt_box,
                'prompt_mask':None}

with torch.no_grad():
    mask_preds, iou_preds = sam_model(batch_images, batch_prompts, mask_out_idxs=[0, 1, 2, 3])
    mask_preds, iou_preds = mask_preds[0], iou_preds[0]
    binary_mask_preds = mask_preds > 0.
print(mask_preds.shape,iou_preds.shape,iou_preds)

plt.figure(figsize=(10,10))
plt.imshow(resized_image)
show_mask(binary_mask_preds.cpu().numpy()[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show() 

In [None]:
input_box = np.array([425, 600, 700, 875])*factor
input_point = np.array([[575, 750]])*factor
input_label = np.array([[0]])
print(input_box.shape,input_point.shape,input_label.shape)

prompt_point = np.concatenate([input_point,input_label],axis=1)
print(prompt_point.shape)

input_prompt_point = torch.tensor(np.expand_dims(prompt_point,axis=0)).float().cuda()
print(padded_img.shape,input_prompt_point.shape,padded_img.dtype,input_prompt_point.dtype)

input_prompt_box = torch.tensor(np.expand_dims(input_box,axis=0)).float().cuda()
print(padded_img.shape,input_prompt_box.shape,padded_img.dtype,input_prompt_box.dtype)

batch_images = padded_img.clone()
batch_prompts = {'prompt_point':input_prompt_point,
                'prompt_box':input_prompt_box,
                'prompt_mask':None}

with torch.no_grad():
    mask_preds, iou_preds = sam_model(batch_images, batch_prompts, mask_out_idxs=[0, 1, 2, 3])
    mask_preds, iou_preds = mask_preds[0], iou_preds[0]
    binary_mask_preds = mask_preds > 0.
print(mask_preds.shape,iou_preds.shape,iou_preds)

plt.figure(figsize=(10,10))
plt.imshow(resized_image)
show_mask(binary_mask_preds.cpu().numpy()[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label[0], plt.gca())
plt.axis('off')
plt.show()

In [None]:
per_images = padded_img.clone()

with torch.no_grad():
    per_image_embedding = sam_model.forward_image_encoder(per_images)
print(per_image_embedding.shape)

input_boxes = np.array([
    [75, 275, 1725, 850],
    [425, 600, 700, 875],
    [1375, 550, 1650, 800],
    [1240, 675, 1400, 750]])*factor
print(input_boxes.shape)

batch_masks = []
for per_box in input_boxes:
    per_prompt_box = torch.tensor(np.expand_dims(per_box,
                                                 axis=0)).float().cuda()
    per_prompt = {
        'prompt_point': None,
        'prompt_box': per_prompt_box,
        'prompt_mask': None
    }
    mask_out_idxs = [0]
    with torch.no_grad():
        mask_preds, iou_preds = sam_model.forward_prompt_encoder_mask_decoder(
            per_image_embedding, per_prompt, mask_out_idxs=mask_out_idxs)
        mask_preds, iou_preds = mask_preds[0], iou_preds[0]
        binary_mask_preds = mask_preds > 0.
    print(mask_preds.shape, iou_preds.shape, iou_preds)

    batch_masks.append(binary_mask_preds)
print(len(batch_masks))

plt.figure(figsize=(10,10))
plt.imshow(resized_image)
for idx, (per_mask, per_box) in enumerate(zip(batch_masks, input_boxes)):
    per_mask=per_mask.cpu().numpy()[0]
    show_mask(per_mask, plt.gca(), random_color=True)
    show_box(per_box, plt.gca())
plt.axis('off')
plt.show()

In [None]:
test_image_path2='/root/code/SimpleAICV_pytorch_training_examples/13.interactive_segmentation_training/sam_predict_example/test_images/groceries.jpg'
origin_image2 = cv2.imdecode(np.fromfile(test_image_path2, dtype=np.uint8),
                        cv2.IMREAD_COLOR)
origin_image2 = cv2.cvtColor(origin_image2, cv2.COLOR_BGR2RGB)
origin_h2, origin_w2 = origin_image2.shape[0], origin_image2.shape[1]
factor2 = 1024 / max(origin_h2, origin_w2)
resize_h2, resize_w2 = int(round(origin_h2 * factor2)), int(
    round(origin_w2 * factor2))
resized_image2 = cv2.resize(origin_image2, (resize_w2, resize_h2))
print(resized_image2.shape,resized_image2.dtype,factor2)

mean = [123.675, 116.28, 103.53]
std = [58.395, 57.12, 57.375]
norm_image2 = (resized_image2 - mean) / std
print(norm_image2.shape,np.max(norm_image2),np.min(norm_image2))

padded_img2 = np.zeros((max(resize_h2, resize_w2), max(resize_h2, resize_w2), 3),
                        dtype=np.float32)
padded_img2[:resize_h2, :resize_w2, :] = norm_image2
print(padded_img2.shape,np.max(padded_img2),np.min(padded_img2))

padded_img2 = torch.tensor(padded_img2).float().cuda().permute(2, 0, 1).unsqueeze(0)
print(padded_img2.shape,torch.max(padded_img2),torch.min(padded_img2))

per_images2 = padded_img2.clone()

with torch.no_grad():
    per_image_embedding = sam_model.forward_image_encoder(per_images2)
print(per_image_embedding.shape)

input_boxes2 = np.array([
    [450, 170, 520, 350],
    [350, 190, 450, 350],
    [500, 170, 580, 350],
    [580, 170, 640, 350]])*factor2

batch_masks2 = []
for per_box in input_boxes2:
    per_prompt_box = torch.tensor(np.expand_dims(per_box,
                                                 axis=0)).float().cuda()
    per_prompt = {
        'prompt_point': None,
        'prompt_box': per_prompt_box,
        'prompt_mask': None
    }
    mask_out_idxs = [0]
    with torch.no_grad():
        mask_preds, iou_preds = sam_model.forward_prompt_encoder_mask_decoder(
            per_image_embedding, per_prompt, mask_out_idxs=mask_out_idxs)
        mask_preds, iou_preds = mask_preds[0], iou_preds[0]
        binary_mask_preds = mask_preds > 0.
    print(mask_preds.shape, iou_preds.shape, iou_preds)

    batch_masks2.append(binary_mask_preds)
print(len(batch_masks2))

plt.figure(figsize=(10,10))
plt.imshow(resized_image2)
for idx, (per_mask, per_box) in enumerate(zip(batch_masks2, input_boxes2)):
    per_mask=per_mask.cpu().numpy()[0]
    show_mask(per_mask, plt.gca(), random_color=True)
    show_box(per_box, plt.gca())
plt.axis('off')
plt.show()