In [1]:
import b3d
import jax.numpy as jnp
import os
from b3d import Mesh, Pose
import jax
import genjax
from genjax import Pytree
import rerun as rr
from b3d.modeling_utils import uniform_discrete, uniform_pose, gaussian_vmf


In [2]:
b3d.rr_init("interactive")

ycb_dir = os.path.join(b3d.get_assets_path(), "bop/ycbv")
scene_id = 49
image_id = 100

all_data = b3d.io.get_ycbv_test_images(ycb_dir, scene_id, [image_id])

meshes = [
    Mesh.from_obj_file(os.path.join(ycb_dir, f'models/obj_{f"{id + 1}".rjust(6, "0")}.ply')).scale(0.001)
    for id in all_data[0]["object_types"]
]

height, width = all_data[0]["rgbd"].shape[:2]
fx,fy,cx,cy = all_data[0]["camera_intrinsics"]
scaling_factor = 0.3
renderer = b3d.renderer.renderer_original.RendererOriginal(
    width * scaling_factor, height * scaling_factor, fx * scaling_factor, fy * scaling_factor, cx * scaling_factor, cy * scaling_factor, 0.01, 2.0
)

100%|██████████| 1/1 [00:03<00:00,  3.01s/it]


In [3]:
@jax.jit
def surface_krays_likelihood_intermediate(observed_rgbd, rendered_rgbd, likelihood_args):
    fx = likelihood_args["fx"]
    fy = likelihood_args["fy"]

    observed_rgb = observed_rgbd[..., :3]
    observed_depth = observed_rgbd[...,3]

    rendered_rgb = rendered_rgbd[..., :3]
    rendered_depth = rendered_rgbd[...,3]

    observed_lab = b3d.colors.rgb_to_lab(observed_rgb)
    rendered_lab = b3d.colors.rgb_to_lab(rendered_rgb)

    color_variance = likelihood_args["color_variance"]
    depth_variance = likelihood_args["depth_variance"]
    outlier_prob = likelihood_args["outlier_prob"]
    outlier_volume = likelihood_args["outlier_volume"]

    multiplier = likelihood_args["multiplier"]

    observed_depth_corrected = observed_depth + (observed_depth == 0.0) * renderer.far
    rendered_areas = (rendered_depth / fx) * (rendered_depth / fy)
    observed_areas = (observed_depth_corrected / fx) * (observed_depth_corrected / fy)

    color_pdf = jax.scipy.stats.truncnorm.pdf(
        (observed_lab - rendered_lab) / jnp.array([100.0, 120.0, 120.0]),
        a=jnp.ones(3) * -1.0,
        b=jnp.ones(3) * 1.0,
        loc=jnp.zeros(3),
        scale=color_variance
    ).prod(-1)

    depth_width = 0.1
    depth_pdf = jax.scipy.stats.truncnorm.pdf(
        (observed_depth - rendered_depth) / 1.0,
        a=-depth_width,
        b=depth_width,
        loc=0.0,
        scale=depth_variance
    )

    inliers_integral = color_pdf * depth_pdf * rendered_areas * (1.0 - outlier_prob)
    outlier_integral = observed_areas * 1 / outlier_volume * outlier_prob 
    pixelwise_score = (inliers_integral + outlier_integral)

    return {
        "score": jnp.log(jnp.sum(pixelwise_score))  * multiplier,
        "pixelwise_score": pixelwise_score,
        "rendered_rgbd": rendered_rgbd
    }

import b3d.chisight.dense.likelihoods.image_likelihood
from b3d.chisight.dense.likelihoods.simple_likelihood import simple_likelihood

intermediate_func = surface_krays_likelihood_intermediate
image_likelihood = b3d.chisight.dense.likelihoods.image_likelihood.make_image_likelihood(
    intermediate_func,
)

