In [1]:
import bayes3d as b
import jax.numpy as jnp
import jax
from tqdm import tqdm

In [2]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7041/static/


In [3]:

intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=50.0, fy=50.0,
    cx=50.0, cy=50.0,
    near=0.001, far=16.0
)
from bayes3d.rendering.nvdiffrast_jax.jax_renderer import Renderer as JaxRenderer
jax_renderer = JaxRenderer(intrinsics)


In [4]:

import trimesh
def as_mesh(scene_or_mesh):
    """
    Convert a possible scene to a mesh.

    If conversion occurs, the returned mesh has only vertex and face data.
    """
    if isinstance(scene_or_mesh, trimesh.Scene):
        if len(scene_or_mesh.geometry) == 0:
            mesh = None  # empty scene
        else:
            # we lose texture information here
            mesh = trimesh.util.concatenate(
                tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
                    for g in scene_or_mesh.geometry.values()))
    else:
        assert(isinstance(mesh, trimesh.Trimesh))
        mesh = scene_or_mesh
    return mesh
mesh  =as_mesh(trimesh.load('InteriorTest.obj'))
mesh.vertices  = mesh.vertices * jnp.array([1.0, -1.0, 1.0]) + jnp.array([0.0, 1.0, 0.0])
vertices = mesh.vertices
faces = mesh.faces

b.show_trimesh("1",mesh)

In [81]:
sequence = [jnp.eye(4)]
transform = b.transform_from_rot_and_pos(b.rotation_from_axis_angle(jnp.array([0.0, 1.0, 0.0]), -jnp.pi/100000.0), jnp.array([0.0, 0.0, 0.1]))
for _ in range(25): sequence.append(sequence[-1] @ transform)
camera_poses = jnp.stack(sequence)
b.clear()
for i in range(len(sequence)):
    b.show_pose(f"{i}", sequence[i])
b.show_trimesh("1",mesh)


In [82]:
gt_images = [
    jax_renderer.render(vertices, faces, b.inverse_pose(p), intrinsics)[0][0,...]
    for p in camera_poses
]
b.make_gif_from_pil_images([b.get_depth_image(img, remove_max=False) for img in gt_images], "gt.gif")

In [75]:
def loss(trans, q, gt_img):
    camera_pose = b.translation_and_quaternion_to_pose_matrix(trans, q)
    img = jax_renderer.render(vertices, faces, b.inverse_pose(camera_pose), intrinsics)[0][0,...]
    return (jnp.abs(img - gt_img)).mean()

value_and_grad_jit = jax.jit(jax.value_and_grad(loss, argnums=(0,1,)))

In [76]:
b.clear()
b.show_pose("actual", camera_poses[1])
tr,q = b.pose_matrix_to_translation_and_quaternion(camera_poses[0])
b.show_pose("inferred", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)

In [77]:
print("start " , value_and_grad_jit(tr, q, gt_images[1]))
poses = []
pbar = tqdm(range(200))
timestep = 1
for _  in pbar:
    loss, (g1, g2) = value_and_grad_jit(tr, q, gt_images[timestep])
    tr -= g1 * 0.01
    q -= g2 * 0.01
    pbar.set_description(f"{loss}")
    # poses.append(b.translation_and_quaternion_to_pose_matrix(tr,q))
b.show_pose("inferred", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)

start  (Array(0.39068013, dtype=float32), (Array([ 0.10627548,  0.13037287, -0.3495799 ], dtype=float32), Array([ 0.        ,  0.10087404,  0.48157406, -0.34359443], dtype=float32)))


0.11451186239719391: 100%|██████████| 200/200 [00:00<00:00, 384.84it/s] 


In [78]:
b.clear()
b.show_pose("actual", camera_poses[1])
tr,q = b.pose_matrix_to_translation_and_quaternion(camera_poses[0])
b.show_pose("inferred", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)

In [79]:
print("start " , value_and_grad_jit(tr, q, gt_images[1]))
poses = []
pbar2 = tqdm(range(len(gt_images)))
for timestep  in pbar2:
    pbar = tqdm(range(200))
    b.show_pose("2", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)
    for _  in pbar:
        loss, (g1, g2) = value_and_grad_jit(tr, q, gt_images[timestep])
        tr -= g1 * 0.01
        q -= g2 * 0.01
        pbar.set_description(f"{loss}")
    b.show_pose("actual", camera_poses[timestep])
    b.show_pose("inferred", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)
    poses.append(b.translation_and_quaternion_to_pose_matrix(tr,q))

start  (Array(0.39068013, dtype=float32), (Array([ 0.10627548,  0.13037287, -0.3495799 ], dtype=float32), Array([ 0.        ,  0.10087411,  0.48157406, -0.3435945 ], dtype=float32)))


