In [None]:
%load_ext autoreload
%autoreload 2

from condorgmm.data.mp4_vda import MP4DepthAnythingVideo

In [None]:
# og_video = MP4DepthAnythingVideo(
#     "/home/georgematheos/condorgmm/assets/custom/trees-01.mp4",
#     min_depth_meters=0.5,
#     max_depth_meters=2.0,
#     encoder="vits",
#     camera_type="short_focal_length",
# )
# video = og_video.downscale(3)

In [None]:
# video = og_video.crop(200, 600, 400, 1000)

In [None]:
from condorgmm.data.r3d_dataloader import R3DVideo
# og_video = R3DVideo(
#     "/home/georgematheos/condorgmm/assets/custom/folding-01.r3d",
# )
# video = og_video

import condorgmm
from condorgmm.utils.common import get_assets_path
import condorgmm.data as data
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
fig, ax = plt.subplots()

video = data.R3DVideo(get_assets_path() / "red-blanket.r3d")
frame_idxs = range(700, 840, 1)
# video = video.crop(0, 180, 16, 256

def update(idx):
    ax.clear()
    ax.imshow(video[frame_idxs[idx]].rgb)
    ax.set_title(f"Frame {frame_idxs[idx]}")

ani = FuncAnimation(fig, update, len(frame_idxs), repeat=False)
# HTML(ani.to_jshtml())

In [None]:
import matplotlib.pyplot as plt

plt.imshow(video[0].rgb)

In [None]:
plt.imshow(video[0].depth)
plt.colorbar()

In [None]:
from condorgmm.condor.interface.camera_tracking import initialize, update, fast_config
import rerun as rr
from condorgmm.condor.rerun import log_state
from tqdm import tqdm
import jax.numpy as jnp
import condorgmm

condorgmm.rr_init("pointcloud-00")
for i in range(10):
    rr.set_time_sequence("frame", i)
    condorgmm.rr_log_frame(video[i])

In [None]:
import genjax
genjax.pretty()

In [None]:
from condorgmm.condor.interface.camera_tracking import initialize, update, fast_config
import rerun as rr
from condorgmm.condor.rerun import log_state
from tqdm import tqdm
import jax.numpy as jnp
import condorgmm


def _f(x):
    return jnp.array(x, dtype=jnp.float32)

hyp = fast_config.base_hypers

cfg = fast_config.replace(
    base_hypers=hyp,
    n_gaussians=384,
    tile_size_x=16,
    tile_size_y=16,
    step_n_sweeps_phase_1=4,
    repopulate_depth_nonreturns=False
)
scenedepth = jnp.sum(video[0].depth) / jnp.sum(video[0].depth > 0)
scenepose = condorgmm.Pose(jnp.array([0., 0., scenedepth, 1., 0., 0., 0.], dtype=jnp.float32))

In [None]:
condorgmm.rr_init("blanket-01")

In [None]:
saved_states = []
started = False
for i in tqdm(frame_idxs):
    if not started:
        _, ccts = initialize(video[i], scenepose, cfg)
        started = True
    else:
        _, ccts = update(video[i], scenepose, ccts, cfg, get_gmm=False)

    saved_states.append(ccts)

    # if 159 < i and i < 201:
    #     saved_states.append(ccts)

    rr.set_time_sequence("frame", i)
    condorgmm.rr_log_rgb(video[i].rgb)
    log_state(ccts.state, ccts.hypers, ellipse_scalar=1.5)#, ellipse_mode=rr.components.FillMode.DenseWireframe, ellipse_scalar=2)
    # rr.log("depth_img/observation", rr.DepthImage(video[i].depth))
    # rr.log("depth_img/inferred", rr.DepthImage(ccts.state.datapoints.value.xyz[..., 2].reshape(video[i].depth.shape)))

In [None]:
### For persistent gaussian tracking figures --

In [None]:
# og_gaussian_indices = [
#     24, 323, 205, 229, 269, 210, 357, 28, 67, 221, 247, 148, 381, 127, 200, 130, 378, 23, 195, 358,
#     254,
#     #
#     356, 112, 304, 198, 359, 8,
#     316, 263, 238, 27
# ]

## Final for jeans --
# og_gaussian_indices = [
#     9, 326, 84, 103, 180, 287, 163, 178, 370, 189, 172, 328, 327, 305, 163, 346, 192, 139, 53, 183, 239, 16, 38, 336, 361, 332, 78, 375, 265, 66, 147, 366, 38
# ]

## Final for blanket --
og_gaussian_indices = [
    270, 238, 352, 207, 305, 173, 239, 352, 276, 268
]

In [None]:
from condorgmm.condor.rerun import _ellipsoids

gaussian_indices = og_gaussian_indices

all_ids = {}
for (i, idx) in enumerate(gaussian_indices):
    all_ids[idx] = i

for (i, st) in enumerate(saved_states[10:]):
    t = i + 710
    rr.set_time_sequence("frame", t)
    
    # log_state(st.state, st.hypers)

    gaussian_indices = [
        i for i in gaussian_indices
        if not st.state.gaussians.origin[i] == -1
    ]

    rr.log("selected", _ellipsoids(
        st.state.gaussians[jnp.array(gaussian_indices, dtype=jnp.int32)],
        do_color=False,
        class_ids=jnp.array([all_ids[i] for i in gaussian_indices]),
        std_scalar=2,
        fill_mode=rr.components.FillMode.DenseWireframe,
    ))

In [None]:
for i in tqdm(range(len(video))):
    rr.set_time_sequence("frame", i)
    condorgmm.rr_log_rgb(video[i].rgb)