In [None]:
%load_ext autoreload
%autoreload 2
%cd ..

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import hydra
import numpy as np
import mediapy as media
from einops import rearrange
from tqdm import tqdm
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
import logging
import pickle

from densetrack3d.models.geometry_utils import least_square_align

from densetrack3d.datasets.custom_data import read_data, read_data_with_depthcrafter

# from densetrack3d.models.densetrack3d.densetrack3d import DenseTrack3D
from densetrack3d.models.densetrack3d.densetrack3d import DenseTrack3D

from densetrack3d.utils.visualizer import Visualizer, flow_to_rgb



from densetrack3d.models.model_utils import (
    smart_cat, 
    get_points_on_a_grid, 
    bilinear_sample2d,
    get_grid,
    bilinear_sampler,
    reduce_masked_mean
)

from densetrack3d.models.predictor.dense_predictor import DensePredictor3D




In [3]:
checkpoint = "checkpoints/densetrack3d.pth"

model = DenseTrack3D(
    stride=4,
    window_len=16,
    add_space_attn=True,
    num_virtual_tracks=64,
    model_resolution=(384, 512),
)


with open(checkpoint, "rb") as f:
    state_dict = torch.load(f, map_location="cpu")
    if "model" in state_dict:
        state_dict = state_dict["model"]
model.load_state_dict(state_dict, strict=False)

predictor = DensePredictor3D(model=model)
predictor = predictor.eval().cuda()

  state_dict = torch.load(f, map_location="cpu")


In [4]:
vis = Visualizer(
    save_dir="results/test",
    fps=7,
    show_first_frame=0,
    linewidth=1
)

In [5]:
vid_names = ["rollerblade"]

In [None]:

save_dir = "results/demo/"
os.makedirs(save_dir, exist_ok=True)

for vid_name in vid_names:
    
    video, videodepth, videodisp = read_data_with_depthcrafter("demo_data", vid_name)

    if videodisp is not None:
        videodepth = least_square_align(videodepth, videodisp)

    video = torch.from_numpy(video).permute(0,3,1,2).cuda()[None].float()
    videodepth = torch.from_numpy(videodepth).unsqueeze(1).cuda()[None].float()


    H, W = video.shape[-2:]

    out_dict = predictor(
        video,
        videodepth,
        grid_query_frame=0,
    )


    trajs_3d_dict = {k: v[0].cpu().numpy() for k, v in out_dict["trajs_3d_dict"].items()}
    
    with open(os.path.join(save_dir, f"{vid_name}.pkl"), "wb") as handle:
        pickle.dump(trajs_3d_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