@genjax.gen
def dense_multiobject_model(num_objects, meshes, likelihood_args):
    all_poses = []
    for i in range(num_objects.const):
        object_pose = uniform_pose(jnp.ones(3)*-100.0, jnp.ones(3)*100.0) @ f"object_pose_{i}"
        all_poses.append(object_pose)

    all_poses = Pose.stack_poses(all_poses)
    scene_mesh = Mesh.transform_and_merge_meshes(meshes, all_poses)

    rendered_rgbd = renderer.render_rgbd_from_mesh(scene_mesh)
    image = image_likelihood(rendered_rgbd, likelihood_args) @ "image"
    return {"scene_mesh": scene_mesh, "image": image}

importance_jit = jax.jit(dense_multiobject_model.importance)

In [4]:

color_variance=0.01
depth_variance=0.02
outlier_prob=0.0001
outlier_volume=5.0
multiplier=65.0
likelihood_args= {
    "color_variance": jnp.ones(3) * color_variance,
    "depth_variance": depth_variance,
    "outlier_prob": outlier_prob,
    "outlier_volume": outlier_volume,
    "multiplier": multiplier,
    "fx": fx,
    "fy": fy
}
initial_camera_pose = all_data[0]["camera_pose"]
initial_object_poses = all_data[0]["object_poses"]
IDX = 0
pose = initial_camera_pose.inv() @ initial_object_poses[IDX]
observed_image_resized = b3d.utils.resize_image(
    all_data[0]["rgbd"], renderer.height, renderer.width
)
choicemap = genjax.ChoiceMap.d(
    dict(
        [
            ("object_pose_0",  pose),
            ("image", observed_image_resized
            )
        ]
    )
)


b3d.rr_init("interactive")

num_samples = 10
for t in range(num_samples):
    rr.log(
        f"img_{t}",
        rr.Image(observed_image_resized[..., :3]),
        timeless=True
    )

trace, _ = importance_jit(
    jax.random.PRNGKey(2),
    choicemap,
    (Pytree.const(1), [meshes[IDX]], likelihood_args),
)

In [5]:
key = jax.random.PRNGKey(1)

