# Setting up the Workspace

To run this jupyter notebook, you will require test.mp4 and the requirements.txt from this [github repo](https://github.com/therealnaveenkamal/gest_detectron). Make sure you have python3.12 installed.

In [1]:
#!pip install -r requirements.txt

!pip install mediapipe
!pip install opencv-python
!pip install matplotlib
!pip install "git+https://github.com/facebookresearch/sam2.git"
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118

#If you run the notebook in colab, you might encounter a prompt asing you to restart due to package import. Feel free to click CANCEL

Collecting mediapipe
  Downloading mediapipe-0.10.20-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (9.7 kB)
Collecting sounddevice>=0.4.4 (from mediapipe)
  Downloading sounddevice-0.5.1-py3-none-any.whl.metadata (1.4 kB)
Downloading mediapipe-0.10.20-cp311-cp311-manylinux_2_28_x86_64.whl (35.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m35.6/35.6 MB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading sounddevice-0.5.1-py3-none-any.whl (32 kB)
Installing collected packages: sounddevice, mediapipe
Successfully installed mediapipe-0.10.20 sounddevice-0.5.1
Collecting git+https://github.com/facebookresearch/sam2.git
  Cloning https://github.com/facebookresearch/sam2.git to /tmp/pip-req-build-xuyh1xio
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/sam2.git /tmp/pip-req-build-xuyh1xio
  Resolved https://github.com/facebookresearch/sam2.git to commit 2b90b9f5ceec907a1c18123530e92e794ad901a4
  Installing build depend

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu118


In [2]:
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
!wget -q https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task

# Import Required Modules

In [3]:
import torch
import torchvision
import sys
import os
import numpy as np
import cv2

import matplotlib.pyplot as plt
from PIL import Image
import mediapipe as mp
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())

PyTorch version: 2.5.1+cu121
Torchvision version: 0.20.1+cu121
CUDA is available: True


# Frame Extraction and Hand Landmarker data extraction using MediaPipe

In [4]:
MARGIN = 10  # pixels
FONT_SIZE = 1
FONT_THICKNESS = 1
HANDEDNESS_TEXT_COLOR = (88, 205, 54) # vibrant green

def draw_landmarks_on_image(rgb_image, detection_result):
  hand_landmarks_list = detection_result.hand_landmarks
  handedness_list = detection_result.handedness
  annotated_image = np.copy(rgb_image)

  # Loop through the detected hands to visualize.
  for idx in range(len(hand_landmarks_list)):
    hand_landmarks = hand_landmarks_list[idx]
    handedness = handedness_list[idx]

    # Draw the hand landmarks.
    hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
    hand_landmarks_proto.landmark.extend([
      landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks
    ])
    solutions.drawing_utils.draw_landmarks(
      annotated_image,
      hand_landmarks_proto,
      solutions.hands.HAND_CONNECTIONS,
      solutions.drawing_styles.get_default_hand_landmarks_style(),
      solutions.drawing_styles.get_default_hand_connections_style())

    # Get the top left corner of the detected hand's bounding box.
    height, width, _ = annotated_image.shape
    x_coordinates = [landmark.x for landmark in hand_landmarks]
    y_coordinates = [landmark.y for landmark in hand_landmarks]
    text_x = int(min(x_coordinates) * width)
    text_y = int(min(y_coordinates) * height) - MARGIN

    # Draw handedness (left or right hand) on the image.
    cv2.putText(annotated_image, f"{handedness[0].category_name}",
                (text_x, text_y), cv2.FONT_HERSHEY_DUPLEX,
                FONT_SIZE, HANDEDNESS_TEXT_COLOR, FONT_THICKNESS, cv2.LINE_AA)

  return annotated_image


# MediaPipe setup
BaseOptions = mp.tasks.BaseOptions
HandLandmarker = mp.tasks.vision.HandLandmarker
HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions
VisionRunningMode = mp.tasks.vision.RunningMode

# Configure video processing options
options = HandLandmarkerOptions(
    base_options=BaseOptions(model_asset_path='hand_landmarker.task'),
    running_mode=VisionRunningMode.VIDEO,  # VIDEO mode
    num_hands=2)

def process_video(input_path, output_path):
    # Initialize video capture
    cap = cv2.VideoCapture(input_path)

    # Get video properties
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Initialize video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    frame_data = []

    with HandLandmarker.create_from_options(options) as detector:
        frame_timestamp = 0
        count = 0
        print("Hand Landmarking In Progress...")

        output_dir = "frames"
        os.makedirs(output_dir, exist_ok=True)

        while True:
            ret, frame = cap.read()

            if not ret:
                print("Hand Landmarked Video Rendering Completed")
                break

            # Convert BGR to RGB
            rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            # Create MediaPipe Image
            mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame)

            # Detect hand landmarks (with timestamp)
            detection_result = detector.detect_for_video(mp_image, frame_timestamp)

            frame_entry = {
                "timestamp": frame_timestamp,
                "landmarks": [],
                "handedness": []
            }

            for hand_landmarks in detection_result.hand_landmarks:
                frame_entry["landmarks"].append([(lm.x, lm.y) for lm in hand_landmarks])

            for classification in detection_result.handedness:
                frame_entry["handedness"].append([(c.category_name, c.score) for c in classification])

            frame_data.append(frame_entry)

            # Draw landmarks (using your existing function)
            annotated_image = draw_landmarks_on_image(rgb_frame, detection_result)

            # Convert back to BGR for video output
            bgr_frame = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)

            # Write processed frame
            out.write(bgr_frame)

            output_file = os.path.join(output_dir, f"{count:05d}.jpg")
            # Save the frame as a JPG with specified quality
            cv2.imwrite(output_file, frame, [int(cv2.IMWRITE_JPEG_QUALITY), 95])

            # Increment frame timestamp on milliseconds
            frame_timestamp += int(1000 / fps)
            count+=1

    cap.release()
    out.release()
    return frame_data

