In [None]:
import os
import numpy as np
import torch
import cv2
from sam2.build_sam import build_sam2_video_predictor

# Set device for computation
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"using device: {device}")

if device.type == "cuda":
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print("\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might give numerically different outputs and sometimes degraded performance on MPS. See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion.")

# Initialize predictor
sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

def show_mask(mask, frame, obj_id=None, random_color=False):
    color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) if random_color else np.array([*plt.get_cmap("tab10")(0 if obj_id is None else obj_id)[:3], 0.6])
    mask = mask.reshape(*mask.shape[-2:], 1)
    color = color.reshape(1, 1, -1)
    frame[mask > 0] = frame[mask > 0] * (1 - color[3]) + color[:3] * 255 * color[3]


In [None]:
video_dir = "./videos/aria"
frame_names = sorted([p for p in os.listdir(video_dir) if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg"]], key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)

points, labels = np.array([[1100, 900]], dtype=np.float32), np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(inference_state, frame_idx=0, obj_id=1, points=points, labels=labels)

video_segments = {out_frame_idx: {out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids)} for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state)}