In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import genjax
genjax.pretty()

In [None]:
## Load a video ##

import condorgmm
import condorgmm.data as data

video = (
    data.YCBVVideo.training_scene(1)
    .downscale(4)
)

import matplotlib.pyplot as plt
import jax
plt.imshow(video[1].rgb)

## Visualize the sequence of states visited by inference

In [None]:
from condorgmm.condor.interface.camera_tracking import initialize, fast_config

frame=video[0]
n_gaussians=100
camera_pose_world_frame=condorgmm.Pose(video[0].camera_pose)
import jax
gmm, ccts, meta = initialize(
    frame,
    camera_pose_world_frame,
    fast_config.replace(
        n_gaussians=n_gaussians,
        n_sweeps_per_phase=(20, 20, 20, 100)
    ),
    log=True, key=jax.random.key(101),
)
hypers = ccts.hypers

In [None]:
import jax.numpy as jnp
from condorgmm.condor.types import FloatFromDiscreteSet
visited_global_param_values = meta.visited_states.states.matter.background_initialization_params
for (param_name, value) in visited_global_param_values.__dict__.items():
    if isinstance(value, FloatFromDiscreteSet):
        dom_size = len(value.domain)
        assert jnp.all(value.idx != 0), f"Parameter {param_name} hit the bottom of its range during inference."
        assert jnp.all(value.idx != dom_size - 1), f"Parameter {param_name} hit the top of its range during inference."


When things go wrong in inference, a very common consequence is that all the datapoints end up associated with the same Gaussian.  By displaying the datapoint->Gaussian association, we can get a quick read on whether inference worked.

In [None]:
meta.visited_states.states.datapoints.value.gaussian_idx

In [None]:
from matplotlib.ticker import FuncFormatter, LogLocator, MaxNLocator
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax

def plot_param(param_name, *, idxs=(), miny=None, ax=None, colors=None):
    def get_value(i, idx=None):
        params = meta.visited_states.states[i].matter.background_initialization_params
        value = getattr(params, param_name).value
        if idx is not None:
            return value[idx]
        return value
    
    xs = jnp.arange(len(meta.visited_states.states))
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 4))
    else:
        fig = ax.figure
    
    if len(idxs) > 0:
        values_list = []
        for i, idx in enumerate(idxs):
            values = jax.vmap(lambda i: get_value(i, idx))(xs)
            if miny is not None:
                values = values - miny
            values_list.append(values)
            color = colors[i] if colors is not None and i < len(colors) else None
            ax.plot(xs, values, label=f'{param_name}[{idx}]', color=color)
        all_values = jnp.concatenate(values_list)  # Combine for range computation
    else:
        values = jax.vmap(get_value)(xs)
        if miny is not None:
            values = values - miny
        ax.plot(xs, values, label=param_name)
        all_values = values  # Single variable range
    
    ymin, ymax = jnp.min(all_values), jnp.max(all_values)
    
    ax.set_xlabel('Inference Step')
    ylabel = param_name
    ax.set_ylabel(ylabel)
    
    ax.set_yscale('log')
    if miny is not None:
        formatter = FuncFormatter(lambda y, _: f'{miny} + 10^{{{jnp.log10(y):.2f}}}')
        ax.yaxis.set_major_formatter(formatter)

    ax.set_title(f'{param_name} over Inference Steps')
    ax.legend()
    
    return ax

# Create a shared plot with six subplots
fig, ax = plt.subplots(6, 1, figsize=(10, 8))

# Plot the first parameter
plot_param('xyz_cov_pcnt', miny=2., ax=ax[0])

# Plot the second parameter
plot_param('xyz_cov_isotropic_prior_stds', idxs=(0, 1, 2), ax=ax[1])

# # Plot the third parameter
plot_param('xyz_mean_pcnt', ax=ax[2])

# Plot the fourth parameter with colors
plot_param('rgb_var_n_pseudo_obs', idxs=(0, 1, 2), ax=ax[3], colors=['red', 'green', 'blue'])

# Plot the fifth parameter with colors
plot_param('rgb_var_pseudo_sample_stds', idxs=(0, 1, 2), ax=ax[4], colors=['red', 'green', 'blue'])

# Plot the sixth parameter with colors
plot_param('rgb_mean_n_pseudo_obs', idxs=(0, 1, 2), ax=ax[5], colors=['red', 'green', 'blue'])

plt.tight_layout()
plt.show()


Log to rerun.

In [None]:
condorgmm.rr_init("condor2/frame0_02")

In [None]:
from condorgmm.condor.rerun import log_state
import rerun as rr
for (i, label) in enumerate(meta.visited_states.all_labels):
    if i % 10 != 0 and i > 40:
        continue

    rr.set_time_sequence("inference_step", i)
    log_state(meta.visited_states.states[i], hypers)
    rr.log("inference_move", rr.TextDocument(label))
    rr.log("depth_img/observation", rr.DepthImage(frame.depth))
    rr.log("depth_img/inferred", rr.DepthImage(meta.visited_states.states[i].datapoints.value.xyz[..., 2].reshape(frame.depth.shape)))

In [None]:
ccts.state.matter.background_initialization_params

## Runtime test

In [None]:
# JIT
# initialize(frame, n_gaussians, camera_pose_world_frame, n_sweeps=100)

In [None]:
# Time
# initialize(frame, n_gaussians, camera_pose_world_frame, n_sweeps=100)
None