In [51]:
import numpy as np
import jax.numpy as jnp
import jax
import bayes3d as b
import time
from PIL import Image
from scipy.spatial.transform import Rotation as R
import matplotlib.pyplot as plt
import cv2
import trimesh
import os

# Can be helpful for debugging:
# jax.config.update('jax_enable_checks', True) 



In [52]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7014/static/


In [94]:
original_intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=80.0, fy=80.0,
    cx=50.0, cy=50.0,
    near=0.001, far=6.0
)

intrinsics = b.scale_camera_parameters(original_intrinsics, 0.5)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
idx = 15
mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/100.0)


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)


[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


In [95]:
b.RENDERER.model_box_dims

Array([[1.84241, 1.87434, 0.57317]], dtype=float32)

In [96]:
camera_pose = b.t3d.transform_from_pos_target_up(
    jnp.array([0.0, 1.5, 1.0]),
    jnp.array([0.0, 0.0, 0.0]),
    jnp.array([0.0, 0.0, 1.0]),
)

camera_poses = jnp.array([
    b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) @ camera_pose
    for angle in jnp.linspace(0, 2*jnp.pi, 120)]
)


In [97]:
poses = jnp.linalg.inv(camera_poses)

observed_images = b.RENDERER.render_many(poses[:,None,...],  jnp.array([0]))
print("observed_images.shape", observed_images.shape)


observed_images.shape (120, 50, 50, 4)


In [98]:
input_viz_images = [b.get_depth_image(i[:,:,2]) for i in observed_images]
b.make_gif_from_pil_images(input_viz_images, "input.gif")

In [99]:
translation_deltas = b.utils.make_translation_grid_enumeration(-0.2, -0.2, -0.2, 0.2, 0.2, 0.2, 11, 11, 11)
rotation_deltas = jax.vmap(lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0))(
    jax.random.split(jax.random.PRNGKey(3), 500)
)

likelihood = jax.vmap(b.threedp3_likelihood_old, in_axes=(None, 0, None, None, None, None, None))

def update_pose_estimate(pose_estimate, gt_image):
    proposals = jnp.einsum("ij,ajk->aik", pose_estimate, translation_deltas)
    rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(proposals[:,None, ...], jnp.array([0]))
    weights_new = likelihood(gt_image, rendered_images, 0.05, 0.1, 10**3, 0.1, 3)
    pose_estimate = proposals[jnp.argmax(weights_new)]

    proposals = jnp.einsum("ij,ajk->aik", pose_estimate, rotation_deltas)
    rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(proposals[:, None, ...], jnp.array([0]))
    weights_new = likelihood(gt_image, rendered_images, 0.05, 0.1, 10**3, 0.1, 3)
    pose_estimate = proposals[jnp.argmax(weights_new)]
    return pose_estimate, pose_estimate

inference_program = jax.jit(lambda p,x: jax.lax.scan(update_pose_estimate, p,x)[1])
inferred_poses = inference_program(poses[0], observed_images)

start = time.time()
pose_estimates_over_time = inference_program(poses[0], observed_images)
end = time.time()
print ("Time elapsed:", end - start)
print ("FPS:", poses.shape[0] / (end - start))


max_depth = 10.0
rerendered_images = b.RENDERER.render_many(pose_estimates_over_time[:, None, ...], jnp.array([0]))
viz_images = []
for (r, d) in zip(rerendered_images, observed_images):
    viz_r = b.viz.scale_image(b.viz.get_depth_image(r[:,:,2]), 5.0)
    viz_d = b.viz.scale_image(b.viz.get_depth_image(d[:,:,2]), 5.0)
    overlay = b.viz.overlay_image(viz_r,viz_d)
    viz_images.append(b.viz.multi_panel(
        [
            viz_d, viz_r, overlay
        ],
        ["Ground Truth", "Inferred Reconstruction", "Overlay"],
    ))

b.make_gif_from_pil_images(viz_images, "demo.gif")

Time elapsed: 1.5517244338989258
FPS: 77.33331858317338
