In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))


In [None]:
import os
import sys

inference_ipynb_path='/root/code/SimpleAICV_pytorch_training_examples/14.video_interactive_segmentation_training/sam2_predict_example'
BASE_DIR = os.path.dirname(os.path.dirname(inference_ipynb_path))
sys.path.append(BASE_DIR)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torch.nn as nn
import torch.nn.functional as F

from SimpleAICV.video_interactive_segmentation.models.segment_anything2_matting.sam2videomatting_test import hiera_b_plus_sam2video_matting_test
from SimpleAICV.video_interactive_segmentation.common import load_state_dict

sam2_checkpoint = ''

sam2_model = hiera_b_plus_sam2video_matting_test()
sam2_model = sam2_model.cuda()
sam2_model = sam2_model.eval()

load_state_dict(sam2_checkpoint, sam2_model)

In [None]:
video_dir_path = '/root/code/SimpleAICV_pytorch_training_examples/14.video_interactive_segmentation_training/sam2_predict_example/test_videos/bedroom'

frames_name_list = []
for per_frame_name in os.listdir(video_dir_path):
    if '.jpg' in per_frame_name:
        frames_name_list.append(per_frame_name)
frames_name_list = sorted(frames_name_list)
frames_path_list = [
    os.path.join(video_dir_path, n) for n in frames_name_list
]

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame: {frame_idx}")
    show_image = Image.open(frames_path_list[frame_idx])
    plt.imshow(show_image)

In [None]:
video_state_dict = sam2_model.init_video_state_dict(video_dir_path=video_dir_path)

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_info(video_state_dict)

In [None]:
video_state_dict = sam2_model.init_video_state_dict(video_dir_path=video_dir_path)

In [None]:
frame_idx = 0

# Let's add a positive click at (x, y) = (210, 350) to get started, for labels, `1` means positive click and `0` means negative click
input_point = np.array([[210, 350]], dtype=np.float32)
input_label = np.array([[1]], dtype=np.int32)
print(input_point.shape, input_label.shape)

