### Pose proposals via stratified (grid) sampling

We add an option in the inference to generate pose proposals via sampling in a uniform pose grid, rather than the default of Gaussian-VMF pose proposals.  

This notebook visualizes such example grids, and demonstrates a run of the end-to-end pipeline with `use_gt_pose=False`, 
across the first 2 scenes for each object. 
We also will aggregate metrics to confirm that pose gridding appears to work at the integrated inference level.

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [None]:
import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation
from tqdm import tqdm

In [None]:
import condorgmm 
from condorgmm import Pose, Mesh
from condorgmm.utils.pose_gridding import make_pose_grid_enumeration_simple_jit, rr_log_pose_grid, rr_log_cloud_and_pose

### Preliminaries: grid visualization

##### Generate pose proposals

In [None]:
k1 = jax.random.PRNGKey(1222)
previous_pose = Pose(jnp.array([0.0, 0.0, 0.0]), Pose.identity_quaternion)
n_translation_half_dim = 2
n_rotation_half_dim = 2
n_poses = (2 * n_translation_half_dim + 1)**3 * (2 * n_rotation_half_dim + 1)**3
print(f"Generating {n_poses} poses")

Gaussian VMF proposals

In [None]:
c2f_schedule = [
    (0.04, 1500.0),
    (0.02, 2000.0),
    (0.01, 2000.0),
    (0.005, 3000.0),
] ## in end_to_end.py#L127
std, concentration = c2f_schedule[0]
print(f"std: {std}, concentration: {concentration}")

vmf_poses = Pose.sample_gaussian_vmf_pose_vmap(
    jax.random.split(k1, n_poses), previous_pose, std, concentration
)

Grid proposals

In [None]:
## To get approximate "parity" between the range of pose deltas
# generated by the vmf and the uniform grid, we initialize the
# pose grid based on results of the vmf.
# (NOTE tha the vmf is less interpretable in this way)

translation_delta = jnp.max(vmf_poses.position.max(axis=0))
rotation_euler_delta = Rotation.from_quat(vmf_poses.quaternion).as_euler("ZYX").max()

### Samples from gaussian vmf
grid_poses= make_pose_grid_enumeration_simple_jit(
    pose_center=previous_pose,
    half_dtr=translation_delta,
    half_ntr=n_translation_half_dim,
    half_dangle=rotation_euler_delta,
    half_nrot=n_rotation_half_dim,
)


##### Visualize pose proposals

In [None]:
YCB_OBJ_ID = 13   # beloved mug
ycb_mesh_dir = os.path.join(condorgmm.get_assets_path(), "bop/ycbv")
mesh = Mesh.from_obj_file(
    os.path.join(ycb_mesh_dir, f'models/obj_{f"{YCB_OBJ_ID + 1}".rjust(6, "0")}.ply')
).scale(0.001)      # note that this processing is the same as in `rr_log_pose_grid`

In [None]:
## for a more immediately visually comparable Rerun log, 
# we sort the Pose in the posegrid by the translation distance

vmf_sort = jnp.argsort(jnp.linalg.norm(vmf_poses.position - previous_pose.position, axis=1))
grid_sort = jnp.argsort(jnp.linalg.norm(grid_poses.position - previous_pose.position, axis=1))
grid_poses = grid_poses[grid_sort]
vmf_poses = vmf_poses[grid_sort]

In [None]:
DEBUG_VIZ = True 

if DEBUG_VIZ:
    # rr_log_pose_grid(vmf_poses, rerun_session_name="vmf_vs_grid", ycb_obj_id=YCB_OBJ_ID, channel_name="vmf")
    rr_log_pose_grid(grid_poses, rerun_session_name="vmf_vs_grid", ycb_obj_id=YCB_OBJ_ID, channel_name="grid")

##### Simulate inference with grid vs vmf based pose proposal
We will decompose inference.py's `c2f_inference_step` to compare results on a single step of inference with the alternate pose gridding methods.

In [None]:
## Setup 
from condorgmm import load_scene
from condorgmm.model import viz_trace
from condorgmm.end_to_end import init_metrics_dict, get_initial_state_for_object, initialize_inference, update_hyperparams_for_subsequent_frames, extend_metrics
from condorgmm.end_to_end import run_inference_step

In [None]:
## inference settings
use_gt_pose = False  
use_grid = True
use_vmf = not use_grid

## ycb scene/object to test
YCB_SCENE = 1
OBJECT_INDEX = 0   # obj index in scene
FRAME_RATE = 50
ycb_dir = condorgmm.get_root_path() / "assets/bop/ycbv/train_real"
live_rerun = True
save_rerun = True
RERUN_SESSION_NAME = f"ycbv_grid_debug_grid_{use_grid}"
from condorgmm.config.default import configuration as config


