In [1]:
### IMPORTS ###

import b3d
import jax.numpy as jnp
import os
from b3d import Mesh, Pose
import jax
import genjax
from genjax import Pytree
import rerun as rr
from b3d.modeling_utils import uniform_discrete, uniform_pose, gaussian_vmf
import matplotlib.pyplot as plt
from functools import partial
import importlib
from ipywidgets import interact
import ipywidgets as widgets
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from genjax import SelectionBuilder as S
from genjax import ChoiceMapBuilder as C

b3d.rr_init("dynamics2")

In [None]:
### Loading data ###
b3d.reload(b3d.io.data_loader)
scene_id = 49
FRAME_RATE = 50
ycb_dir = os.path.join(b3d.get_assets_path(), "bop/ycbv")
print(f"Scene {scene_id}")
b3d.reload(b3d.io.data_loader)
num_scenes = b3d.io.data_loader.get_ycbv_num_test_images(ycb_dir, scene_id)

# image_ids = [image] if image is not None else range(1, num_scenes, FRAME_RATE)
image_ids = range(1, num_scenes + 1, FRAME_RATE)
all_data = b3d.io.data_loader.get_ycbv_test_images(ycb_dir, scene_id, image_ids)

meshes = [
    Mesh.from_obj_file(
        os.path.join(ycb_dir, f'models/obj_{f"{id + 1}".rjust(6, "0")}.ply')
    ).scale(0.001)
    for id in all_data[0]["object_types"]
]

image_height, image_width = all_data[0]["rgbd"].shape[:2]
fx,fy,cx,cy = all_data[0]["camera_intrinsics"]
scaling_factor = 1.0
renderer = b3d.renderer.renderer_original.RendererOriginal(
    image_width * scaling_factor, image_height * scaling_factor, fx * scaling_factor, fy * scaling_factor, cx * scaling_factor, cy * scaling_factor, 0.01, 2.0
)
b3d.viz_rgb(all_data[0]["rgbd"])

In [136]:
import b3d
import b3d.chisight.gen3d.model
b3d.reload(b3d.chisight.gen3d.model)
import b3d.chisight.gen3d.transition_kernels as transition_kernels
b3d.reload(b3d.chisight.gen3d.transition_kernels)
import b3d.chisight.gen3d.image_kernel as image_kernel
b3d.reload(b3d.chisight.gen3d.image_kernel)
import b3d.io.data_loader
import jax
import jax.numpy as jnp
from b3d import Mesh, Pose
from b3d.chisight.gen3d.model import (
    make_colors_choicemap,
    make_depth_nonreturn_prob_choicemap,
    make_visibility_prob_choicemap,
)
from b3d.chisight.gen3d.model import dynamic_object_generative_model
from genjax import ChoiceMapBuilder as C
from genjax import Pytree
from b3d.chisight.gen3d.projection import PixelsPointsAssociation

p_resample_color = 0.005
hyperparams = {
    "pose_kernel": transition_kernels.UniformPoseDriftKernel(max_shift=0.2),
    "color_kernel": transition_kernels.LaplaceNotTruncatedColorDriftKernel(scale=0.05),
    "visibility_prob_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.05, support=jnp.array([0.001, 0.999])
    ),
    "depth_nonreturn_prob_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.05, support=jnp.array([0.001, 0.999])
    ),
    "depth_scale_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.05,
        support=jnp.array([0.0025, 0.005, 0.01, 0.02]),
    ),
    "color_scale_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.05, support=jnp.array([0.01])
    ),
    "image_kernel": image_kernel.NoOcclusionPerVertexImageKernel(image_kernel.OldOcclusionPixelRGBDDistribution()),
}



In [137]:

T = 0
b3d.rr_set_time(T)
OBJECT_INDEX = 2

template_pose = all_data[T]["camera_pose"].inv() @ all_data[T]["object_poses"][OBJECT_INDEX]
rendered_rgbd = renderer.render_rgbd_from_mesh(meshes[OBJECT_INDEX].transform(template_pose))
xyz_rendered = b3d.xyz_from_depth(rendered_rgbd[..., 3], fx, fy, cx, cy)

fx, fy, cx, cy = all_data[T]["camera_intrinsics"]
xyz_observed = b3d.xyz_from_depth(all_data[T]["rgbd"][..., 3], fx, fy, cx, cy)
mask = all_data[T]["masks"][OBJECT_INDEX] * (xyz_observed[..., 2] > 0) * (jnp.linalg.norm(xyz_rendered - xyz_observed, axis=-1) < 0.01)
model_vertices = template_pose.inv().apply(xyz_rendered[mask])
model_colors = vertex_attributes=all_data[T]["rgbd"][..., :3][mask]

# subset = jax.random.permutation(jax.random.PRNGKey(0), len(model_vertices))[:len(model_vertices) // 2]
# model_vertices = model_vertices[subset]
# model_colors = model_colors[subset]
mesh = meshes[OBJECT_INDEX]
model_vertices = mesh.vertices
model_colors = mesh.vertex_attributes

hyperparams["intrinsics"] = {
    "fx": fx, "fy": fy, "cx": cx, "cy": cy,
        "image_height": Pytree.const(image_height),
    "image_width": Pytree.const(image_width),
    "near": 0.01,
    "far": 3.0,
}
hyperparams["vertices"] = model_vertices


num_vertices = model_vertices.shape[0]
previous_state = {
    "pose": template_pose,
    "colors": model_colors,
    "visibility_prob": jnp.ones(num_vertices)
    * hyperparams["visibility_prob_kernel"].support[-1],
    "depth_nonreturn_prob": jnp.ones(num_vertices)
    * hyperparams["depth_nonreturn_prob_kernel"].support[0],
    "depth_scale": hyperparams["depth_scale_kernel"].support[0],
    "color_scale": hyperparams["color_scale_kernel"].support[0],
}

