In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import genjax
genjax.pretty()

In [None]:
import condorgmm
from condorgmm.utils.common import get_assets_path
import condorgmm.data as data
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
fig, ax = plt.subplots()

video = data.R3DVideo(get_assets_path() / "nearfar.r3d")
frame_idxs = range(360, 840, 1)
video = video.crop(0, 180, 16, 256).downscale(2)

def update(idx):
    ax.clear()
    ax.imshow(video[frame_idxs[idx]].rgb)
    ax.set_title(f"Frame {frame_idxs[idx]}")

ani = FuncAnimation(fig, update, len(frame_idxs), repeat=False)
# HTML(ani.to_jshtml())

In [None]:
from condorgmm.condor.interface.camera_tracking import initialize, update, fast_config
import rerun as rr
from condorgmm.condor.rerun import log_state
from tqdm import tqdm
import jax.numpy as jnp

def _f(x):
    return jnp.array(x, dtype=jnp.float32)

hyp = fast_config.base_hypers.replace({
    "infer_background_evolution_params": False,
    "default_background_evolution_params": {
        "target_xyz_mean_std": _f(0.01),
        "xyz_cov_pcnt": _f(4),
    }
})
cfg = fast_config.replace(base_hypers=hyp)
scenedepth = jnp.sum(video[0].depth) / jnp.sum(video[0].depth > 0)
scenepose = condorgmm.Pose(jnp.array([0., 0., scenedepth, 1., 0., 0., 0.], dtype=jnp.float32))

In [None]:
condorgmm.rr_init("condor_nonrigid_00")

In [None]:
for i in tqdm(range(1000)):
    if i % 10 == 0:
        if i == 0:
            _, ccts = initialize(video[i], scenepose, cfg)
        else:
            _, ccts = update(video[i], scenepose, ccts, cfg, get_gmm=False)

        rr.set_time_sequence("frame", i)
        log_state(ccts.state, ccts.hypers)
        rr.log("depth_img/observation", rr.DepthImage(video[i].depth))
        rr.log("depth_img/inferred", rr.DepthImage(ccts.state.datapoints.value.xyz[..., 2].reshape(video[i].depth.shape)))