input_prompt_point = np.concatenate([input_point, input_label], axis=1, dtype=np.float32)
input_prompt_point = np.expand_dims(input_prompt_point, axis=0)
print(input_prompt_point.shape,input_prompt_point.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=input_prompt_point,
                                    prompt_box=None,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame: {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_points(input_point, input_label[0], plt.gca())

global_preds, local_preds, fused_preds, iou_preds = sam2_model.forward_one_image_test(
    video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0, 1, 2, 3])
global_preds = global_preds[0][0].permute(1, 2, 0).float().cpu().numpy()
local_preds = local_preds[0][0][0].float().cpu().numpy()
fused_preds = fused_preds[0][0][0].float().cpu().numpy()
iou_preds = iou_preds[0][0].float().cpu().numpy()
print(global_preds.shape, local_preds.shape, fused_preds.shape,
      iou_preds.shape, iou_preds)

local_preds = np.expand_dims(local_preds, axis=-1)
fused_preds = np.expand_dims(fused_preds, axis=-1)
print(local_preds.shape, np.max(local_preds), np.min(local_preds))
print(fused_preds.shape, np.max(fused_preds), np.min(fused_preds))

plt.figure(figsize=(9, 6))
plt.imshow(global_preds)
show_points(input_point, input_label[0], plt.gca())
plt.title(f"frame: {frame_idx}, global pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(local_preds)
show_points(input_point, input_label[0], plt.gca())
plt.title(f"frame: {frame_idx}, local pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(fused_preds)
show_points(input_point, input_label[0], plt.gca())
plt.title(f"frame: {frame_idx}, fused pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

# 创建绿色背景
show_image = Image.open(frames_path_list[frame_idx])
green_background = np.zeros_like(show_image, dtype=np.float32)
green_background[:, :] = [0, 255, 0]  # RGB格式
print(green_background.shape, np.max(green_background),
      np.min(green_background))

# 得到前景区域和背景区域并合并
foreground = show_image * fused_preds
background = green_background * (1 - fused_preds)
result_image = foreground + background

foreground = foreground.astype(np.uint8)
background = background.astype(np.uint8)
result_image = result_image.astype(np.uint8)

plt.figure(figsize=(9, 6))
plt.imshow(foreground)
show_points(input_point, input_label[0], plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(background)
show_points(input_point, input_label[0], plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(result_image)
show_points(input_point, input_label[0], plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

In [None]:
tracking_object_id = 0
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=0, tracking_object_ids=[tracking_object_id], use_point_prompt_input=True,use_box_prompt_input=False, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_fuse_pred = per_frame_object_result['fuse_pred'][0]
    per_frame_per_object_pred_iou = per_frame_object_result['pred_iou']
    per_frame_per_pred_object_score = per_frame_object_result[
        'pred_object_score']

    show_image = Image.open(frames_path_list[frame_idx])
    green_background = np.zeros_like(show_image, dtype=np.float32)
    green_background[:, :] = [0, 255, 0]  # RGB格式

    # 得到前景区域和背景区域并合并
    per_frame_per_object_fuse_pred = np.expand_dims(
        per_frame_per_object_fuse_pred, axis=-1)
    foreground = show_image * per_frame_per_object_fuse_pred
    background = green_background * (1 - per_frame_per_object_fuse_pred)
    result_image = foreground + background
    result_image = result_image.astype(np.uint8)

    plt.figure(figsize=(9, 6))
    plt.imshow(result_image)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_pred_iou:.3f}, Object Score: {per_frame_per_pred_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

In [None]:
frame_idx = 0

# Let's add a positive click at (x, y) = (210, 350) to get started, for labels, `1` means positive click and `0` means negative click
input_point = np.array([[210, 350], [250, 220]], dtype=np.float32)
input_label = np.array([[1], [1]], dtype=np.int32)
print(input_point.shape, input_label.shape)

input_prompt_point = np.concatenate([input_point, input_label], axis=1, dtype=np.float32)
input_prompt_point = np.expand_dims(input_prompt_point, axis=0)
print(input_prompt_point.shape,input_prompt_point.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=input_prompt_point,
                                    prompt_box=None,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())

global_preds, local_preds, fused_preds, iou_preds = sam2_model.forward_one_image_test(
    video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0, 1, 2, 3])
global_preds = global_preds[0][0].permute(1, 2, 0).float().cpu().numpy()
local_preds = local_preds[0][0][0].float().cpu().numpy()
fused_preds = fused_preds[0][0][0].float().cpu().numpy()
iou_preds = iou_preds[0][0].float().cpu().numpy()
print(global_preds.shape, local_preds.shape, fused_preds.shape,
      iou_preds.shape, iou_preds)

local_preds = np.expand_dims(local_preds, axis=-1)
fused_preds = np.expand_dims(fused_preds, axis=-1)
print(local_preds.shape, np.max(local_preds), np.min(local_preds))
print(fused_preds.shape, np.max(fused_preds), np.min(fused_preds))

plt.figure(figsize=(9, 6))
plt.imshow(global_preds)
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())
plt.title(f"frame: {frame_idx}, global pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(local_preds)
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())
plt.title(f"frame: {frame_idx}, local pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(fused_preds)
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())
plt.title(f"frame: {frame_idx}, fused pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

# 创建绿色背景
show_image = Image.open(frames_path_list[frame_idx])
green_background = np.zeros_like(show_image, dtype=np.float32)
green_background[:, :] = [0, 255, 0]  # RGB格式
print(green_background.shape, np.max(green_background),
      np.min(green_background))

# 得到前景区域和背景区域并合并
foreground = show_image * fused_preds
background = green_background * (1 - fused_preds)
result_image = foreground + background

foreground = foreground.astype(np.uint8)
background = background.astype(np.uint8)
result_image = result_image.astype(np.uint8)

plt.figure(figsize=(9, 6))
plt.imshow(foreground)
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(background)
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(result_image)
show_points(input_point, np.squeeze(input_label,axis=1), plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

In [None]:
tracking_object_id = 1
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=0, tracking_object_ids=[tracking_object_id], use_point_prompt_input=True,use_box_prompt_input=False, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_fuse_pred = per_frame_object_result['fuse_pred'][0]
    per_frame_per_object_pred_iou = per_frame_object_result['pred_iou']
    per_frame_per_pred_object_score = per_frame_object_result[
        'pred_object_score']

    show_image = Image.open(frames_path_list[frame_idx])
    green_background = np.zeros_like(show_image, dtype=np.float32)
    green_background[:, :] = [0, 255, 0]  # RGB格式

    # 得到前景区域和背景区域并合并
    per_frame_per_object_fuse_pred = np.expand_dims(
        per_frame_per_object_fuse_pred, axis=-1)
    foreground = show_image * per_frame_per_object_fuse_pred
    background = green_background * (1 - per_frame_per_object_fuse_pred)
    result_image = foreground + background
    result_image = result_image.astype(np.uint8)

    plt.figure(figsize=(9, 6))
    plt.imshow(result_image)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_pred_iou:.3f}, Object Score: {per_frame_per_pred_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_object_info(video_state_dict)

In [None]:
frame_idx = 0

# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
input_box = np.array([300, 0, 500, 400], dtype=np.float32)
print(input_box.shape)

input_prompt_box = np.expand_dims(input_box, axis=0)
print(input_prompt_box.shape,input_prompt_box.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=input_prompt_box,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_box(input_box, plt.gca())

global_preds, local_preds, fused_preds, iou_preds = sam2_model.forward_one_image_test(
    video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0, 1, 2, 3])
global_preds = global_preds[0][0].permute(1, 2, 0).float().cpu().numpy()
local_preds = local_preds[0][0][0].float().cpu().numpy()
fused_preds = fused_preds[0][0][0].float().cpu().numpy()
iou_preds = iou_preds[0][0].float().cpu().numpy()
print(global_preds.shape, local_preds.shape, fused_preds.shape,
      iou_preds.shape, iou_preds)

local_preds = np.expand_dims(local_preds, axis=-1)
fused_preds = np.expand_dims(fused_preds, axis=-1)
print(local_preds.shape, np.max(local_preds), np.min(local_preds))
print(fused_preds.shape, np.max(fused_preds), np.min(fused_preds))

plt.figure(figsize=(9, 6))
plt.imshow(global_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, global pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(local_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, local pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(fused_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, fused pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

# 创建绿色背景
show_image = Image.open(frames_path_list[frame_idx])
green_background = np.zeros_like(show_image, dtype=np.float32)
green_background[:, :] = [0, 255, 0]  # RGB格式
print(green_background.shape, np.max(green_background),
      np.min(green_background))

# 得到前景区域和背景区域并合并
foreground = show_image * fused_preds
background = green_background * (1 - fused_preds)
result_image = foreground + background

foreground = foreground.astype(np.uint8)
background = background.astype(np.uint8)
result_image = result_image.astype(np.uint8)

plt.figure(figsize=(9, 6))
plt.imshow(foreground)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(background)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(result_image)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

In [None]:
tracking_object_id = 0
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=0, tracking_object_ids=[tracking_object_id], use_point_prompt_input=False,use_box_prompt_input=True, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_fuse_pred = per_frame_object_result['fuse_pred'][0]
    per_frame_per_object_pred_iou = per_frame_object_result['pred_iou']
    per_frame_per_pred_object_score = per_frame_object_result[
        'pred_object_score']

    show_image = Image.open(frames_path_list[frame_idx])
    green_background = np.zeros_like(show_image, dtype=np.float32)
    green_background[:, :] = [0, 255, 0]  # RGB格式

    # 得到前景区域和背景区域并合并
    per_frame_per_object_fuse_pred = np.expand_dims(
        per_frame_per_object_fuse_pred, axis=-1)
    foreground = show_image * per_frame_per_object_fuse_pred
    background = green_background * (1 - per_frame_per_object_fuse_pred)
    result_image = foreground + background
    result_image = result_image.astype(np.uint8)

    plt.figure(figsize=(9, 6))
    plt.imshow(result_image)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_pred_iou:.3f}, Object Score: {per_frame_per_pred_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_object_info(video_state_dict)

In [None]:
frame_idx = 60

# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
input_box = np.array([180, 0, 450, 430], dtype=np.float32)
print(input_box.shape)

input_prompt_box = np.expand_dims(input_box, axis=0)
print(input_prompt_box.shape,input_prompt_box.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=input_prompt_box,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_box(input_box, plt.gca())

global_preds, local_preds, fused_preds, iou_preds = sam2_model.forward_one_image_test(
    video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0, 1, 2, 3])
global_preds = global_preds[0][0].permute(1, 2, 0).float().cpu().numpy()
local_preds = local_preds[0][0][0].float().cpu().numpy()
fused_preds = fused_preds[0][0][0].float().cpu().numpy()
iou_preds = iou_preds[0][0].float().cpu().numpy()
print(global_preds.shape, local_preds.shape, fused_preds.shape,
      iou_preds.shape, iou_preds)

local_preds = np.expand_dims(local_preds, axis=-1)
fused_preds = np.expand_dims(fused_preds, axis=-1)
print(local_preds.shape, np.max(local_preds), np.min(local_preds))
print(fused_preds.shape, np.max(fused_preds), np.min(fused_preds))

plt.figure(figsize=(9, 6))
plt.imshow(global_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, global pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(local_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, local pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(fused_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, fused pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

# 创建绿色背景
show_image = Image.open(frames_path_list[frame_idx])
green_background = np.zeros_like(show_image, dtype=np.float32)
green_background[:, :] = [0, 255, 0]  # RGB格式
print(green_background.shape, np.max(green_background),
      np.min(green_background))

# 得到前景区域和背景区域并合并
foreground = show_image * fused_preds
background = green_background * (1 - fused_preds)
result_image = foreground + background

foreground = foreground.astype(np.uint8)
background = background.astype(np.uint8)
result_image = result_image.astype(np.uint8)

plt.figure(figsize=(9, 6))
plt.imshow(foreground)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(background)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(result_image)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

In [None]:
tracking_object_id = 0
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=60, tracking_object_ids=[tracking_object_id], use_point_prompt_input=False,use_box_prompt_input=True, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_fuse_pred = per_frame_object_result['fuse_pred'][0]
    per_frame_per_object_pred_iou = per_frame_object_result['pred_iou']
    per_frame_per_pred_object_score = per_frame_object_result[
        'pred_object_score']

    show_image = Image.open(frames_path_list[frame_idx])
    green_background = np.zeros_like(show_image, dtype=np.float32)
    green_background[:, :] = [0, 255, 0]  # RGB格式

    # 得到前景区域和背景区域并合并
    per_frame_per_object_fuse_pred = np.expand_dims(
        per_frame_per_object_fuse_pred, axis=-1)
    foreground = show_image * per_frame_per_object_fuse_pred
    background = green_background * (1 - per_frame_per_object_fuse_pred)
    result_image = foreground + background
    result_image = result_image.astype(np.uint8)

    plt.figure(figsize=(9, 6))
    plt.imshow(result_image)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_pred_iou:.3f}, Object Score: {per_frame_per_pred_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_object_info(video_state_dict)

In [None]:
frame_idx = 199

# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
input_box = np.array([300, 50, 470, 375], dtype=np.float32)
print(input_box.shape)

input_prompt_box = np.expand_dims(input_box, axis=0)
print(input_prompt_box.shape,input_prompt_box.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=input_prompt_box,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_box(input_box, plt.gca())

global_preds, local_preds, fused_preds, iou_preds = sam2_model.forward_one_image_test(
    video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0, 1, 2, 3])
global_preds = global_preds[0][0].permute(1, 2, 0).float().cpu().numpy()
local_preds = local_preds[0][0][0].float().cpu().numpy()
fused_preds = fused_preds[0][0][0].float().cpu().numpy()
iou_preds = iou_preds[0][0].float().cpu().numpy()
print(global_preds.shape, local_preds.shape, fused_preds.shape,
      iou_preds.shape, iou_preds)

local_preds = np.expand_dims(local_preds, axis=-1)
fused_preds = np.expand_dims(fused_preds, axis=-1)
print(local_preds.shape, np.max(local_preds), np.min(local_preds))
print(fused_preds.shape, np.max(fused_preds), np.min(fused_preds))

plt.figure(figsize=(9, 6))
plt.imshow(global_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, global pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(local_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, local pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(fused_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, fused pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

# 创建绿色背景
show_image = Image.open(frames_path_list[frame_idx])
green_background = np.zeros_like(show_image, dtype=np.float32)
green_background[:, :] = [0, 255, 0]  # RGB格式
print(green_background.shape, np.max(green_background),
      np.min(green_background))

# 得到前景区域和背景区域并合并
foreground = show_image * fused_preds
background = green_background * (1 - fused_preds)
result_image = foreground + background

foreground = foreground.astype(np.uint8)
background = background.astype(np.uint8)
result_image = result_image.astype(np.uint8)

plt.figure(figsize=(9, 6))
plt.imshow(foreground)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(background)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(result_image)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

In [None]:
tracking_object_id = 0
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=199, tracking_object_ids=[tracking_object_id], use_point_prompt_input=False,use_box_prompt_input=True, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_fuse_pred = per_frame_object_result['fuse_pred'][0]
    per_frame_per_object_pred_iou = per_frame_object_result['pred_iou']
    per_frame_per_pred_object_score = per_frame_object_result[
        'pred_object_score']

    show_image = Image.open(frames_path_list[frame_idx])
    green_background = np.zeros_like(show_image, dtype=np.float32)
    green_background[:, :] = [0, 255, 0]  # RGB格式

    # 得到前景区域和背景区域并合并
    per_frame_per_object_fuse_pred = np.expand_dims(
        per_frame_per_object_fuse_pred, axis=-1)
    foreground = show_image * per_frame_per_object_fuse_pred
    background = green_background * (1 - per_frame_per_object_fuse_pred)
    result_image = foreground + background
    result_image = result_image.astype(np.uint8)

    plt.figure(figsize=(9, 6))
    plt.imshow(result_image)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_pred_iou:.3f}, Object Score: {per_frame_per_pred_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_object_info(video_state_dict)

In [None]:
frame_idx = 0
prompt_mask_path= '/root/code/SimpleAICV_pytorch_training_examples/14.video_interactive_segmentation_training/sam2_predict_example/test_videos/bedroom/00000_prompt_mask_for_target_0.png'

input_mask = np.array(Image.open(prompt_mask_path).convert('L'), dtype=np.uint8)
input_mask = input_mask / 255.
input_mask = input_mask.astype(np.float32)
print(input_mask.shape)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_mask(input_mask, plt.gca())

input_prompt_mask = input_mask
print(input_prompt_mask.shape,input_prompt_mask.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=None,
                                    prompt_mask=input_prompt_mask)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

In [None]:
tracking_object_id = 0
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=0, tracking_object_ids=[tracking_object_id], use_point_prompt_input=False,use_box_prompt_input=False, use_mask_prompt_input=True)

frame_num = video_state_dict['video_frame_num']
object_track_state = video_state_dict['object_track_state'][tracking_object_id]
object_track_result = video_state_dict['object_track_result'][tracking_object_id]
print(frame_num, object_track_state)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result[frame_idx]
    per_frame_per_object_fuse_pred = per_frame_object_result['fuse_pred'][0]
    per_frame_per_object_pred_iou = per_frame_object_result['pred_iou']
    per_frame_per_pred_object_score = per_frame_object_result[
        'pred_object_score']

    show_image = Image.open(frames_path_list[frame_idx])
    green_background = np.zeros_like(show_image, dtype=np.float32)
    green_background[:, :] = [0, 255, 0]  # RGB格式

    # 得到前景区域和背景区域并合并
    per_frame_per_object_fuse_pred = np.expand_dims(
        per_frame_per_object_fuse_pred, axis=-1)
    foreground = show_image * per_frame_per_object_fuse_pred
    background = green_background * (1 - per_frame_per_object_fuse_pred)
    result_image = foreground + background
    result_image = result_image.astype(np.uint8)

    plt.figure(figsize=(9, 6))
    plt.imshow(result_image)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_pred_iou:.3f}, Object Score: {per_frame_per_pred_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

In [None]:
video_state_dict = sam2_model.clear_video_state_dict_all_object_info(video_state_dict)

In [None]:
frame_idx = 0

# Let's add a box at (x_min, y_min, x_max, y_max) = (150, 130, 290, 410) to get started
input_box = np.array([150, 130, 290, 410], dtype=np.float32)
print(input_box.shape)

input_prompt_box = np.expand_dims(input_box, axis=0)
print(input_prompt_box.shape,input_prompt_box.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=input_prompt_box,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_box(input_box, plt.gca())

global_preds, local_preds, fused_preds, iou_preds = sam2_model.forward_one_image_test(
    video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0, 1, 2, 3])
global_preds = global_preds[0][0].permute(1, 2, 0).float().cpu().numpy()
local_preds = local_preds[0][0][0].float().cpu().numpy()
fused_preds = fused_preds[0][0][0].float().cpu().numpy()
iou_preds = iou_preds[0][0].float().cpu().numpy()
print(global_preds.shape, local_preds.shape, fused_preds.shape,
      iou_preds.shape, iou_preds)

local_preds = np.expand_dims(local_preds, axis=-1)
fused_preds = np.expand_dims(fused_preds, axis=-1)
print(local_preds.shape, np.max(local_preds), np.min(local_preds))
print(fused_preds.shape, np.max(fused_preds), np.min(fused_preds))

plt.figure(figsize=(9, 6))
plt.imshow(global_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, global pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(local_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, local pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(fused_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, fused pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

# 创建绿色背景
show_image = Image.open(frames_path_list[frame_idx])
green_background = np.zeros_like(show_image, dtype=np.float32)
green_background[:, :] = [0, 255, 0]  # RGB格式
print(green_background.shape, np.max(green_background),
      np.min(green_background))

# 得到前景区域和背景区域并合并
foreground = show_image * fused_preds
background = green_background * (1 - fused_preds)
result_image = foreground + background

foreground = foreground.astype(np.uint8)
background = background.astype(np.uint8)
result_image = result_image.astype(np.uint8)

plt.figure(figsize=(9, 6))
plt.imshow(foreground)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(background)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(result_image)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

In [None]:
frame_idx = 0

# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
input_box = np.array([300, 0, 500, 400], dtype=np.float32)
print(input_box.shape)

input_prompt_box = np.expand_dims(input_box, axis=0)
print(input_prompt_box.shape,input_prompt_box.dtype)

exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask = sam2_model.add_new_object_prompt_input(
                                    video_state_dict,
                                    frame_idx=frame_idx,
                                    prompt_point=None,
                                    prompt_box=input_prompt_box,
                                    prompt_mask=None)
print(exist_object_ids, frame_idx, new_object_id, has_prompt_point, has_prompt_box, has_prompt_mask)

plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
show_image = Image.open(frames_path_list[frame_idx])
plt.imshow(show_image)
show_box(input_box, plt.gca())

global_preds, local_preds, fused_preds, iou_preds = sam2_model.forward_one_image_test(
    video_state_dict, new_object_id, frame_idx, mask_out_idxs=[0, 1, 2, 3])
global_preds = global_preds[0][0].permute(1, 2, 0).float().cpu().numpy()
local_preds = local_preds[0][0][0].float().cpu().numpy()
fused_preds = fused_preds[0][0][0].float().cpu().numpy()
iou_preds = iou_preds[0][0].float().cpu().numpy()
print(global_preds.shape, local_preds.shape, fused_preds.shape,
      iou_preds.shape, iou_preds)

local_preds = np.expand_dims(local_preds, axis=-1)
fused_preds = np.expand_dims(fused_preds, axis=-1)
print(local_preds.shape, np.max(local_preds), np.min(local_preds))
print(fused_preds.shape, np.max(fused_preds), np.min(fused_preds))

plt.figure(figsize=(9, 6))
plt.imshow(global_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, global pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(local_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, local pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(fused_preds)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, fused pred, IoU Score: {iou_preds:.3f}",
          fontsize=18)
plt.axis('off')
plt.show()

# 创建绿色背景
show_image = Image.open(frames_path_list[frame_idx])
green_background = np.zeros_like(show_image, dtype=np.float32)
green_background[:, :] = [0, 255, 0]  # RGB格式
print(green_background.shape, np.max(green_background),
      np.min(green_background))

# 得到前景区域和背景区域并合并
foreground = show_image * fused_preds
background = green_background * (1 - fused_preds)
result_image = foreground + background

foreground = foreground.astype(np.uint8)
background = background.astype(np.uint8)
result_image = result_image.astype(np.uint8)

plt.figure(figsize=(9, 6))
plt.imshow(foreground)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(background)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

plt.figure(figsize=(9, 6))
plt.imshow(result_image)
show_box(input_box, plt.gca())
plt.title(f"frame: {frame_idx}, IoU Score: {iou_preds:.3f}", fontsize=18)
plt.axis('off')
plt.show()

In [None]:
tracking_object_ids = [0, 1]
video_state_dict = sam2_model.forward_tracking_for_test(video_state_dict, start_tracking_frame_idx=0, tracking_object_ids=tracking_object_ids, use_point_prompt_input=False,use_box_prompt_input=True, use_mask_prompt_input=False)

frame_num = video_state_dict['video_frame_num']
object_track_state_0 = video_state_dict['object_track_state'][0]
object_track_result_0 = video_state_dict['object_track_result'][0]
print(frame_num, object_track_state_0)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result_0[frame_idx]
    per_frame_per_object_fuse_pred = per_frame_object_result['fuse_pred'][0]
    per_frame_per_object_pred_iou = per_frame_object_result['pred_iou']
    per_frame_per_pred_object_score = per_frame_object_result[
        'pred_object_score']

    show_image = Image.open(frames_path_list[frame_idx])
    green_background = np.zeros_like(show_image, dtype=np.float32)
    green_background[:, :] = [0, 255, 0]  # RGB格式

    # 得到前景区域和背景区域并合并
    per_frame_per_object_fuse_pred = np.expand_dims(
        per_frame_per_object_fuse_pred, axis=-1)
    foreground = show_image * per_frame_per_object_fuse_pred
    background = green_background * (1 - per_frame_per_object_fuse_pred)
    result_image = foreground + background
    result_image = result_image.astype(np.uint8)

    plt.figure(figsize=(9, 6))
    plt.imshow(result_image)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_pred_iou:.3f}, Object Score: {per_frame_per_pred_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()

frame_num = video_state_dict['video_frame_num']
object_track_state_1 = video_state_dict['object_track_state'][1]
object_track_result_1 = video_state_dict['object_track_result'][1]
print(frame_num, object_track_state_1)

vis_frame_stride = 30
for frame_idx in range(0, len(frames_path_list), vis_frame_stride):
    per_frame_object_result = object_track_result_1[frame_idx]
    per_frame_per_object_fuse_pred = per_frame_object_result['fuse_pred'][0]
    per_frame_per_object_pred_iou = per_frame_object_result['pred_iou']
    per_frame_per_pred_object_score = per_frame_object_result[
        'pred_object_score']

    show_image = Image.open(frames_path_list[frame_idx])
    green_background = np.zeros_like(show_image, dtype=np.float32)
    green_background[:, :] = [0, 255, 0]  # RGB格式

    # 得到前景区域和背景区域并合并
    per_frame_per_object_fuse_pred = np.expand_dims(
        per_frame_per_object_fuse_pred, axis=-1)
    foreground = show_image * per_frame_per_object_fuse_pred
    background = green_background * (1 - per_frame_per_object_fuse_pred)
    result_image = foreground + background
    result_image = result_image.astype(np.uint8)

    plt.figure(figsize=(9, 6))
    plt.imshow(result_image)
    plt.title(
        f"frame: {frame_idx}, IoU Score: {per_frame_per_object_pred_iou:.3f}, Object Score: {per_frame_per_pred_object_score:.3f}",
        fontsize=18)
    plt.axis('off')
    plt.show()