In [6]:
import os
import gc
from tqdm import tqdm

import numpy as np

import cv2
import torch
from sam2.build_sam import build_sam2_video_predictor
from sam2.utils.misc import load_video_frames

DIR_RAW = "D:/Documents/devs/fight_motion/data/raw"
DIR_INT = "D:/Documents/devs/fight_motion/data/interim"
DIR_SAM = "D:/Documents/devs/fight_motion/sam2-main"


In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Define colors for the masks
colors = [(255, 0, 0), (0, 0, 255)]  # Red and Blue

In [3]:
# Load prompts from the text file
def load_prompts(txt_path):
    with open(txt_path, 'r') as f:
        lines = f.readlines()
    # The first line indicates the frame index
    frame_idx = int(lines[0].strip())
    # The second line contains the bounding boxes
    bbox_line = lines[1].strip()
    # Parse the bounding boxes
    bbox_list = eval(bbox_line)  # Use eval to convert string representation to list of tuples
    return frame_idx, bbox_list

In [17]:
def extract_fighter_masks(video_path, txt_path, output_video_path):
    # Load the model
    checkpoint = os.path.join(DIR_SAM, "checkpoints/sam2.1_hiera_small.pt")
    model_cfg = os.path.join(DIR_SAM, "sam2/configs/sam2.1/sam2.1_hiera_s.yaml")
    # checkpoint = os.path.join(DIR_SAM, "checkpoints/sam2.1_hiera_large.pt")
    # model_cfg = os.path.join(DIR_SAM, "sam2/configs/sam2.1/sam2.1_hiera_l.yaml")
    predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device)

    # Initialize the inference state
    # video_path = os.path.join(DIR_RAW, "aldo_holloway_1.mp4")
    state = predictor.init_state(video_path, offload_video_to_cpu=True)

    # txt_path = os.path.join(DIR_INT, "aldo_holloway_1.txt")
    frame0, prompts = load_prompts(txt_path)

    # Open video for reading and writing
    cap = cv2.VideoCapture(video_path)
    frame_rate = cap.get(cv2.CAP_PROP_FPS)
    loaded_frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        loaded_frames.append(frame)
    cap.release()

    height, width = loaded_frames[0].shape[:2]
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    # output_video_path = os.path.join(DIR_INT, "aldo_holloway_1_masks.mp4")
    out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height))

    # Initialize mask storage for both fighters
    fighter_masks = ([], [])

    # Add prompts and start tracking
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
        for obj_id, bbox in enumerate(prompts):
            _, _, _ = predictor.add_new_points_or_box(state, frame_idx=frame0, obj_id=obj_id, box=bbox)

        # Forward tracking
        for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
            # print(f"Processed frame {frame_idx} forward")
            for obj_id, mask in zip(object_ids, masks):
                mask = mask[0].cpu().numpy()
                mask = mask > 0.0
                fighter_masks[obj_id].append(mask)

        # Backward tracking
        for frame_idx, object_ids, masks in predictor.propagate_in_video(state, start_frame_idx=frame0, reverse=True):
            # print(f"Processed frame {frame_idx} backward")
            if frame_idx==frame0:
                continue
            for obj_id, mask in zip(object_ids, masks):
                mask = mask[0].cpu().numpy()
                mask = mask > 0.0
                fighter_masks[obj_id].insert(0, mask)  # Insert at the beginning for backward frames

        # Overlay masks on the original video
        for frame_idx in tqdm(range(len(loaded_frames))):
            img = loaded_frames[frame_idx].copy()
            for obj_id, masks in enumerate(fighter_masks):
                if frame_idx < len(masks):
                    mask = masks[frame_idx]
                    mask_img = np.zeros((height, width, 3), np.uint8)
                    mask_img[mask] = colors[obj_id]
                    img = cv2.addWeighted(img, 1, mask_img, 0.5, 0)  # Adjust opacity as needed

            out.write(img)

    out.release()

    del predictor, state
    gc.collect()
    torch.clear_autocast_cache()
    torch.cuda.empty_cache()

In [18]:
for video_name in os.listdir(DIR_RAW):
    if video_name.endswith(".mp4") and video_name=='aldo_holloway_2.mp4':
        txt_path = os.path.join(DIR_INT, video_name.replace(".mp4", ".txt"))
        video_path = os.path.join(DIR_RAW, video_name)
        output_video_path = os.path.join(
            DIR_INT, video_name.replace(".mp4", "_masks.mp4")
        )
        print("extracting fighter masks for", video_name)
        extract_fighter_masks(video_path, txt_path, output_video_path)

extracting fighter masks for cerrone_story_2.mp4



Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(
propagate in video: 100%|██████████| 201/201 [00:23<00:00,  8.58it/s]
propagate in video: 100%|██████████| 10/10 [00:01<00:00, 10.00it/s]
100%|██████████| 210/210 [00:02<00:00, 82.57it/s]