# Helper Functions for SAM2

In [5]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        if(obj_id<=20):
          cmap_idx = 1
        else:
          cmap_idx = 2
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=20):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white', linewidth=1)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='.', s=marker_size, edgecolor='white', linewidth=1)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

# Predictor Initialization - Feeding Clicks from Hand Landmarker data

In [6]:
def init_sam_predictor(fd, device, sam2_checkpoint = "sam2.1_hiera_large.pt", model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml", video_dir = "./frames"):
    predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device)
    inference_state = predictor.init_state(video_path=video_dir)

    predictor.reset_state(inference_state)
    prompts={}

    cap = cv2.VideoCapture("test.mp4")

    # Get video properties
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))


    ann_frame_idx = 0  # the frame index we interact with
    ic, ic2 = 0, 0

    for ic, elem in enumerate(fd[0]['landmarks'][0]):
        if(ic%4 ==0):
            temp = []
            ann_obj_id = ic
            temp.append([elem[0]*width, elem[1]*height])
            points = np.array(temp, dtype=np.float32)
            labels = np.array(np.ones(1), np.int32)

            prompts[ann_obj_id] = points, labels


            _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                inference_state=inference_state,
                frame_idx=ann_frame_idx,
                obj_id=ann_obj_id,
                points=points,
                labels=labels,
            )

    ic+=1
    for ic2, elem in enumerate(fd[0]['landmarks'][1]):
        if(ic2%4 ==0):
            temp = []
            ann_obj_id = ic+ic2
            temp.append([elem[0]*width, elem[1]*height])
            points = np.array(temp, dtype=np.float32)
            labels = np.array(np.ones(1), np.int32)

            prompts[ann_obj_id] = points, labels


            _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                inference_state=inference_state,
                frame_idx=ann_frame_idx,
                obj_id=ann_obj_id,
                points=points,
                labels=labels,
            )

    return predictor, inference_state


# Render Final SAM2 Segmented Video Output

In [7]:
def render_sam_video(video_path, video_segments, output_path, alpha=0.5):
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    # Create color mapping for objects
    obj_ids = list({k for frame in video_segments.values() for k in frame.keys()})

    frame_idx = 0

    output_dir = "segments"
    os.makedirs(output_dir, exist_ok=True)

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        if frame_idx in video_segments:
            # Convert to RGB for processing (SAM masks are in RGB space)
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            overlay = frame_rgb.copy()

            for obj_id, mask in video_segments[frame_idx].items():
                # Ensure mask is 2D and matches frame dimensions
                mask = mask[0]
                if mask.shape != (height, width):
                    mask = cv2.resize(mask.astype(np.uint8), (width, height))

                cmap = plt.get_cmap("tab10")

                if(obj_id <=20):
                  color = [255, 0, 0]
                else:
                  color = [0, 255, 0]

                # Create colored mask
                mask_bgr = np.zeros_like(overlay)
                mask_bgr[mask] = color

                # Blend mask with overlay
                overlay = cv2.addWeighted(overlay, 1, mask_bgr, alpha, 0)

            # Convert back to BGR for video writing
            frame_out = cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)
            output_file = os.path.join(output_dir, f"{frame_idx:05d}.jpg")
            # Save the frame as a JPG with specified quality
            cv2.imwrite(output_file, frame_out, [int(cv2.IMWRITE_JPEG_QUALITY), 50])

            out.write(frame_out)

        frame_idx += 1

    cap.release()
    out.release()
    print(f"Saved SAM masked video to {output_path}")

# **Main Code - Execution**

In [10]:
from sam2.build_sam import build_sam2_video_predictor

# Generating Hand Landmarks

frame_data = process_video("test.mp4", "landmarked_output.mp4")

print("Video Frame Extracted")

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")


if device.type == "cuda":
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

print("Predictor Calling")

predictor, inference_state = init_sam_predictor(frame_data, device, sam2_checkpoint = "sam2.1_hiera_large.pt", model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml", video_dir = "./frames")

print("Predictor Initialized")

video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }


render_sam_video(
    video_path="test.mp4",
    video_segments=video_segments,
    output_path="sam_masked_output_final.mp4",
    alpha=1
)

Hand Landmarking In Progress...
Hand Landmarked Video Rendering Completed
Video Frame Extracted
using device: cuda
Predictor Calling


frame loading (JPEG): 100%|██████████| 210/210 [00:09<00:00, 21.85it/s]


Predictor Initialized


propagate in video: 100%|██████████| 210/210 [14:00<00:00,  4.00s/it]


Saved SAM masked video to sam_masked_output_final.mp4