0.3816789984703064: 100%|██████████| 200/200 [00:00<00:00, 276.31it/s]
0.3915194272994995: 100%|██████████| 200/200 [00:00<00:00, 287.50it/s]
0.22637736797332764: 100%|██████████| 200/200 [00:00<00:00, 271.57it/s]
0.2539404630661011: 100%|██████████| 200/200 [00:00<00:00, 280.13it/s]
0.03320816904306412: 100%|██████████| 200/200 [00:00<00:00, 280.94it/s]
0.03555193543434143: 100%|██████████| 200/200 [00:00<00:00, 280.71it/s]
0.15251809358596802: 100%|██████████| 200/200 [00:00<00:00, 266.10it/s]
0.17262689769268036: 100%|██████████| 200/200 [00:00<00:00, 277.07it/s]
0.19562497735023499: 100%|██████████| 200/200 [00:00<00:00, 283.27it/s]
0.07484032958745956: 100%|██████████| 200/200 [00:00<00:00, 259.38it/s]
0.025297656655311584: 100%|██████████| 200/200 [00:00<00:00, 262.83it/s]
0.051032885909080505: 100%|██████████| 200/200 [00:00<00:00, 250.05it/s]
0.012134687043726444: 100%|██████████| 200/200 [00:00<00:00, 282.20it/s]
0.032516807317733765: 100%|██████████| 200/200 [00:00<00:00, 281

In [80]:
b.clear()
for i in range(len(poses)):
    b.show_pose(f"{i}", poses[i])
    b.show_pose(f"{i}_actual", camera_poses[i])

In [None]:
b.show_pose("2", b.translation_and_quaternion_to_pose_matrix(tr,q), size=0.1)


In [56]:

# b.show_pose("2", 

In [18]:
value_and_grad_jit(*b.pose_matrix_to_translation_and_quaternion(camera_poses[0]), gt_images[1])

(Array(0.43455184, dtype=float32),
 (Array([-0.15023103, -0.33511856,  0.05152562], dtype=float32),
  Array([ 0.        ,  1.6910034 , -0.15435895,  0.6346966 ], dtype=float32)))

In [5]:
translation_deltas = b.utils.make_translation_grid_enumeration(-0.2, -0.2, -0.2, 0.2, 0.2, 0.2, 25, 25, 25)
rotation_deltas = jax.vmap(lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.0001, 800.0))(
    jax.random.split(jax.random.PRNGKey(30), 100)
)

In [6]:
def grid_and_select(current_pose, diffs, obs_img, variance):
    possible_poses = jnp.einsum("aij,jk->aik", diffs, current_pose)
    rendered_imgs = b.RENDERER.render_many(b.inverse_pose(possible_poses)[:,None,...], jnp.array([0]))[...,:3]
    scores = jax.vmap(b.threedp3_likelihood, in_axes=(None, 0, None, None))(obs_img, rendered_imgs, variance, 0.0)
    return possible_poses[jnp.argmax(scores)], scores.max()
grid_and_select_jit = jax.jit(grid_and_select)

In [33]:
T = 1
current_pose = camera_poses[0]
b.clear()
b.show_pose(f"current_pose", current_pose)
b.show_pose(f"ground_truth", camera_poses[T])

In [55]:
translation_deltas = b.utils.make_translation_grid_enumeration(-0.1, -0.1, -0.1, 0.1, 0.1, 0.1, 25, 25, 25)

In [53]:
translation_deltas = b.utils.make_translation_grid_enumeration(-0.05, -0.05, -0.05, 0.05, 0.05, 0.05, 25, 25, 25)

In [56]:
T = 1
b.clear()
current_pose,score = grid_and_select_jit(current_pose, translation_deltas, images[T], 0.02)
# current_pose,score = grid_and_select_jit(current_pose, rotation_deltas, images[T], 0.02)
print(score)
b.show_pose(f"current_pose", current_pose)
b.show_pose(f"ground_truth", camera_poses[T])

23.425


In [31]:
print(current_pose)
print(camera_poses[T])

[[ 1.0000000e+00  0.0000000e+00  0.0000000e+00  6.6666678e-02]
 [ 0.0000000e+00  1.0000000e+00  0.0000000e+00  1.8626451e-08]
 [ 0.0000000e+00  0.0000000e+00  1.0000000e+00 -4.1666478e-03]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  1.0000000e+00]]
[[ 0.99950653  0.         -0.03141076  0.1       ]
 [ 0.          1.          0.          0.        ]
 [ 0.03141076  0.          0.99950653  0.        ]
 [ 0.          0.          0.          1.        ]]


In [29]:
b.threedp3_likelihood(images[T],images[T], 0.02, 0.0)

Array(50., dtype=float32)

In [52]:
b.make_gif_from_pil_images(viz_images, 'room.gif')


In [21]:
b.show_trimesh("1",b.RENDERER.meshes[0])

(11, 100, 100, 4)