In [1]:
import jax
import jax.numpy as jnp
import genjax
genjax.pretty()
import b3d
import b3d.chisight.gen3d as gen3d
from tqdm import tqdm

In [2]:
problematic_scene_specs = [(2, 5), (3, 0), (3, 1), (3, 2), (4, 1), (4, 3), (5, 0)]
spec = problematic_scene_specs[1]
scene_id, object_idx = spec

In [None]:
all_data, meshes, renderer, intrinsics, initial_object_poses = gen3d.dataloading.load_scene(
    scene_id, FRAME_RATE=50, subdir="train_real"
)
b3d.viz_rgb(all_data[0]["rgbd"])

In [4]:
template_pose, model_vertices, model_colors = gen3d.dataloading.load_object_given_scene(
    all_data, meshes, renderer, object_idx
)

In [None]:
hyperparams = gen3d.settings.hyperparams
hyperparams["intrinsics"] = intrinsics
hyperparams["vertices"] = model_vertices
hyperparams

In [None]:
all_data

In [None]:
inference_hyperparams = gen3d.settings.inference_hyperparams
inference_hyperparams = gen3d.hyperparams.InferenceHyperparams(**{
    **inference_hyperparams.attributes_dict(),
    "pose_proposal_args": [(0.04, 1000.0)]
})
inference_hyperparams

In [25]:
def get_gt_pose(T):
    return (
        all_data[T]["camera_pose"].inv()
        @ all_data[T]["object_poses"][object_idx]
    )

In [26]:
initial_state = gen3d.dataloading.get_initial_state(
    template_pose, model_vertices, model_colors, hyperparams
)

In [27]:
key = jax.random.PRNGKey(156)
og_trace = gen3d.inference.get_initial_trace(
    key, hyperparams, initial_state, all_data[0]["rgbd"]
)

In [28]:
b3d.rr_init("inference_debugging-2")


In [None]:
tracking_results = {}
maxT = 20
key = jax.random.PRNGKey(1)
trace = og_trace
for T in tqdm(range(maxT)):
    key = b3d.split_key(key)
    prev_trace = trace
    trace, _, all_weights, all_poses, keys_to_regen = gen3d.inference.inference_step(
        key,
        trace,
        all_data[T]["rgbd"],
        inference_hyperparams,
        gt_pose=get_gt_pose(T),
        use_gt_pose=True,
        get_all_weights = True
    )
    tracking_results[T] = trace

    gt_pose, gt_key = all_poses[0], keys_to_regen[0]
    gt_trace = gen3d.inference.get_trace_generated_during_inference(
        gt_key, prev_trace, gt_pose, inference_hyperparams
    )

    gen3d.model.viz_trace(
        trace,
        T,
        ground_truth_vertices=meshes[object_idx].vertices,
        ground_truth_pose=get_gt_pose(T),
    )

    if jnp.linalg.norm(gt_pose.position - gen3d.model.get_new_state(trace)["pose"].position) > 0.02:
        b3d.rr_init("gt_pose_trace")
        gen3d.model.viz_trace(gt_trace, T, ground_truth_vertices=meshes[object_idx].vertices, ground_truth_pose=get_gt_pose(T))
        break