In [None]:
import os
import cv2
import torch
from PIL import Image
import numpy as np
from transformers import AutoModel
from torchvision.io import read_video, write_video
from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything

device = "cuda:0" if torch.cuda.is_available() else "cpu"
theia_model = AutoModel.from_pretrained("theaiinstitute/theia-tiny-patch16-224-cddsv", trust_remote_code=True)
theia_model = theia_model.to(device)
target_model_names = [
    "google/vit-huge-patch14-224-in21k",
    "facebook/dinov2-large",
    "openai/clip-vit-large-patch14",
    "facebook/sam-vit-huge",
    "LiheYoung/depth-anything-large-hf",
]
feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root="../../../feature_stats")

mask_generator, sam_model = prepare_mask_generator(device)
depth_anything_model_name = "LiheYoung/depth-anything-large-hf"
depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, device)

example_video_path = "../../../media/example_video_to_visualize.mp4"
video, _, _ = read_video(example_video_path, pts_unit="sec", output_format="THWC")
video = video.numpy()
images = [Image.fromarray(cv2.resize(im, (224, 224))) for im in video]

theia_decode_results, gt_decode_results = decode_everything(
    theia_model=theia_model,
    feature_means=feature_means,
    feature_vars=feature_vars,
    images=images,
    mask_generator=mask_generator,
    sam_model=sam_model,
    depth_anything_decoder=depth_anything_decoder,
    pred_iou_thresh=0.5,
    stability_score_thresh=0.7,
    gt=True,
    device=device,
)

vis_video = np.stack(
    [np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)]
)
vis_video = torch.from_numpy(vis_video * 255.0).to(torch.uint8)
vis_save_path = "./visualized.mp4"
write_video(vis_save_path, vis_video, fps=10)