In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import genjax
genjax.pretty()

In [None]:
## Load a video ##

import condorgmm
import condorgmm.data as data

# original_video = data.YCBinEOATVideo("cracker_box_reorient")
original_video = condorgmm.data.YCBVVideo.training_scene(2)
OBJECT_IDX = 1
video = original_video.downscale(2)

import matplotlib.pyplot as plt
plt.imshow(video[0].rgb)

In [None]:
plt.imshow(video[0].depth)

In [None]:
import jax.numpy as jnp

plt.imshow(jnp.where(
    jnp.logical_not(video[0].masks[OBJECT_IDX][..., None]),
    video[0].rgb,
    jnp.zeros_like(video[0].rgb)
))

In [None]:
import condorgmm.condor.interface.object_tracking as ot

In [None]:
cfg = ot.default_cfg

In [None]:
gmm0, gmm1, ccts, metadata = ot.initialize(
    video[0],
    condorgmm.Pose(video[0].camera_pose),
    condorgmm.Pose(video[0].object_poses[OBJECT_IDX]),
    original_video.get_object_mesh_from_id(
        video[0].object_ids[OBJECT_IDX]
    ),
    video[0].masks[OBJECT_IDX],
    cfg,
    log=True
)

In [None]:
import rerun as rr
from condorgmm.condor.rerun import log_state

object_fitting_meta = metadata['object_model_metadata']['meta']
object_fitting_hypers = metadata['object_model_metadata']['hypers']

bkg_only_meta = metadata['bkg_only_metadata']
bkg_only_hypers = metadata['hypers_masking_object']

st_after_adding_in_object = metadata['st_after_adding_in_object']
st_after_dp_update = metadata['st_after_dp_update']

final_metadata = metadata['final_metadata']
final_hypers = ccts.hypers

In [None]:
from matplotlib.ticker import FuncFormatter, LogLocator, MaxNLocator
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax

def plot_param(param_name, *, idxs=(), miny=None, ax=None, colors=None):
    def get_value(i, idx=None):
        params = final_metadata.visited_states.states[i].matter.background_initialization_params
        value = getattr(params, param_name).value
        if idx is not None:
            return value[idx]
        return value
    
    xs = jnp.arange(len(final_metadata.visited_states.states))
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 4))
    else:
        fig = ax.figure
    
    if len(idxs) > 0:
        values_list = []
        for i, idx in enumerate(idxs):
            values = jax.vmap(lambda i: get_value(i, idx))(xs)
            if miny is not None:
                values = values - miny
            values_list.append(values)
            color = colors[i] if colors is not None and i < len(colors) else None
            ax.plot(xs, values, label=f'{param_name}[{idx}]', color=color)
        all_values = jnp.concatenate(values_list)  # Combine for range computation
    else:
        values = jax.vmap(get_value)(xs)
        if miny is not None:
            values = values - miny
        ax.plot(xs, values, label=param_name)
        all_values = values  # Single variable range
    
    ymin, ymax = jnp.min(all_values), jnp.max(all_values)
    
    ax.set_xlabel('Inference Step')
    ylabel = param_name
    ax.set_ylabel(ylabel)
    
    ax.set_yscale('log')
    if miny is not None:
        formatter = FuncFormatter(lambda y, _: f'{miny} + 10^{{{jnp.log10(y):.2f}}}')
        ax.yaxis.set_major_formatter(formatter)

    ax.set_title(f'{param_name} over Inference Steps')
    ax.legend()
    
    return ax

# Create a shared plot with six subplots
fig, ax = plt.subplots(6, 1, figsize=(10, 8))

# Plot the first parameter
plot_param('xyz_cov_pcnt', miny=2., ax=ax[0])

# Plot the second parameter
plot_param('xyz_cov_isotropic_prior_stds', idxs=(0, 1, 2), ax=ax[1])

# # Plot the third parameter
plot_param('xyz_mean_pcnt', ax=ax[2])

# Plot the fourth parameter with colors
plot_param('rgb_var_n_pseudo_obs', idxs=(0, 1, 2), ax=ax[3], colors=['red', 'green', 'blue'])

# Plot the fifth parameter with colors
plot_param('rgb_var_pseudo_sample_stds', idxs=(0, 1, 2), ax=ax[4], colors=['red', 'green', 'blue'])

# Plot the sixth parameter with colors
plot_param('rgb_mean_n_pseudo_obs', idxs=(0, 1, 2), ax=ax[5], colors=['red', 'green', 'blue'])

plt.tight_layout()
plt.show()


In [None]:

condorgmm.rr_init("full_inference_seq_02")

