# Video segmentation with SAM 2

In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import shutil
import pandas as pd
import IPython
from tqdm import tqdm

In [None]:
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [3]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

def df_set(df, image_name, key, value):
    df.at[df.index[df['image_name'] == image_name][0], key] = value

def df_get(df, image_name, key):
    return df.at[df.index[df['image_name'] == image_name][0], key]

### Set up each split

In [4]:
root_dir = "/path/to/root/dir

In [7]:
video_paths = []

for vid_name in os.listdir(root_dir):
    vid_path = os.path.join(root_dir, vid_name)

    frame_path = os.path.join(vid_path, "frames")

    for split_name in os.listdir(frame_path):
        split_path = os.path.join(frame_path, split_name)
        video_paths.append(split_path)
    
print(f"Found {len(video_paths)} videos")

Found 20 videos (splits)


In [8]:
for vid_path in video_paths:
    images = os.listdir(vid_path)
    if "data.csv" in images:
        print(f"WARN - {vid_path} already has data.csv")
        continue
    data = {'image_name' : images}
    df = pd.DataFrame(data)
    df.to_csv(os.path.join(vid_path, "data.csv"), index=False)

### Loading the SAM 2 video predictor

In [9]:
from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "../checkpoints/sam2_hiera_tiny.pt"
model_cfg = "sam2_hiera_t.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)

In [None]:
i = 0
video_dir = "./videos/to_process"


while i < len(video_paths):
    if os.path.exists(video_dir):
        shutil.rmtree(video_dir)
    os.makedirs(video_dir)

    path = video_paths[i]
    frame_names = [
        p for p in os.listdir(path)
        if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
    ]
    frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
    frame_path = os.path.join(path, frame_names[0])
    first_frame = Image.open(frame_path)

    shutil.copy(frame_path, video_dir)
    inference_state = predictor.init_state(video_path=video_dir)
    predictor.reset_state(inference_state)

    IPython.display.clear_output()
    plt.clf()
    
    print(frame_path)
    plt.figure(figsize=(9, 6))
    plt.imshow(first_frame)

    df = pd.read_csv(os.path.join(path, "data.csv"))
    if 'click_point' in df.columns:
        df.drop(columns=['click_point'])
    if 'click_x' not in df.columns:
        df['click_x'] = -1.0
    if 'click_y' not in df.columns:
        df['click_y'] = -1.0
    if df_get(df, frame_names[0], 'click_x') != -1.0:
        x = df_get(df, frame_names[0], 'click_x')
        y = df_get(df, frame_names[0], 'click_y')
        points = np.array([[x, y]], dtype=np.float32)
        labels = np.array([1], np.int32)
        show_points(points, labels, plt.gca())

        _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
            inference_state=inference_state,
            frame_idx=0,
            obj_id=1,
            points=points,
            labels=labels,
        )
        show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

        plt.show()
        inp = input("confirm y/n?")
        if inp == "y":
            i += 1
            continue
    else:
        plt.show()

    inp = input("Enter x y")
    x, y = inp.split()
    x, y = float(x), float(y)

    IPython.display.clear_output()
    plt.clf()

    plt.figure(figsize=(9, 6))
    plt.imshow(first_frame)

    points = np.array([[x, y]], dtype=np.float32)
    labels = np.array([1], np.int32)
    show_points(points, labels, plt.gca())

    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
            inference_state=inference_state,
            frame_idx=0,
            obj_id=1,
            points=points,
            labels=labels,
        )
    show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

    plt.show()

    inp = input("confirm y/n?")

    IPython.display.clear_output()
    plt.clf()

    if inp == "y":
        df_set(df, frame_names[0], 'click_x', x)
        df_set(df, frame_names[0], 'click_y', y)
        df.to_csv(os.path.join(path, "data.csv"), index=False)
        i += 1