In [None]:
import bayes3d as b
from bayes3d.viz.open3dviz import Open3DVisualizer
import os
import jax.numpy as jnp
import open3d as o3d
import jax
import bayes3d.genjax
from tqdm import tqdm
import genjax

In [None]:
b.setup_visualizer()

In [None]:
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
obj_idx = 4
mesh_filename = os.path.join(model_dir,"obj_" + "{}".format(obj_idx+1).rjust(6, '0') + ".ply")
SCALING_FACTOR = 1.0/1000.0

In [None]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=500.0, fy=500.0,
    cx=50.0, cy=50.0,
    near=0.01, far=50.0
)
b.setup_renderer(intrinsics)
b.RENDERER.add_mesh_from_file(mesh_filename, scaling_factor=SCALING_FACTOR)

In [None]:
object_poses = jnp.array([b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(
            jnp.array([0.0, 0.6, 0.6]),
            jnp.array([0.0, 0.0, 0.0]),
            jnp.array([0.0, 0.0, 1.0]),
        )) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)  for angle in jnp.linspace(-jnp.pi, jnp.pi, 7)[:-1]])
observations = b.RENDERER.render_many(object_poses[:,None,...], jnp.array([0]))

In [None]:
b.hstack_images([b.get_depth_image(o[...,2]) for o in observations])

In [None]:
grid = b.utils.make_translation_grid_enumeration_3d(
    -0.1, -0.1, -0.2,
    0.1, 0.1, 0.2,
    # 100, 100, 100
    60, 60, 60
)
b.show_cloud("grid", grid)

In [None]:
def voxel_occupied_occluded_free(camera_pose, depth_image, grid, intrinsics, tolerance):
    grid_in_cam_frame = b.apply_transform(grid, b.t3d.inverse_pose(camera_pose))
    pixels = b.project_cloud_to_pixels(grid_in_cam_frame, intrinsics).astype(jnp.int32)
    valid_pixels = (0 <= pixels[:,0]) * (0 <= pixels[:,1]) * (pixels[:,0] < intrinsics.width) * (pixels[:,1] < intrinsics.height)
    real_depth_vals = depth_image[pixels[:,1],pixels[:,0]] * valid_pixels + (1 - valid_pixels) * (intrinsics.far + 1.0)
    
    projected_depth_vals = grid_in_cam_frame[:,2]
    occupied = jnp.abs(real_depth_vals - projected_depth_vals) < tolerance
    occluded = real_depth_vals < projected_depth_vals
    occluded = occluded * (1.0 - occupied)
    free = (1.0 - occluded) * (1.0 - occupied)
    return 1.0 * occupied + 0.5 * occluded
voxel_occupied_occluded_free_parallel = jax.jit(jax.vmap(voxel_occupied_occluded_free, in_axes=(0, 0, None, None, None)))

In [None]:
occupancies = voxel_occupied_occluded_free_parallel(
    b.inverse_pose(object_poses), observations[...,2], grid, intrinsics, 0.001
)
print(occupancies.sum())

In [None]:
b.clear()
b.show_cloud("grid", grid[(occupancies > 0.6).sum(0) > 0 ])
# b.show_cloud("grid2", grid[occupancy == 0.5],color=b.RED)

In [None]:
mesh = b.utils.make_voxel_mesh_from_point_cloud(grid[(occupancies > 0.6).sum(0) > 0 ], 0.005 )

In [None]:
b.clear()
b.show_trimesh("mesh", mesh)

In [None]:
key = jax.random.PRNGKey(10)
random_pose = b.distributions.gaussian_vmf_jit(
    key,
    b.transform_from_pos(jnp.array([0.0, 0.0, 1.4])),
    0.01,
    0.01
)
observation =  b.RENDERER.render(random_pose[None,...], jnp.array([0]))[...,:3]
b.get_depth_image(observation[...,2])

In [None]:
sample_gaussian_vmf_jit = jax.jit(jax.vmap(
    b.distributions.gaussian_vmf_jit,
    in_axes=(0, None, None, None)
))

In [None]:
@genjax.gen
def single_object_model(variance, outlier_prob, outlier_volume):
    pose = b.genjax.uniform_pose(jnp.array([-10.0,-10.0,-10.0]), jnp.array([10.0,10.0,10.0])) @ "pose"
    rendered = b.RENDERER.render(
        pose[None,...] , jnp.array([0])
    )[...,:3]
    image = b.genjax.image_likelihood(rendered, variance, outlier_prob, outlier_volume) @ "image"
    return rendered

importance_jit = jax.jit(single_object_model.importance)
key = jax.random.PRNGKey(5)
enumerator = b.genjax.make_enumerator(["pose"]) 

In [None]:
trace = importance_jit(
    key,
    genjax.choice_map({"image": observation}),
    (0.001, 0.001, 1000.0)
)[1][1]

In [None]:
keys = jax.random.split(key, 1000)
poses = sample_gaussian_vmf_jit(keys,     b.transform_from_pos(jnp.array([0.0, 0.0, 1.4])),
    0.01,
    0.01)

In [None]:
scores = enumerator.enumerate_choices_get_scores(trace, key, poses)

In [None]:
trace = enumerator.update_choices(trace, key, poses[scores.argmax()])

In [None]:
b.clear()
b.show_cloud("obs", trace["image"].reshape(-1,3))
b.show_cloud("render", trace.get_retval().reshape(-1,3), color=b.RED)

In [None]:
keys = jax.random.split(key, 5000)
poses = sample_gaussian_vmf_jit(keys,     b.transform_from_pos(jnp.array([0.0, 0.0, 1.4])),
    0.01,
    0.01)

In [None]:
grid_over_pose = jax.jit(jax.vmap(
    lambda trace,key, p: trace.update(key, genjax.choice_map({"pose": p}), tuple(map(lambda v: Diff(v, UnknownChange), trace.args)))
, in_axes=(None, None, 0))
)

In [None]:
grid_over_pose(trace, key, poses)

In [None]:
viz = Open3DVisualizer(intrinsics)

In [None]:
viz.clear()
mesh = o3d.io.read_triangle_model(mesh_filename)
mesh.meshes[0].mesh.scale(SCALING_FACTOR, jnp.array([0.0, 0.0, 0.0]))
viz.render.scene.add_model(f"1", mesh)

In [None]:
images = []
for i, pose in tqdm(enumerate(object_poses)):
    rgbd = viz.capture_image(intrinsics, b.t3d.inverse_pose(pose))
    images.append(rgbd)
b.hstack_images([b.get_rgb_image(rgbd.rgb) for rgbd in images])