animation_sequence = [
    # (object_fitting_meta, object_fitting_hypers, "object_fitting"),
    # (bkg_only_meta, bkg_only_hypers, "bkg_only"),
    (st_after_adding_in_object, final_hypers, "st_after_adding_in_object"),
    (st_after_dp_update, final_hypers, "st_after_dp_update"),
    (final_metadata, final_hypers, "final_metadata"),
]

ctr = 0
for (meta, hypers, name) in animation_sequence:
    print(f"Reached {name}")
    if isinstance(meta, ot.CondorGMMState):
        rr.set_time_sequence("step", ctr)
        log_state(meta, hypers)
        rr.log("text", rr.TextLog(name))
        ctr += 1
    else:
        for (i, label) in enumerate(meta.visited_states.all_labels):
            if i % 5 == 0:
                rr.set_time_sequence("step", ctr)
                state = meta.visited_states.states[i]
                log_state(state, hypers)
                rr.log("text", rr.TextLog(f"{name}:::{label}"))
                ctr += 1

In [None]:
video[0].

In [None]:
import jax

# jax.vmap(
#     lambda wts: jnp.sort(wts, descending=True)
# )(final_metadata.visited_states.states[135:145].gaussians.mixture_weight)

# final_metadata.visited_states.states.gaussians.xyz[135:145, :10]
final_metadata.visited_states.states.datapoints.value.gaussian_idx[139:145, :20]

In [None]:
final_metadata.visited_states.states.gaussians.mixture_weight[139:145, 39]

In [None]:
final_metadata.visited_states.states.gaussians.rgb_vars[100:145, 39]

In [None]:
from matplotlib.ticker import FuncFormatter, LogLocator, MaxNLocator
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax

def plot_param(param_name, *, idxs=(), miny=None, ax=None, colors=None):
    def get_value(i, idx=None):
        params = final_metadata.visited_states.states[i].matter.background_initialization_params
        value = getattr(params, param_name).value
        if idx is not None:
            return value[idx]
        return value
    
    xs = jnp.arange(len(final_metadata.visited_states.states))
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 4))
    else:
        fig = ax.figure
    
    if len(idxs) > 0:
        values_list = []
        for i, idx in enumerate(idxs):
            values = jax.vmap(lambda i: get_value(i, idx))(xs)
            if miny is not None:
                values = values - miny
            values_list.append(values)
            color = colors[i] if colors is not None and i < len(colors) else None
            ax.plot(xs, values, label=f'{param_name}[{idx}]', color=color)
        all_values = jnp.concatenate(values_list)  # Combine for range computation
    else:
        values = jax.vmap(get_value)(xs)
        if miny is not None:
            values = values - miny
        ax.plot(xs, values, label=param_name)
        all_values = values  # Single variable range
    
    ymin, ymax = jnp.min(all_values), jnp.max(all_values)
    
    ax.set_xlabel('Inference Step')
    ylabel = param_name
    ax.set_ylabel(ylabel)
    
    ax.set_yscale('log')
    if miny is not None:
        formatter = FuncFormatter(lambda y, _: f'{miny} + 10^{{{jnp.log10(y):.2f}}}')
        ax.yaxis.set_major_formatter(formatter)

    ax.set_title(f'{param_name} over Inference Steps')
    ax.legend()
    
    return ax

# Create a shared plot with six subplots
fig, ax = plt.subplots(6, 1, figsize=(10, 8))

# Plot the first parameter
plot_param('xyz_cov_pcnt', miny=2., ax=ax[0])

# Plot the second parameter
plot_param('xyz_cov_isotropic_prior_stds', idxs=(0, 1, 2), ax=ax[1])

# # Plot the third parameter
plot_param('xyz_mean_pcnt', ax=ax[2])

# Plot the fourth parameter with colors
plot_param('rgb_var_n_pseudo_obs', idxs=(0, 1, 2), ax=ax[3], colors=['red', 'green', 'blue'])

# Plot the fifth parameter with colors
plot_param('rgb_var_pseudo_sample_stds', idxs=(0, 1, 2), ax=ax[4], colors=['red', 'green', 'blue'])

# Plot the sixth parameter with colors
plot_param('rgb_mean_n_pseudo_obs', idxs=(0, 1, 2), ax=ax[5], colors=['red', 'green', 'blue'])

plt.tight_layout()
plt.show()


In [None]:
from condorgmm.condor.rerun import log_state

condorgmm.rr_init("final_state_00")

log_state(ccts.state, ccts.hypers)

In [None]:
metadata['final_metadata'].visited_states.states[138:142].gaussians.mixture_weight

In [None]:
metadata['final_metadata'].visited_states.get_label(0)

