In [1]:
import os
import time

import jax
import jax.numpy as jnp
from IPython import embed
from scipy.spatial.transform import Rotation as R

import bayes3d as b

# Can be helpful for debugging:
# jax.config.update('jax_enable_checks', True)

assets_dir = os.getenv("B3D_ASSETS_PATH")

intrinsics = b.Intrinsics(
    height=100, width=100, fx=50.0, fy=50.0, cx=50.0, cy=50.0, near=0.001, far=6.0
)

b.setup_renderer(intrinsics)
b.RENDERER.add_mesh_from_file(
    os.path.join(assets_dir, "sample_objs/bunny.obj")
)

num_frames = 60

poses = [b.t3d.transform_from_pos(jnp.array([-3.0, 0.0, 3.5]))]
delta_pose = b.t3d.transform_from_rot_and_pos(
    R.from_euler("zyx", [-1.0, 0.1, 2.0], degrees=True).as_matrix(),
    jnp.array([0.09, 0.05, 0.02]),
)
for t in range(num_frames - 1):
    poses.append(poses[-1].dot(delta_pose))
poses = jnp.stack(poses)
print("Number of frames: ", poses.shape[0])

observed_images = b.RENDERER.render_many(poses[:, None, ...], jnp.array([0]))
print("observed_images.shape", observed_images.shape)

translation_deltas = b.utils.make_translation_grid_enumeration(
    -0.2, -0.2, -0.2, 0.2, 0.2, 0.2, 5, 5, 5
)
rotation_deltas = jax.vmap(
    lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0)
)(jax.random.split(jax.random.PRNGKey(30), 100))

likelihood = jax.vmap(
    b.threedp3_likelihood_old, in_axes=(None, 0, None, None, None, None, None)
)


def update_pose_estimate(pose_estimate, gt_image):
    proposals = jnp.einsum("ij,ajk->aik", pose_estimate, translation_deltas)
    rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(
        proposals[:, None, ...], jnp.array([0])
    )
    weights_new = likelihood(gt_image, rendered_images, 0.05, 0.1, 10**3, 0.1, 3)
    pose_estimate = proposals[jnp.argmax(weights_new)]

    proposals = jnp.einsum("ij,ajk->aik", pose_estimate, rotation_deltas)
    rendered_images = jax.vmap(b.RENDERER.render, in_axes=(0, None))(
        proposals[:, None, ...], jnp.array([0])
    )
    weights_new = likelihood(gt_image, rendered_images, 0.05, 0.1, 10**3, 0.1, 3)
    pose_estimate = proposals[jnp.argmax(weights_new)]
    return pose_estimate, pose_estimate


inference_program = jax.jit(lambda p, x: jax.lax.scan(update_pose_estimate, p, x)[1])
inferred_poses = inference_program(poses[0], observed_images)

start = time.time()
pose_estimates_over_time = inference_program(poses[0], observed_images)
end = time.time()
print("Time elapsed:", end - start)
print("FPS:", poses.shape[0] / (end - start))

rerendered_images = b.RENDERER.render_many(
    pose_estimates_over_time[:, None, ...], jnp.array([0])
)

viz_images = [
    b.viz.multi_panel(
        [
            b.viz.scale_image(b.viz.get_depth_image(d[:, :, 2]), 3),
            b.viz.scale_image(b.viz.get_depth_image(r[:, :, 2]), 3),
        ],
        labels=["Observed", "Rerendered"],
        label_fontsize=20,
    )
    for (r, d) in zip(rerendered_images, observed_images)
]
b.make_gif_from_pil_images(viz_images, "assets/demo.gif")


embed()


[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (128, 128, 1024)


RuntimeError: Cuda error: 801[cudaGraphicsGLRegisterImage(&s.cudaColorBuffer[i], s.glColorBuffer[i], GL_TEXTURE_3D, cudaGraphicsRegisterFlagsReadOnly);]
Exception raised from _setup at bayes3d/rendering/nvdiffrast/common/rasterize_gl.cpp:398 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xba (0x7fff210b9a9a in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xd2 (0x7fff210611dc in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #2: _setup(CUstream_st*, RasterizeGLStateWrapper&, int, int, int) + 0xb1e (0x7ffdf6395dde in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/bayes3d/rendering/nvdiffrast/nvdiffrast_plugin_gl.cpython-311-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x43c0a0f (0x7fff6b5c0a0f in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #4: <unknown function> + 0x43c1215 (0x7fff6b5c1215 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #5: <unknown function> + 0x505b843 (0x7fff6c25b843 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #6: <unknown function> + 0x5058f39 (0x7fff6c258f39 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #7: <unknown function> + 0x505763d (0x7fff6c25763d in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #8: <unknown function> + 0x5056faf (0x7fff6c256faf in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #9: <unknown function> + 0x7414ab6 (0x7fff6e614ab6 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #10: <unknown function> + 0x11147e5 (0x7fff683147e5 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #11: <unknown function> + 0x1115138 (0x7fff68315138 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #12: <unknown function> + 0x107d238 (0x7fff6827d238 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #13: <unknown function> + 0x107f9a7 (0x7fff6827f9a7 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #14: <unknown function> + 0x1081eb0 (0x7fff68281eb0 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #15: <unknown function> + 0x1343c0d (0x7fff68543c0d in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #16: <unknown function> + 0x1019218 (0x7fff68219218 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #17: <unknown function> + 0x101a78c (0x7fff6821a78c in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #18: <unknown function> + 0x64c1c5 (0x7fff6784c1c5 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #19: <unknown function> + 0x64bffd (0x7fff6784bffd in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #20: <unknown function> + 0x138660c (0x7fff6858660c in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
<omitting python frames>
frame #35: <unknown function> + 0x6a8fee (0x7fff678a8fee in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #36: <unknown function> + 0x6a75a3 (0x7fff678a75a3 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #51: <unknown function> + 0x6a8fee (0x7fff678a8fee in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #52: <unknown function> + 0x6a75a3 (0x7fff678a75a3 in /nix/store/v3j9fv95hvmmj6bp7x9sq4vcx8z7fphk-python3-3.11.9-env/lib/python3.11/site-packages/jaxlib/xla_extension.so)