In [None]:
# ## load ycb
all_data, meshes, intrinsics = load_scene(ycb_dir, YCB_SCENE, FRAME_RATE)
initial_object_poses = all_data[0]["object_poses"]
all_scores = init_metrics_dict()
gt_pose = lambda T: all_data[T]["camera_pose"].inv() @ all_data[T]["object_poses"][OBJECT_INDEX]   # from end_to_end.py#L272

In [None]:
## The below code is taken from end_to_end.py/run_tracking

key = jax.random.PRNGKey(0)

if live_rerun:
    condorgmm.rr_init(RERUN_SESSION_NAME)
    
T = 0

initial_object_poses = (
    all_data[T]["camera_pose"].inv() @ all_data[T]["object_poses"]
)

initial_state = get_initial_state_for_object(
    meshes, OBJECT_INDEX, initial_object_poses
)

trace = initialize_inference(
    initial_state,
    all_data,
    config.model_hyperparams_first_frame,
    meshes[OBJECT_INDEX].vertices,
    intrinsics,
)
if live_rerun or save_rerun:
    # Log to rerun as "frame -1"
    viz_trace(trace, -1, meshes[OBJECT_INDEX].vertices, gt_pose(0))

trace, _, metadata = (
    condorgmm.inference.update_all_variables_given_pose(
        key,
        trace,
        trace.get_choices()["pose"],
        config.point_attribute_proposal,
    )
)

if live_rerun or save_rerun:
    viz_trace(trace, 0, meshes[OBJECT_INDEX].vertices, gt_pose(0))
                    
# Run inference for the rest of the frames
trace = update_hyperparams_for_subsequent_frames(
    trace, config.model_hyperparams_subsequent_frames
)

inferred_poses = [trace.get_choices()["pose"]]

In [None]:
# copied from inference.py
def run_inference_step(
    trace,
    gt_pose,
    use_gt_pose,
    point_attribute_proposal,
    observed_img,
    key=jax.random.PRNGKey(0),
    do_advance_time=True,
    use_grid=False
):
    if do_advance_time:
        trace = condorgmm.inference.advance_time(key, trace, observed_img)
        
    if not use_grid:
        ## VMF pose proposals
        c2f_schedule = [
            (0.04, 1500.0),
            (0.02, 2000.0),
            (0.01, 2000.0),
            (0.005, 3000.0),
        ]
        n_poses = 2000  
        for v, c in c2f_schedule:
            key = jax.random.split(key)[-1]
            trace, metadata = condorgmm.inference.c2f_inference_gaussian_vmf(
                key,
                trace,
                v, c,
                n_poses,
                point_attribute_proposal,
                use_gt_pose=use_gt_pose,
                gt_pose=gt_pose,
                get_metadata=True    # for visualization
            ) 
        
    else:
        ## gridding pose proposals
        n_translation_half_dim = 2
        n_rotation_half_dim = 1
        c2f_schedule = [
            # half_dtr, half_dangle
            (0.015, jnp.pi/10),   
            (0.0075 , jnp.pi/12),
            (0.0025, jnp.pi/15),
            (0.001, jnp.pi/20)
        ]

        for half_dtr, half_dangle in c2f_schedule:
            key = jax.random.split(key)[-1]
            trace, metadata = condorgmm.inference.c2f_inference_grid(
                key,
                trace,
                half_dtr, half_dangle,
                n_translation_half_dim, n_rotation_half_dim,
                point_attribute_proposal,
                use_gt_pose=use_gt_pose,
                gt_pose=gt_pose,
                get_metadata=True
            )
        
    return trace, metadata

In [None]:
maxT = 10

for T in tqdm(range(1, maxT)):
    key, sk = jax.random.split(key)
    trace, metadata = run_inference_step(
                trace,
                gt_pose(0),
                use_gt_pose,
                config.point_attribute_proposal,
                all_data[0]["rgbd"],
                sk,
                do_advance_time=False,
                use_grid=use_grid,   
            )
    inferred_poses.append(trace.get_choices()["pose"])
    print(trace.get_choices()['pose'].position)
    
    # if live_rerun or save_rerun:
    #     viz_trace(trace, T, meshes[OBJECT_INDEX].vertices, gt_pose(T))
    #     for sample_idx_to_viz in range(0,len(metadata['poses']), 100):
    #         rr_log_cloud_and_pose(metadata['poses'][sample_idx_to_viz], 
    #         vertices=meshes[OBJECT_INDEX].vertices, t=T, 
    #         channel="grid_proposals" if use_grid else "vmf_proposals")
                        
    extend_metrics(
        all_scores,
        trace,
        "test_object",
        gt_pose(T),
        meshes[OBJECT_INDEX],
    )
