In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

# Load Model

In [None]:
import os
import sys

inference_ipynb_path='/root/code/SimpleAICV_pytorch_training_examples/14.video_interactive_segmentation_training/sam2_predict_example'
BASE_DIR = os.path.dirname(os.path.dirname(inference_ipynb_path))
sys.path.append(BASE_DIR)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torch.nn as nn
import torch.nn.functional as F

from SimpleAICV.video_interactive_segmentation.models.segment_anything2.sam2video_test import hiera_b_plus_sam2video_test
from SimpleAICV.video_interactive_segmentation.common import load_state_dict

sam2_checkpoint = '/root/autodl-tmp/pretrained_models/sam2_segmentation_train_on_video_interactive_segmentation_dataset/hiera_b_plus_sam2video_multilevel_stage3_epoch_20.pth'

sam2_model = hiera_b_plus_sam2video_test()
sam2_model = sam2_model.cuda()
sam2_model = sam2_model.eval()

load_state_dict(sam2_checkpoint, sam2_model)

In [None]:
video_dir_path = '/root/code/SimpleAICV_pytorch_training_examples/14.video_interactive_segmentation_training/sam2_predict_example/test_videos/bedroom'

frames_name_list = []
for per_frame_name in os.listdir(video_dir_path):
    if '.jpg' in per_frame_name:
        frames_name_list.append(per_frame_name)
frames_name_list = sorted(frames_name_list)
frames_path_list = [
    os.path.join(video_dir_path, n) for n in frames_name_list
]

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame: {frame_idx}")
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)

# Init Video State Dict

In [None]:
video_state_dict = sam2_model.init_video_state_dict(video_dir_path=video_dir_path)

# Clear Video State Dict All Info

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_info(video_state_dict)

# Init Video State Dict Again

In [None]:
video_state_dict = sam2_model.init_video_state_dict(video_dir_path=video_dir_path)

# Add one object with one point prompt

In [None]:
frame_idx = 0

# Let's add a positive click at (x, y) = (210, 350) to get started, for labels, `1` means positive click and `0` means negative click
input_point = np.array([[210, 350]], dtype=np.float32)
input_label = np.array([[1]], dtype=np.int32)
print(input_point.shape, input_label.shape)

input_prompt_point = np.concatenate([input_point, input_label], axis=1, dtype=np.float32)
input_prompt_point = np.expand_dims(input_prompt_point, axis=0)
print(input_prompt_point.shape,input_prompt_point.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=input_prompt_point,
                                    prompt_box=None,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame: {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_points(input_point, input_label[0], plt.gca())

mask_preds, iou_preds = sam2_model.forward_one_image_test(video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0, 1, 2, 3])
mask_preds, iou_preds = mask_preds[0], iou_preds[0]
binary_mask_preds = mask_preds > 0.
print(binary_mask_preds.shape,iou_preds.shape,iou_preds)

for i, (mask, score) in enumerate(zip(binary_mask_preds, iou_preds)):
    mask=mask.cpu().float().numpy()
    score=score.cpu().float().numpy()
    plt.figure(figsize=(9, 6))
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)
    show_mask(mask, plt.gca(), obj_id=new_object_id)
    show_points(input_point, input_label[0], plt.gca())
    plt.title(f"frame: {frame_idx}, Mask: {i+1},IoU Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()

# Tracking one object with one point prompt

In [None]:
tracking_object_id = 0
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=0, tracking_object_ids=[tracking_object_id], use_point_prompt_input=True,use_box_prompt_input=False, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_mask = per_frame_object_result['pred_mask']
    per_frame_per_object_iou = per_frame_object_result['pred_iou']
    per_frame_per_object_score = per_frame_object_result['pred_object_score']

    per_frame_per_object_mask = per_frame_per_object_mask > 0.

    plt.figure(figsize=(9, 6))
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)
    show_mask(per_frame_per_object_mask, plt.gca(), obj_id=tracking_object_id)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_iou:.3f}, Object Score: {per_frame_per_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

# Add one object with two point prompt

In [None]:
frame_idx = 0

# Let's add a positive click at (x, y) = (210, 350) to get started, for labels, `1` means positive click and `0` means negative click
input_point = np.array([[210, 350], [250, 220]], dtype=np.float32)
input_label = np.array([[1], [1]], dtype=np.int32)
print(input_point.shape, input_label.shape)

input_prompt_point = np.concatenate([input_point, input_label], axis=1, dtype=np.float32)
input_prompt_point = np.expand_dims(input_prompt_point, axis=0)
print(input_prompt_point.shape,input_prompt_point.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=input_prompt_point,
                                    prompt_box=None,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())

mask_preds, iou_preds = sam2_model.forward_one_image_test(video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0])
mask_preds, iou_preds = mask_preds[0][0], iou_preds[0][0]
binary_mask_preds = mask_preds > 0.
binary_mask_preds = binary_mask_preds.cpu().float().numpy()
iou_preds = iou_preds.cpu().float().numpy()
print(binary_mask_preds.shape,iou_preds.shape,iou_preds)