In [10]:
key = jax.random.split(key, 2)[-1]
w = 0.05
batch_length = 700
test_poses = Pose.concatenate_poses(
    [
        pose @ Pose.sample_uniform_pose_vmap(
            jax.random.split(key, 20000),
            -w/2*jnp.ones(3),
            w/2*jnp.ones(3)
        ),
        # pose[None,...]
    ]
)
split_poses = [test_poses[i] for i in jnp.array_split(jnp.arange(test_poses.shape[0]), test_poses.shape[0] // batch_length + 1)]
print(len(split_poses))

29


In [7]:
@jax.jit
def fine_grain_inference(
    color_variance,
    depth_variance,
    outlier_prob,
    outlier_volume,
    multiplier,
    split_poses
):
    key = jax.random.split(jax.random.PRNGKey(0), 2)[-1]
    likelihood_args= {
        "color_variance": jnp.ones(3) * color_variance,
        "depth_variance": depth_variance,
        "outlier_prob": outlier_prob,
        "outlier_volume": outlier_volume,
        "multiplier": multiplier,
        "fx": fx,
        "fy": fy
    }
    trace, _ = dense_multiobject_model.importance(
        jax.random.PRNGKey(2),
        choicemap,
        (Pytree.const(1), [meshes[IDX]], likelihood_args),
    )

    scores = jnp.concatenate([
        b3d.enumerate_choices_get_scores(
            trace, key, Pytree.const(("object_pose_0",)), p
        )
        for p in split_poses
    ])

    sampled_indices = jax.random.categorical(key, scores, shape=(num_samples,))
    return sampled_indices, scores, trace, key

In [8]:
def visualize_posterior_samples(
    color_variance,
    depth_variance,
    outlier_prob,
    outlier_volume,
    multiplier,
    split_poses,
):
    sampled_indices, scores, trace, key = fine_grain_inference(
        color_variance,
        depth_variance,
        outlier_prob,
        outlier_volume,
        multiplier,
        split_poses,
    )
    print(sampled_indices)
    print(scores[sampled_indices])

    rendered_rgbds = jax.vmap(
        lambda i: renderer.render_rgbd_from_mesh(
            meshes[IDX].transform(test_poses[i])
        )
    )(sampled_indices)

    for t in range(len(sampled_indices)):
        rr.log(
            f"img_{t}/rerender",
            rr.Image(rendered_rgbds[t][...,:3])
        )    
    
    # for t in range(len(sampled_indices)):
    #     sampled_trace = b3d.update_choices_jit(trace, key,  Pytree.const(("object_pose_0",)), test_poses[sampled_indices[t]])
    #     intermediate_info = surface_krays_likelihood_intermediate(
    #         sampled_trace.get_choices()["image"], 
    #         sampled_trace.get_retval()["scene_mesh"],
    #         renderer,
    #         sampled_trace.get_args()[2]
    #     )
    #     rr.log(
    #         f"img_{t}/rerender",
    #         rr.Image(intermediate_info["rendered_rgbd"][...,:3])
    #     )
    
visualize_posterior_samples(
    color_variance=0.01,
    depth_variance=0.02,
    outlier_prob=0.0001,
    outlier_volume=5.0,
    multiplier=65.0,
    split_poses=split_poses
)

[5420 5420 5420 5420 5420 5420 5420 5420 5420 5420]
[188.60237 188.60237 188.60237 188.60237 188.60237 188.60237 188.60237
 188.60237 188.60237 188.60237]


In [11]:
from ipywidgets import interact, interactive, fixed, interact_manual
from ipywidgets import FloatSlider
rr.set_time_sequence("time", 0)
cu = False
interact(
    visualize_posterior_samples,
    color_variance=FloatSlider(value=0.13,min=0.001, max=0.5, step=0.001, continuous_update=cu),
    depth_variance=FloatSlider(value=0.06,min=0.001, max=0.5, step=0.001, continuous_update=cu),
    outlier_prob=FloatSlider(value=0.001, min=0.00001, max=1.0, step=0.05, continuous_update=cu),
    outlier_volume=FloatSlider(value=1.0,min=1.0, max=10.0, step=1.0, continuous_update=cu),
    multiplier=FloatSlider(value=100.0,min=1.0, max=200.0, step=1.0, continuous_update=cu),
    split_poses=fixed(split_poses)
)

interactive(children=(FloatSlider(value=0.13, continuous_update=False, description='color_variance', max=0.5, …

<function __main__.visualize_posterior_samples(color_variance, depth_variance, outlier_prob, outlier_volume, multiplier, split_poses)>

In [33]:
scores, scores.max()

(Array([-6.003183 , -5.982863 , -5.952487 , ..., -5.967316 , -5.9038825,
        -5.699437 ], dtype=float32),
 Array(-5.699437, dtype=float32))

In [102]:
print(sampled_indices)
pose_samples = test_poses[sampled_indices]
images = renderer.render_rgbd_from_mesh(
    meshes[IDX].transform(pose_samples[0])
)
images = jax.vmap(
    lambda i: renderer.render_rgbd_from_mesh(
        meshes[IDX].transform(pose_samples[i])
    )
)(jnp.arange(len(pose_samples)))



for t in range(len(images)):
    rr.set_time_sequence("time", t)
    rr.log(
        "image/rerender",
        rr.Image(images[t][..., :3]),
    )


[14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095
 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095
 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095
 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095
 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095
 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095
 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095
 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095 14095
 14095 14095 14095 14095]


In [104]:
scores.max()

Array(2492.451, dtype=float32)

In [103]:
scores

Array([-6.0253487, -6.025215 , 86.84036  , ..., -3.1231146,  3.211216 ,
       99.93143  ], dtype=float32)

Array([2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138,
       2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138,
       2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138,
       2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138,
       2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138,
       2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138,
       2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138,
       2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138,
       2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138, 2138,
       2138], dtype=int32)