In [1]:
import bayes3d as b
import jax.numpy as jnp
import jax

In [2]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7020/static/


In [3]:

intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=50.0, fy=50.0,
    cx=50.0, cy=50.0,
    near=0.001, far=16.0
)

import trimesh
def as_mesh(scene_or_mesh):
    """
    Convert a possible scene to a mesh.

    If conversion occurs, the returned mesh has only vertex and face data.
    """
    if isinstance(scene_or_mesh, trimesh.Scene):
        if len(scene_or_mesh.geometry) == 0:
            mesh = None  # empty scene
        else:
            # we lose texture information here
            mesh = trimesh.util.concatenate(
                tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
                    for g in scene_or_mesh.geometry.values()))
    else:
        assert(isinstance(mesh, trimesh.Trimesh))
        mesh = scene_or_mesh
    return mesh
mesh  =as_mesh(trimesh.load('InteriorTest.obj'))
mesh.vertices  = mesh.vertices * jnp.array([1.0, -1.0, 1.0])
b.setup_renderer(intrinsics)
b.RENDERER.add_mesh(mesh)


[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (128, 128, 1024)
Centering mesh with translation [-0.09099948 -1.499594   -0.01582694]


In [4]:
sequence = [jnp.eye(4)]
transform = b.transform_from_rot_and_pos(b.rotation_from_axis_angle(jnp.array([0.0, 1.0, 0.0]), -jnp.pi/100.0), jnp.array([0.1, 0.0, 0.0]))
for _ in range(30): sequence.append(sequence[-1] @ transform)
camera_poses = jnp.stack(sequence)
inverse_camera_poses = b.inverse_pose(camera_poses)
images = b.RENDERER.render_many(inverse_camera_poses[:,None,...],jnp.array([0]))[...,:3]
viz_images = [b.get_depth_image(images[i,...,2], max=10.0) for i in range(images.shape[0])]
b.make_gif_from_pil_images(viz_images, 'room.gif')
b.clear()
for i in range(len(sequence)):
    b.show_pose(f"{i}", sequence[i])

In [5]:
translation_deltas = b.utils.make_translation_grid_enumeration(-0.2, -0.2, -0.2, 0.2, 0.2, 0.2, 25, 25, 25)
rotation_deltas = jax.vmap(lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.0001, 800.0))(
    jax.random.split(jax.random.PRNGKey(30), 100)
)

In [6]:
def grid_and_select(current_pose, diffs, obs_img, variance):
    possible_poses = jnp.einsum("aij,jk->aik", diffs, current_pose)
    rendered_imgs = b.RENDERER.render_many(b.inverse_pose(possible_poses)[:,None,...], jnp.array([0]))[...,:3]
    scores = jax.vmap(b.threedp3_likelihood, in_axes=(None, 0, None, None))(obs_img, rendered_imgs, variance, 0.0)
    return possible_poses[jnp.argmax(scores)], scores.max()
grid_and_select_jit = jax.jit(grid_and_select)

In [33]:
T = 1
current_pose = camera_poses[0]
b.clear()
b.show_pose(f"current_pose", current_pose)
b.show_pose(f"ground_truth", camera_poses[T])

In [55]:
translation_deltas = b.utils.make_translation_grid_enumeration(-0.1, -0.1, -0.1, 0.1, 0.1, 0.1, 25, 25, 25)

In [53]:
translation_deltas = b.utils.make_translation_grid_enumeration(-0.05, -0.05, -0.05, 0.05, 0.05, 0.05, 25, 25, 25)

In [56]:
T = 1
b.clear()
current_pose,score = grid_and_select_jit(current_pose, translation_deltas, images[T], 0.02)
# current_pose,score = grid_and_select_jit(current_pose, rotation_deltas, images[T], 0.02)
print(score)
b.show_pose(f"current_pose", current_pose)
b.show_pose(f"ground_truth", camera_poses[T])

23.425


In [31]:
print(current_pose)
print(camera_poses[T])

[[ 1.0000000e+00  0.0000000e+00  0.0000000e+00  6.6666678e-02]
 [ 0.0000000e+00  1.0000000e+00  0.0000000e+00  1.8626451e-08]
 [ 0.0000000e+00  0.0000000e+00  1.0000000e+00 -4.1666478e-03]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  1.0000000e+00]]
[[ 0.99950653  0.         -0.03141076  0.1       ]
 [ 0.          1.          0.          0.        ]
 [ 0.03141076  0.          0.99950653  0.        ]
 [ 0.          0.          0.          1.        ]]


In [29]:
b.threedp3_likelihood(images[T],images[T], 0.02, 0.0)

Array(50., dtype=float32)

In [52]:
b.make_gif_from_pil_images(viz_images, 'room.gif')


In [21]:
b.show_trimesh("1",b.RENDERER.meshes[0])

(11, 100, 100, 4)