In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import genjax
genjax.pretty()

In [None]:
## Load a video ##

import condorgmm
import condorgmm.data as data

video = (
    data.YCBVVideo.training_scene(2)
    .downscale(4)
)

import matplotlib.pyplot as plt
plt.imshow(video[1].rgb)

In [None]:
video[1].rgb.shape

In [None]:
from condorgmm.condor.interface.camera_tracking import initialize, update, slow_config
import rerun as rr
from condorgmm.condor.rerun import log_state
from tqdm import tqdm
import condorgmm

condorgmm.rr_init("step_inference_sequence_02")

In [None]:
cfg = slow_config.replace(
    do_pose_update=True,
    step_n_sweeps_phase_1=10,
    step_n_sweeps_phase_2=20,
)

In [None]:
_, ccts0 = initialize(
    video[0],
    condorgmm.Pose(video[0].camera_pose),
    cfg,
    log=False
)

In [None]:

log_state(ccts0.state, ccts0.hypers)

In [None]:
fr = 20

_, ccts, meta = update(video[fr], condorgmm.Pose(video[0].camera_pose), ccts0, cfg, log=True)

In [None]:
# Plot evolution parameters over inference steps
from matplotlib.ticker import FuncFormatter, LogLocator, MaxNLocator
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax

def plot_evolution_param(param_name, *, miny=None, ax=None):
    def get_value(i):
        params = meta.visited_states.states[i].matter.background_evolution_params
        return getattr(params, param_name).value
    
    xs = jnp.arange(len(meta.visited_states.states))
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 4))
    else:
        fig = ax.figure
    
    values = jax.vmap(get_value)(xs)
    if miny is not None:
        values = values - miny
    ax.plot(xs, values, label=param_name)
    
    ax.set_xlabel('Inference Step')
    ylabel = param_name
    ax.set_ylabel(ylabel)
    
    ax.set_yscale('log')
    if miny is not None:
        formatter = FuncFormatter(lambda y, _: f'{miny} + 10^{{{jnp.log10(y):.2f}}}')
        ax.yaxis.set_major_formatter(formatter)

    ax.set_title(f'{param_name} over Inference Steps')
    ax.legend()
    
    return ax

# Create a shared plot with four subplots for the evolution parameters
fig, ax = plt.subplots(4, 1, figsize=(10, 8))

# Plot each evolution parameter
plot_evolution_param('prob_gaussian_is_new', ax=ax[0])
plot_evolution_param('xyz_cov_pcnt', miny=2., ax=ax[1])  # Subtract 2 from xyz_cov_pcnt
plot_evolution_param('rgb_var_pcnt', ax=ax[2])
plot_evolution_param('target_xyz_mean_std', ax=ax[3])

plt.tight_layout()
plt.show()

# Check if any parameters hit their domain boundaries
from condorgmm.condor.types import FloatFromDiscreteSet
visited_evolution_param_values = meta.visited_states.states.matter.background_evolution_params
for (param_name, value) in visited_evolution_param_values.__dict__.items():
    if isinstance(value, FloatFromDiscreteSet):
        dom_size = len(value.domain)
        assert jnp.all(value.idx != 0), f"Parameter {param_name} hit the bottom of its range during inference."
        assert jnp.all(value.idx != dom_size - 1), f"Parameter {param_name} hit the top of its range during inference."

In [None]:
condorgmm.rr_init("tmp")

mask = ccts0.state.gaussian_has_assoc_mask
rr.log("frame0", rr.Points3D(
    ccts0.state.gaussians.xyz[mask],
    colors=ccts0.state.gaussians.rgb[mask] / 255,
))
rr.log("step", rr.Points3D(
    meta.visited_states.states[20].gaussians.xyz[mask],
    colors=meta.visited_states.states[20].gaussians.rgb[mask] / 255,
))
# Draw lines connecting corresponding points between frame0 and step
mask = ccts0.state.gaussian_has_assoc_mask
xyz0 = ccts0.state.gaussians.xyz[mask]
xyz1 = meta.visited_states.states[20].gaussians.xyz[mask]

rr.log("lines2", rr.LineStrips3D(
    [jnp.concatenate([xyz0[i:i+1], xyz1[i:i+1]]) for i in range(len(xyz0))],
    colors=jnp.array([[0.5, 0.5, 0.5]]),
))


In [None]:
jnp.stack([xyz0, xyz1, jnp.full_like(xyz0, jnp.nan)], axis=1).reshape(-1, 3)

In [None]:
for (i, label) in enumerate(meta.visited_states.all_labels):
    if i % 5 != 0:
        continue
    rr.set_time_sequence("inference_step", i+1)
    log_state(meta.visited_states.states[i], ccts.hypers, log_in_world_frame=True)
    condorgmm.rr_log_posquat(
        meta.visited_states.states[i].scene.transform_World_Camera.posquat,
        channel="inferred_camera_pose"
    )
    rr.log("inference_move", rr.TextDocument(label))
    rr.log("depth_img/observation", rr.DepthImage(video[fr].depth))
    rr.log("depth_img/inferred", rr.DepthImage(ccts.state.datapoints.value.xyz[..., 2].reshape(video[fr].depth.shape)))
    condorgmm.rr_log_pose(condorgmm.Pose(video[fr].camera_pose), "gt_pose")

In [None]:
meta.visited_states.states.scene.transform_World_Camera

In [None]:
sts = meta.visited_states.states
jax.vmap(lambda st: st.n_assocs_per_gaussian[83])(sts)

In [None]:
import jax.numpy as jnp

i = 2
sts = meta.visited_states.states

gi = 8

# jnp.logical_and(
#     sts[i].gaussians.origin != sts[i+1].gaussians.origin,
#     sts[i].gaussian_has_assoc_mask
# )

sts[i].n_assocs_per_gaussian[gi]