plt.figure(figsize=(9, 6))
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_mask(binary_mask_preds, plt.gca())
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())
plt.title(f"frame {frame_idx}, Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

# Tracking one object with two point prompt

In [None]:
tracking_object_id = 1
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=0, tracking_object_ids=[tracking_object_id], use_point_prompt_input=True,use_box_prompt_input=False, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_mask = per_frame_object_result['pred_mask']
    per_frame_per_object_iou = per_frame_object_result['pred_iou']
    per_frame_per_object_score = per_frame_object_result['pred_object_score']

    per_frame_per_object_mask = per_frame_per_object_mask > 0.

    plt.figure(figsize=(9, 6))
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)
    show_mask(per_frame_per_object_mask, plt.gca(), obj_id=tracking_object_id)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_iou:.3f}, Object Score: {per_frame_per_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

# Clear Video State Dict All Object Info

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_object_info(video_state_dict)

# Add one object with box prompt

In [None]:
frame_idx = 0

# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
input_box = np.array([300, 0, 500, 400], dtype=np.float32)
print(input_box.shape)

input_prompt_box = np.expand_dims(input_box, axis=0)
print(input_prompt_box.shape,input_prompt_box.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=input_prompt_box,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_box(input_box, plt.gca())

mask_preds, iou_preds = sam2_model.forward_one_image_test(video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0])
mask_preds, iou_preds = mask_preds[0][0], iou_preds[0][0]
binary_mask_preds = mask_preds > 0.
binary_mask_preds = binary_mask_preds.cpu().float().numpy()
iou_preds = iou_preds.cpu().float().numpy()
print(binary_mask_preds.shape,iou_preds.shape,iou_preds)

plt.figure(figsize=(9, 6))
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_mask(binary_mask_preds, plt.gca())
show_box(input_box, plt.gca())
plt.title(f"frame {frame_idx}, Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

# Tracking one object with box prompt

In [None]:
tracking_object_id = 0
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=0, tracking_object_ids=[tracking_object_id], use_point_prompt_input=False,use_box_prompt_input=True, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_mask = per_frame_object_result['pred_mask']
    per_frame_per_object_iou = per_frame_object_result['pred_iou']
    per_frame_per_object_score = per_frame_object_result['pred_object_score']

    per_frame_per_object_mask = per_frame_per_object_mask > 0.

    plt.figure(figsize=(9, 6))
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)
    show_mask(per_frame_per_object_mask, plt.gca(), obj_id=tracking_object_id)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_iou:.3f}, Object Score: {per_frame_per_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

# Clear Video State Dict All Object Info

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_object_info(video_state_dict)

# Add one object with box prompt at inter frame

In [None]:
frame_idx = 60

# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
input_box = np.array([180, 0, 450, 430], dtype=np.float32)
print(input_box.shape)

input_prompt_box = np.expand_dims(input_box, axis=0)
print(input_prompt_box.shape,input_prompt_box.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=input_prompt_box,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_box(input_box, plt.gca())

mask_preds, iou_preds = sam2_model.forward_one_image_test(video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0])
mask_preds, iou_preds = mask_preds[0][0], iou_preds[0][0]
binary_mask_preds = mask_preds > 0.
binary_mask_preds = binary_mask_preds.cpu().float().numpy()
iou_preds = iou_preds.cpu().float().numpy()
print(binary_mask_preds.shape, iou_preds.shape, iou_preds)

plt.figure(figsize=(9, 6))
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_mask(binary_mask_preds, plt.gca())
show_box(input_box, plt.gca())
plt.title(f"frame {frame_idx}, Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

# Tracking one object with box prompt at inter frame

In [None]:
tracking_object_id = 0
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=60, tracking_object_ids=[tracking_object_id], use_point_prompt_input=False,use_box_prompt_input=True, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_mask = per_frame_object_result['pred_mask']
    per_frame_per_object_iou = per_frame_object_result['pred_iou']
    per_frame_per_object_score = per_frame_object_result['pred_object_score']

    per_frame_per_object_mask = per_frame_per_object_mask > 0.

    plt.figure(figsize=(9, 6))
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)
    show_mask(per_frame_per_object_mask, plt.gca(), obj_id=tracking_object_id)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_iou:.3f}, Object Score: {per_frame_per_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

# Clear Video State Dict All Object Info

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_object_info(video_state_dict)

# Add one object with box prompt at final frame

In [None]:
frame_idx = 199

# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
input_box = np.array([300, 50, 470, 375], dtype=np.float32)
print(input_box.shape)

input_prompt_box = np.expand_dims(input_box, axis=0)
print(input_prompt_box.shape,input_prompt_box.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=input_prompt_box,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_box(input_box, plt.gca())

mask_preds, iou_preds = sam2_model.forward_one_image_test(video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0])
mask_preds, iou_preds = mask_preds[0][0], iou_preds[0][0]
binary_mask_preds = mask_preds > 0.
binary_mask_preds = binary_mask_preds.cpu().float().numpy()
iou_preds = iou_preds.cpu().float().numpy()
print(binary_mask_preds.shape, iou_preds.shape, iou_preds)

plt.figure(figsize=(9, 6))
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_mask(binary_mask_preds, plt.gca())
show_box(input_box, plt.gca())
plt.title(f"frame {frame_idx}, Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

# Tracking one object with box prompt at final frame

In [None]:
tracking_object_id = 0
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=199, tracking_object_ids=[tracking_object_id], use_point_prompt_input=False,use_box_prompt_input=True, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_mask = per_frame_object_result['pred_mask']
    per_frame_per_object_iou = per_frame_object_result['pred_iou']
    per_frame_per_object_score = per_frame_object_result['pred_object_score']

    per_frame_per_object_mask = per_frame_per_object_mask > 0.

    plt.figure(figsize=(9, 6))
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)
    show_mask(per_frame_per_object_mask, plt.gca(), obj_id=tracking_object_id)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_iou:.3f}, Object Score: {per_frame_per_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

# Clear Video State Dict All Object Info

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_object_info(video_state_dict)

# Add one object with mask prompt

In [None]:
frame_idx = 0
prompt_mask_path= '/root/code/SimpleAICV_pytorch_training_examples/14.video_interactive_segmentation_training/sam2_predict_example/test_videos/bedroom/00000_prompt_mask_for_target_0.png'

input_mask = np.array(Image.open(prompt_mask_path).convert('L'), dtype=np.uint8)
input_mask = input_mask / 255.
input_mask = input_mask.astype(np.float32)
print(input_mask.shape)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_mask(input_mask, plt.gca())

input_prompt_mask = input_mask
print(input_prompt_mask.shape,input_prompt_mask.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=None,
                                    prompt_mask=input_prompt_mask)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

# Tracking one object with mask prompt

In [None]:
tracking_object_id = 0
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=0, tracking_object_ids=[tracking_object_id], use_point_prompt_input=False,use_box_prompt_input=False, use_mask_prompt_input=True)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_mask = per_frame_object_result['pred_mask']
    per_frame_per_object_iou = per_frame_object_result['pred_iou']
    per_frame_per_object_score = per_frame_object_result['pred_object_score']

    per_frame_per_object_mask = per_frame_per_object_mask > 0.

    plt.figure(figsize=(9, 6))
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)
    show_mask(per_frame_per_object_mask, plt.gca(), obj_id=tracking_object_id)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_iou:.3f}, Object Score: {per_frame_per_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

# Clear Video State Dict All Object Info

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_object_info(video_state_dict)

# Add two object with box prompt

In [None]:
frame_idx = 0

# Let's add a box at (x_min, y_min, x_max, y_max) = (150, 130, 290, 410) to get started
input_box = np.array([150, 130, 290, 410], dtype=np.float32)
print(input_box.shape)

input_prompt_box = np.expand_dims(input_box, axis=0)
print(input_prompt_box.shape,input_prompt_box.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=input_prompt_box,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_box(input_box, plt.gca())

mask_preds, iou_preds = sam2_model.forward_one_image_test(video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0])
mask_preds, iou_preds = mask_preds[0][0], iou_preds[0][0]
binary_mask_preds = mask_preds > 0.
binary_mask_preds = binary_mask_preds.cpu().float().numpy()
iou_preds = iou_preds.cpu().float().numpy()
print(binary_mask_preds.shape, iou_preds.shape, iou_preds)

plt.figure(figsize=(9, 6))
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_mask(binary_mask_preds, plt.gca())
show_box(input_box, plt.gca())
plt.title(f"frame {frame_idx}, Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

In [None]:
frame_idx = 0

# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
input_box = np.array([300, 0, 500, 400], dtype=np.float32)
print(input_box.shape)

input_prompt_box = np.expand_dims(input_box, axis=0)
print(input_prompt_box.shape,input_prompt_box.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=input_prompt_box,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_box(input_box, plt.gca())

mask_preds, iou_preds = sam2_model.forward_one_image_test(video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0])
mask_preds, iou_preds = mask_preds[0][0], iou_preds[0][0]
binary_mask_preds = mask_preds > 0.
binary_mask_preds = binary_mask_preds.cpu().float().numpy()
iou_preds = iou_preds.cpu().float().numpy()
print(binary_mask_preds.shape, iou_preds.shape, iou_preds)

plt.figure(figsize=(9, 6))
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_mask(binary_mask_preds, plt.gca())
show_box(input_box, plt.gca())
plt.title(f"frame {frame_idx}, Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

# Tracking two object with box prompt

In [None]:
tracking_object_ids = [0, 1]
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=0, tracking_object_ids=tracking_object_ids, use_point_prompt_input=False,use_box_prompt_input=True, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state_0 = video_state_dict['object_track_state'][0]
object_track_result_0 = video_state_dict['object_track_result'][0]
print(frame_num, object_track_state_0)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result_0[frame_idx]
    per_frame_per_object_mask = per_frame_object_result['pred_mask']
    per_frame_per_object_iou = per_frame_object_result['pred_iou']
    per_frame_per_object_score = per_frame_object_result['pred_object_score']

    per_frame_per_object_mask = per_frame_per_object_mask > 0.

    plt.figure(figsize=(9, 6))
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)
    show_mask(per_frame_per_object_mask, plt.gca(), obj_id=0)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_iou:.3f}, Object Score: {per_frame_per_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

frame_num = video_state_dict['video_frame_num']
object_track_state_1 = video_state_dict['object_track_state'][1]
object_track_result_1 = video_state_dict['object_track_result'][1]
print(frame_num, object_track_state_1)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result_1[frame_idx]
    per_frame_per_object_mask = per_frame_object_result['pred_mask']
    per_frame_per_object_iou = per_frame_object_result['pred_iou']
    per_frame_per_object_score = per_frame_object_result['pred_object_score']

    per_frame_per_object_mask = per_frame_per_object_mask > 0.

    plt.figure(figsize=(9, 6))
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)
    show_mask(per_frame_per_object_mask, plt.gca(), obj_id=1)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_iou:.3f}, Object Score: {per_frame_per_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()