choicemap = (
    genjax.ChoiceMap.d(
        {
            "pose": previous_state["pose"],
            "color_scale": previous_state["color_scale"],
            "depth_scale": previous_state["depth_scale"],
            "rgbd": all_data[T]["rgbd"],
        }
    ) ^ 
    make_visibility_prob_choicemap(previous_state["visibility_prob"]) ^
    make_colors_choicemap(previous_state["colors"]) ^
    make_depth_nonreturn_prob_choicemap(previous_state["depth_nonreturn_prob"])
)
key = jax.random.PRNGKey(0)

trace= dynamic_object_generative_model.importance(key, choicemap, (hyperparams, previous_state))[0]
print(trace.get_score())
og_trace = trace
b3d.chisight.gen3d.model.viz_trace(trace, T,
        ground_truth_vertices=meshes[OBJECT_INDEX].vertices,
        ground_truth_pose=all_data[T]["camera_pose"].inv() @ all_data[T]["object_poses"][OBJECT_INDEX]                                   
)
results = {}

-10053.277


In [138]:
import b3d.chisight.gen3d.inference_old as inference
import b3d.chisight.gen3d.settings 
b3d.reload(b3d.chisight.gen3d.inference_old)
inference_hyperparams = b3d.chisight.gen3d.settings.inference_hyperparams
import b3d.chisight.gen3d.visualization as viz
b3d.reload(b3d.chisight.gen3d.visualization)
import b3d.chisight.gen3d.visualization as viz
import b3d.chisight.gen3d.image_kernel
b3d.reload(b3d.chisight.gen3d.image_kernel)

In [139]:
# trace, _ = inference.update_vertex_attributes(key, trace, inference_hyperparams)
# b3d.chisight.gen3d.model.viz_trace(trace, 1,
#         ground_truth_vertices=meshes[OBJECT_INDEX].vertices,
#         ground_truth_pose=all_data[T]["camera_pose"].inv() @ all_data[T]["object_poses"][OBJECT_INDEX]                                   
# )
# choicemap_good = trace.get_choices()

In [140]:
trace = trace.update(key, choicemap_good)[0]
b3d.chisight.gen3d.model.viz_trace(trace, 0,
        ground_truth_vertices=meshes[OBJECT_INDEX].vertices,
        ground_truth_pose=all_data[T]["camera_pose"].inv() @ all_data[T]["object_poses"][OBJECT_INDEX]                                   
)

In [141]:
for T in tqdm(range(len(all_data))):
    trace = inference.advance_time(key, trace, all_data[T]["rgbd"])
    trace = inference.inference_step(trace, key, inference_hyperparams)[0]
    results[T] = trace

    b3d.chisight.gen3d.model.viz_trace(trace, T,
            ground_truth_vertices=meshes[OBJECT_INDEX].vertices,
            ground_truth_pose=all_data[T]["camera_pose"].inv() @ all_data[T]["object_poses"][OBJECT_INDEX]                                   
    )

100%|██████████| 49/49 [02:09<00:00,  2.65s/it]


In [84]:
vertex_index = 2426
trace.get_retval()["new_state"]["visibility_prob"][vertex_index]

Array(0., dtype=float32)

In [None]:
for T in tqdm(range(len(all_data))):
    trace = inference.advance_time(key, trace, all_data[T]["rgbd"])
    trace = inference.inference_step(trace, key, inference_hyperparams)[0]
    results[T] = trace

    b3d.chisight.gen3d.model.viz_trace(trace, T,
            ground_truth_vertices=meshes[OBJECT_INDEX].vertices,
            ground_truth_pose=all_data[T]["camera_pose"].inv() @ all_data[T]["object_poses"][OBJECT_INDEX]                                   
    )

In [42]:
T = 1
trace = inference.advance_time(key, trace, all_data[T]["rgbd"])
trace = inference.inference_step(trace, key, inference_hyperparams)[0]
b3d.chisight.gen3d.model.viz_trace(trace, T)

In [72]:
b3d.reload(b3d.chisight.gen3d.visualization)
import b3d.chisight.gen3d.visualization as viz

latent_rgbd_per_point, observed_rgbd_per_point = b3d.chisight.gen3d.image_kernel.get_latent_and_observed_correspondences(
    trace.get_retval()["new_state"], trace.get_args()[0], trace.get_choices()["rgbd"]
)
previous_state = trace.get_args()[1]
vertex_index = 602
print(latent_rgbd_per_point[vertex_index], observed_rgbd_per_point[vertex_index])
previous_color = previous_state["colors"][vertex_index]
previous_visibility_prob = previous_state["visibility_prob"][vertex_index]
previous_dnrp = previous_state["depth_nonreturn_prob"][vertex_index]
observed_rgbd_for_point = observed_rgbd_per_point[vertex_index]
latent_rgbd_for_point = latent_rgbd_per_point[vertex_index]
attribute_proposal_function = inference.attribute_proposal_only_color_and_visibility
viz.create_interactive_visualization(
    observed_rgbd_for_point,
    latent_rgbd_for_point,
    hyperparams, inference_hyperparams,
    previous_color,
    previous_visibility_prob,
    previous_dnrp,
    attribute_proposal_function,
)

[0.54117644 0.37647057 0.2352941  0.9140787 ] [0.54117644 0.37647057 0.2352941  0.851     ]


interactive(children=(FloatSlider(value=0.009999999776482582, continuous_update=False, description='Color Scal…