In [1]:
import os
import sys
from pathlib import Path

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import torch
from PIL import Image
from tqdm import trange

sys.path.append("..")

from core.config import default_config
from core.extract_video_rgb import extract_video_rgb
from core.image_segmentor import load_sam_auto_gen, segment_with_sam
from core.mask_handler import MaskHandler
from core.propagator import load_sam2_video_predictor_and_initial_state
from core.task_handler import TaskHandler
from core.utils import ID2RGBConverter, flatten_mask, viz_mask

### Step 1: extract rgb into a folder

In [2]:
video_path = "/scratch/quanta/Experiments/Thesis/exps/prob_video_4_open_laptop/viz/rgb.mp4"
work_dir = Path(
    "/scratch/quanta/Experiments/Thesis/exps/sam2_full_video_track_example")

In [None]:
extract_video_rgb(
    video_path=video_path,
    save_workspace_dir=work_dir,
)

### Step 2: Key frame image segmentation

In [3]:
sam = load_sam_auto_gen(
    ckpt_pth="/scratch/quanta/Models/SAM/sam_vit_h_4b8939.pth",
    points_per_side=default_config['sam1_points_per_side'],
    pred_iou_thresh=default_config['sam1_pred_iou_thrshold'],
)

In [None]:
segment_with_sam(
    rgb_dir=str(work_dir / "data/rgb"),
    save_dir=str(work_dir / "data/sam_1_seg"),
    sam_auto_gen=sam,
    min_size=default_config['min_size'],
    step=default_config['step'],
    max_masks_per_frame=default_config['max_masks_per_frame'],
)

In [None]:
del sam
torch.cuda.empty_cache()

### Step 3: Use SAM2 to propagate and associate all masks

In [None]:
config_pth = "configs/sam2.1/sam2.1_hiera_l.yaml"
ckpt_pth = "/scratch/quanta/Models/SAM2/sam2.1_hiera_large.pt"
device = "cuda:0"
video_predictor, inference_state = load_sam2_video_predictor_and_initial_state(
    ckpt_pth=ckpt_pth,
    rgb_jpg_dir=str(work_dir / "data/rgb"),
    model_config=config_pth,
    device=device,
)

In [None]:
task_handler = TaskHandler(
    queue_dir=str(work_dir / "data/sam_2_queue"),
    sam_mask_dir=str(work_dir / "data/sam_1_seg"),
    save_mask_dir=str(work_dir / "data/sam_2_track"),
    video_predictor=video_predictor,
    inference_state=inference_state,
    step=default_config['step'],
    disappear_thresh=default_config['disappear_threshold'],
    iou_thresh=default_config['iou_threshold'],
    device=device,
)

In [None]:
task_handler.submit_initial_tasks()

In [None]:
task_return = True
while task_return is True:
    torch.cuda.empty_cache()
    task_return = task_handler.run_one_task()

In [None]:
# load id_map and union find
id_map = {}

import json
with open(str(work_dir / "data/sam_2_queue/id_map.json")) as f:
    tmp_map = json.load(f)

for i in tmp_map.keys():
    j = i
    while tmp_map[str(j)] != j:
        j = tmp_map[str(j)]

    id_map[int(i)] = j

with open(str(work_dir / "data/sam_2_queue/united_id_map.json"), "w") as f:
    json.dump(id_map, f, indent=4)

### Step 4: Visualize

In [None]:
converter = ID2RGBConverter()
mask_handler = MaskHandler(str(work_dir / "data/sam_2_track"))

viz_save_pth = work_dir / "temp/viz_sam2_association"
viz_save_pth.mkdir(parents=True, exist_ok=True)

In [None]:
for i in trange(task_handler.num_frames):
    masks_data = mask_handler.load_masks(i)
    obj_ids = [id_map[item['original_obj_id']] for item in masks_data]
    viz_img = viz_mask(
        flattened_mask=flatten_mask(
            mask=masks_data,
            object_id_list=obj_ids,
        ),
        converter=converter,
    )
    Image.fromarray(viz_img).save(str(viz_save_pth / "{:06d}.png".format(i)))

In [None]:
import ffmpeg

(ffmpeg.input(
    str(viz_save_pth / "*.png"),
    pattern_type="glob",
    framerate=30,
).output(
    str(work_dir / "viz/sam2_assotiation.mp4")
).run())