# Starter notebook for analyzing inference behavior

This notebook is intended to be a template to start from for analyzing inference behavior.
If you push changes to this notebook, they are changes to the starter template!

When working on studying inferenve behavior, please make a copy of this notebook in a folder
with your name on it.  (E.g. George would make a copy of this notebook in `notebooks/gm`.)
Then feel free to modify that notebook as you please in your debugging.

In [None]:
import jax
import jax.numpy as jnp
import condorgmm
import condorgmm.model
import genjax
import condorgmm.end_to_end
from condorgmm.config.default import configuration as config
from tqdm import tqdm
from condorgmm.utils import inference_analysis_utils as iau
import matplotlib.pyplot as plt
genjax.pretty()

In [None]:
scene_id = 1
OBJECT_INDEX = 1

In [None]:
# Load video & object
FRAME_RATE = 50
ycb_dir = condorgmm.get_root_path() / "assets/bop/ycbv/train_real"

all_data, meshes, intrinsics = condorgmm.load_scene(ycb_dir, scene_id, FRAME_RATE)
initial_object_poses = (
    all_data[0]["camera_pose"].inv() @ all_data[0]["object_poses"]
)
def gt_pose(T):
    return all_data[T]["camera_pose"].inv() @ all_data[T]["object_poses"][OBJECT_INDEX]

# Get initial state
initial_state, vertices = condorgmm.end_to_end.get_initial_state_for_object(
    meshes, OBJECT_INDEX, initial_object_poses
)

In [None]:
# Initialize rerun
rr_name = f"result_analysis_sc{scene_id}_obj{OBJECT_INDEX}"
condorgmm.rr_init(rr_name)

In [None]:
# Generate trace at initial pose with derministically initialized state
key = jax.random.PRNGKey(0)
trace = condorgmm.end_to_end.initialize_inference(
    initial_state, all_data, config.model_hyperparams_first_frame, vertices, intrinsics
)
condorgmm.model.viz_trace(
    trace,
    -1,
    ground_truth_pose=gt_pose(0)
)

In [None]:
# Update the point properties at this fixed pose, at T=0
key, subkey = jax.random.split(key)
trace, _, _ = (
    condorgmm.inference.update_all_variables_given_pose(
        key,
        trace,
        trace.get_choices()["pose"],
        config.point_attribute_proposal,
    )
)
condorgmm.model.viz_trace(
    trace,
    0,
    ground_truth_pose=gt_pose(0)
)

In [None]:
# Update hyperparams for subsequent frames
trace_pre_loop = condorgmm.end_to_end.update_hyperparams_for_subsequent_frames(
    trace, config.model_hyperparams_subsequent_frames
)

In [None]:
# Run inference over the full video
maxT = len(all_data)
trace = trace_pre_loop
trs = []
keys = []
for T in tqdm(range(1, maxT)):
    trs.append(trace)
    keys.append(key)
    key, _ = jax.random.split(key)
    trace = condorgmm.end_to_end.run_inference_step(
        trace, gt_pose(T), True, config.point_attribute_proposal, all_data[T]["rgbd"],
        key,
        do_advance_time = True
    )
    condorgmm.model.viz_trace(
        trace,
        T,
        ground_truth_pose=all_data[T]["camera_pose"].inv()
        @ all_data[T]["object_poses"][OBJECT_INDEX],
    )

In [None]:
## Below here -- setup for inspecting specific frames with issues

In [None]:
T_bad = 13 # The frame to be inspected (a frame at which tracking got off)

trace_pre_error = trs[T_bad - 1]
key_pre_error = keys[T_bad - 1]

# Reconstruct the first step of C2F at this frame where something went wrong.
# Use the same PRNGKey as from the main loop above,
# so this gives you the exact thing that happened internally
# in the algorithm above.
key_advanced, _ = jax.random.split(key_pre_error)
tr_advanced = condorgmm.inference.advance_time(
    key_advanced, trace_pre_error, all_data[T_bad]["rgbd"]
)
key_stp1 = jax.random.split(key_advanced)[-1]
tr_stp1, metadata = condorgmm.inference.inference_step(
    key_stp1,
    tr_advanced,
    0.04,
    1500.0,
    2000,
    config.point_attribute_proposal,
    use_gt_pose=True,
    gt_pose=gt_pose(T_bad),
    get_metadata=True,
)

In [None]:
print("Max log importance weight - log importance weight of trace at GT pose:", jnp.max(metadata["scores"]) - metadata["scores"][-1])

In [None]:
# Plot the 100 largest importance weights
plt.plot(jnp.sort(metadata["scores"])[::-1][0:100])

In [None]:
# Regenerate the trace that was at the ground truth pose
k = jax.random.split(metadata["key_for_point_proposals"], 2000)[-1]
gt_pose_tr, gt_pose_score, _ = condorgmm.inference.update_all_variables_given_pose(
    k, tr_stp1, gt_pose(T_bad), config.point_attribute_proposal
)

In [None]:
# Visualize the preceding trace, the resampled trace at T_bad, and the trace at the
# ground truth pose generated in the first step of C2F at T_bad.
condorgmm.rr_init(f"{rr_name}-T{T_bad}--1")
condorgmm.model.viz_trace(trace_pre_error, 0, ground_truth_pose=gt_pose(T_bad-1))
condorgmm.model.viz_trace(tr_stp1, 1, ground_truth_pose=gt_pose(T_bad))
condorgmm.model.viz_trace(gt_pose_tr, 10, ground_truth_pose=gt_pose(T_bad))

In [None]:
# Quick summary of some of the sub trace scores, and other differences between the traces:
iau.print_trace_subscore_diffs(tr_stp1, gt_pose_tr)
print(f"Num visibility flag flips in resampled trace: {iau.get_n_visibility_flips(tr_stp1)}")
print(f"Num visibility flag flips in gt pose trace: {iau.get_n_visibility_flips(gt_pose_tr)}")

In [None]:
# Penzai inspection of the two traces:

In [None]:
gt_pose_tr

In [None]:
tr_stp1

In [None]:
# Look at the difference in the total proposal scores for all point-level
# proposals in the two traces.  (This ignores Q score for pose proposals, but
# that is usually small [magnitude <15].)
d = metadata["point_proposal_metadata"]["proposal_scores"]
(d[metadata["sampled_index"]].sum() - d[-1].sum())

In [None]:
# Note that you can also look at the `metadata` dict to see a lot more
# details about the c2f step.