In [None]:
(
    frame,
    camera_pose_world_frame,
    object_pose_world_frame,
    object_mesh,
    object_mask,  # (H, W) boolean array
    cfg,
) = (
    video[0],
    condorgmm.Pose(video[0].camera_pose),
    condorgmm.Pose(video[0].object_poses[0]),
    original_video.get_object_mesh_from_id(original_video.object_id),
    video[0].masks[0],
    ot.default_cfg,
)

In [None]:
import jax
from jax.random import split
import jax.numpy as jnp

key = jax.random.key(0)

from condorgmm.condor.interface.object_tracking import _frame_to_intrinsics

k1, k2, k3, k4, k5, k6 = split(key, 6)

if cfg.repopulate_depth_nonreturns:
    mask = jnp.ones(frame.width * frame.height, dtype=bool)
else:
    mask = jnp.array(frame.depth > 0, dtype=bool).flatten()
hypers = cfg.base_hypers.replace(
    {
        "n_gaussians": cfg.n_gaussians_for_background + cfg.n_gaussians_for_object,
        "tile_size_x": cfg.tile_size_x,
        "tile_size_y": cfg.tile_size_y,
        "intrinsics": _frame_to_intrinsics(
            frame, ensure_fx_eq_fy=cfg.repopulate_depth_nonreturns
        ),
        "datapoint_mask": mask,
        "max_n_gaussians_per_tile": cfg.max_n_gaussians_per_tile,
        "repopulate_depth_nonreturns": cfg.repopulate_depth_nonreturns,
    },
    do_replace_none=False,
)


In [None]:
from condorgmm.condor.interface.object_tracking import BackgroundOnlySceneState, Pose

xyz, rgb = condorgmm.mesh.sample_surface_points(
    object_mesh, cfg.n_pts_for_object_fitting
)
xyz, rgb = jnp.array(xyz, dtype=jnp.float32), jnp.array(rgb, dtype=jnp.float32)
rgb = (
    rgb
    + jax.random.normal(k1, rgb.shape, dtype=jnp.float32)
    * hypers.rgb_noisefloor_std
)
hypers = hypers.replace(
    n_gaussians=cfg.n_gaussians_for_object,
    datapoint_mask=jnp.ones(cfg.n_pts_for_object_fitting, dtype=bool),
    repopulate_depth_nonreturns=False,
    use_monolithic_tiling=True,
    initial_scene=BackgroundOnlySceneState(
        transform_World_Camera=Pose.identity(),
        background_rigidity=hypers.initial_scene.background_rigidity,  # doesn't matter
    ),
)


In [None]:
n_gaussians = 20
n_datapoints = xyz.shape[0]
gaussian_to_datapoint = jax.random.choice(
    key, jnp.arange(n_datapoints), shape=(n_gaussians,), replace=False
)

dp_to_4_closest = jax.vmap(
    lambda dp_idx: jnp.argsort(jnp.linalg.norm(xyz - xyz[dp_idx], axis=-1))[:20]
)
gaussian_to_4_datapoints = dp_to_4_closest(gaussian_to_datapoint)
gaussian_to_4_datapoints

In [None]:
jnp.concatenate(gaussian_to_4_datapoints)

In [None]:
jnp.repeat(jnp.arange(n_gaussians), 20)

In [None]:
dp_to_gaussian = (-jnp.ones(n_datapoints, dtype=jnp.int32)).at[
    jnp.concatenate(gaussian_to_4_datapoints)
].set(jnp.repeat(jnp.arange(n_gaussians), 20))
dp_to_gaussian

In [None]:
from condorgmm.condor.interface.object_tracking import _get_sparse_datapoint_assignment_initialization

gaussian_idxs = _get_sparse_datapoint_assignment_initialization(
    k2, cfg.n_gaussians_for_object, xyz, rgb
)

gaussian_idxs[gaussian_idxs != -1]

In [None]:
import condorgmm.condor.interface.object_tracking as ot
# import importlib
# importlib.reload(ot)

gmm0, gmm1, ccts, metadata = ot.initialize(
    video[0],
    condorgmm.Pose(video[0].camera_pose),
    condorgmm.Pose(video[0].object_poses[0]),
    original_video.get_object_mesh_from_id(original_video.object_id),
    video[0].masks[0],
    ot.default_cfg.replace(n_gaussians_for_object=120),
    log=True
)

In [None]:
import rerun as rr
from condorgmm.condor.rerun import log_state

hyp = metadata['object_model_metadata']['hypers']
meta = metadata['object_model_metadata']['meta']

condorgmm.rr_init("object_fitting_00")

for (i, label) in enumerate(meta.visited_states.all_labels):
    if i % 40 != 0 and i > 30:
        continue
    rr.set_time_sequence("frame", i)
    log_state(meta.visited_states.states[i], hyp)

In [None]:
condorgmm.rr_init("final_state_00")

log_state(ccts.state, ccts.hypers)