In [None]:
%%time
import bayes3d as b
import jax.numpy as jnp
import jax
import numpy as np
from tqdm import tqdm
import bayes3d.o3d_viz
import os

In [None]:
b.setup_visualizer()

In [None]:
original_intrinsics = b.Intrinsics(
    height=500,
    width=500,
    fx=500.0, fy=500.0,
    cx=250.0, cy=250.0,
    near=0.001, far=6.0
)

meshes = []
meshes.append(b.mesh.make_cuboid_mesh(jnp.array([0.1, 0.1, 0.1])))
meshes.append(b.mesh.make_cuboid_mesh(jnp.array([0.5, 0.5, 0.02])))
b.setup_renderer(original_intrinsics, num_layers=1024)
for m in meshes:
    b.RENDERER.add_mesh(m)
table_mesh = b.mesh.make_cuboid_mesh(jnp.array([5.0, 5.0, 0.01]))

In [None]:
viz = b.o3d_viz.O3DVis(original_intrinsics)

In [None]:
contact_plane = b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(
    jnp.array([0.0, 1.5, 1.0]),
    jnp.array([0.0, 0.0, 0.0]),
    jnp.array([0.0, 0.0, 1.0]),
))

contact_poses_parallel_jit = jax.jit(
    jax.vmap(
        b.scene_graph.relative_pose_from_edge,
        in_axes=(0, None, 0),
    )
)
contact_poses_jit = jax.jit(
    jax.vmap(
        b.scene_graph.relative_pose_from_edge,
        in_axes=(0, None, None),
    )
)

# TODO: Different shapes

distinct_colors = b.distinct_colors(3)
ids = jnp.array([1, 1, 1, 0, 0, 0, 0, 0, 0])
color = jnp.array([0, 1, 2, 0, 1, 2,2,0,1])

# Frame 1
all_contact_params = jnp.array([
[
    [-0.3, -0.3, 0.0],
    [0.3, -0.3, 0.0],
    [0.3, 0.3, 0.0],
    [-0.15, 0.2, 0.0],
    [-0.3, 0.2, 0.0],
    [-0.45, 0.2, 0.0],
    [-0.15, 0.45, 0.0],
    [-0.3, 0.45,0.0],
    [-0.45, 0.45, 0.0],
],
# Frame 2
[
    [-0.3, -0.3, 0.0],
    [0.3, -0.3, 0.0],
    [0.3, 0.3, 0.0],
    [-0.3, -0.3, 0.0],
    [-0.3, 0.2, 0.0],
    [-0.45, 0.2, 0.0],
    [-0.15, 0.45, 0.0],
    [-0.3, 0.45,0.0],
    [-0.45, 0.45, 0.0],
],
[
    [-0.3, -0.3, 0.0],
    [0.3, -0.3, 0.0],
    [0.3, 0.3, 0.0],
    [-0.3, -0.3, 0.0],
    [0.3, -0.3, 0.0],
    [-0.45, 0.2, 0.0],
    [-0.15, 0.45, 0.0],
    [-0.3, 0.45,0.0],
    [-0.45, 0.45, 0.0],
],
[
    [-0.3, -0.3, 0.0],
    [0.3, -0.3, 0.0],
    [0.3, 0.3, 0.0],
    [-0.3, -0.3, 0.0],
    [0.3, -0.3, 0.0],
    [0.3, 0.3, 0.0],
    [-0.15, 0.45, 0.0],
    [-0.3, 0.45,0.0],
    [-0.45, 0.45, 0.0],
]

])

rgbd_images = []
all_poses = []
for i in range(len(all_contact_params)):
    contact_params = all_contact_params[i]
    poses = contact_plane @ contact_poses_parallel_jit(
        contact_params,
        3,
        b.RENDERER.model_box_dims[ids]
    )
    all_poses.append(poses)
    viz.clear()

    viz.make_trimesh(table_mesh, contact_plane, np.array([221, 174, 126, 255.0])/255.0)
    for i in range(len(poses)):
        viz.make_trimesh(b.RENDERER.meshes[ids[i]], poses[i], np.array([*distinct_colors[color[i]], 1.0]))

    rgbd = viz.capture_image(original_intrinsics, jnp.eye(4))
    rgbd_images.append(rgbd)



In [None]:
np.savez("rgbd.npz",rgbd_images[0])
b.hstack_images([b.get_rgb_image(rgbd.rgb) for rgbd in rgbd_images])


In [None]:
rgbd_original = np.load("rgbd.npz",allow_pickle=True)["arr_0"].item()
SCALING_FACTOR = 0.3
rgbd = b.scale_rgbd(rgbd_original, SCALING_FACTOR)

In [None]:
# intrinsics = rgbd.intrinsics
# observed_point_cloud_image = b.t3d.unproject_depth(rgbd.depth, intrinsics)[:,:,:3]
# observed_point_cloud_image = (
#     observed_point_cloud_image *
#     (b.t3d.apply_transform(observed_point_cloud_image, b.t3d.inverse_pose(contact_plane))[:,:,2]>0.02)[...,None]
# )
# observed_point_cloud_image = b.t3d.unproject_depth(observed_point_cloud_image[:,:,2], intrinsics)


# b.clear()
# b.show_cloud("1", observed_point_cloud_image[:,:,:3].reshape(-1,3))

In [None]:
intrinsics = rgbd.intrinsics
b.setup_renderer(intrinsics)
for m in meshes:
    b.RENDERER.add_mesh(m)
    
observed_point_cloud_image = b.RENDERER.render_multiobject(all_poses[0], ids)[:,:,:3]
b.clear()
b.show_cloud("1", observed_point_cloud_image[:,:,:3].reshape(-1,3))
b.get_depth_image(observed_point_cloud_image[:,:,2])


In [None]:
grid_params = [
    (0.5, jnp.pi, (11,11,11)), (0.2, jnp.pi/3, (11,11,11)), (0.1, jnp.pi/5, (11,11,1)),
    (0.05, jnp.pi/5, (11,11,11)), 
]
contact_param_gridding_schedule = [
    b.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]



In [None]:
threedp3_likelihood_full_hierarchical_bayes_per_pixel_jit = jax.jit(jax.vmap(jax.vmap(jax.vmap(
    b.threedp3_likelihood_per_pixel_jit,
       in_axes=(None, None, None, 0, None, None)),
       in_axes=(None, None, 0, None, None, None)),
       in_axes=(None, 0, None, None, None, None)
), static_argnames=('filter_size',))

In [None]:
VARIANCE_GRID = jnp.array([ 0.0001])
OUTLIER_GRID = jnp.array([0.01])
OUTLIER_VOLUME = 1000.0

In [None]:
def refine(trace, init_contact_param, i, obj_id):
    contact_param_grid = contact_param_gridding_schedule[i] + init_contact_param
    potential_new_object_poses = contact_plane @ contact_poses_jit(
        contact_param_grid,
        3,
        b.RENDERER.model_box_dims[obj_id],
    )
    potential_poses = jnp.concatenate(
        [
            jnp.tile(trace.poses[:,None,...], (1,potential_new_object_poses.shape[0],1,1)),
            potential_new_object_poses[None,...]
        ]
    )
    traces = b.Traces(
        potential_poses, jnp.concatenate([trace.ids, jnp.array([obj_id])]), VARIANCE_GRID, OUTLIER_GRID,
        trace.outlier_volume, trace.observation
    )
    p = b.score_traces(traces)

    ii,jj,kk = jnp.unravel_index(p.argmax(), p.shape)
    contact_param = contact_param_grid[ii]
    return contact_param, traces[ii,jj,kk]

refine_jit = jax.jit(refine, static_argnames=("i", "obj_id",))

In [None]:
# trace = Trace(
#     plane_pose[None,...], [21],
#     VARIANCE, 0.01,
#     observed_point_cloud_image
# )


gt_trace = b.Trace(
    poses, ids, VARIANCE_GRID[0], OUTLIER_GRID[0], OUTLIER_VOLUME,
    observed_point_cloud_image
)
print(b.score_trace(gt_trace))
b.show_cloud("rerender", b.render_image(gt_trace)[:,:,:3].reshape(-1,3),color=b.RED)

trace = b.Trace(
    jnp.zeros((0,4,4)), jnp.array([],dtype=jnp.int32),
    VARIANCE_GRID[0], OUTLIER_GRID[0], OUTLIER_VOLUME,
    observed_point_cloud_image
)
b.viz_trace_meshcat(trace)

In [None]:
%%time
for _ in range(10):
    all_paths = []
    for obj_id in tqdm(range(len(b.RENDERER.meshes))):
        contact_param = jnp.zeros(3)
        p = None
        trace_path = []
        for c2f_iter in range(len(contact_param_gridding_schedule)):
            contact_param, trace_ = refine_jit(trace, contact_param, c2f_iter, obj_id)
            trace_path.append(trace_)

        all_paths.append(
            trace_path
        )


    scores = jnp.array([b.score_trace(t[-1]) for t in all_paths])
    normalized_scores = b.utils.normalize_log_scores(scores)
    # print(["{:0.3f}".format(n) for n in normalized_scores])
    order = jnp.argsort(-scores)
    # print(order)
    new_trace = all_paths[jnp.argmax(scores)][-1]
    trace = new_trace
    b.viz_trace_meshcat(trace)


In [None]:
b.viz_trace_meshcat(new_trace, renderer)

In [None]:
b.viz_trace_meshcat(trace, renderer)
trace.ids

In [None]:
t = all_paths[0][0]
b.viz_trace_meshcat(t, renderer)
b.score_trace(t, renderer)

In [None]:
[(b.score_trace(t, renderer),t.variance, t.outlier_prob, t.outlier_volume) for t in all_paths[0]]

In [None]:
print(all_paths[0][0])

In [None]:
reconstruction = b.render_image(trace, renderer)
b.get_depth_image(reconstruction[:,:,2])

In [None]:
print(trace.variance, trace.outlier_prob, trace.outlier_volume)
p = b.threedp3_likelihood_per_pixel_jit(
    trace.observation, reconstruction[:,:,:3],
    trace.variance, 0.0, 1.0,
    3
)
outlier_density = jnp.log(trace.outlier_prob) - jnp.log(0.0005)
b.get_depth_image(1.0 * (outlier_density > p), min=0.0, max=1.0)

In [None]:
potential_new_trace = all_traces[0]
potential_new_trace.poses = potential_new_trace.poses.at[-1].set(poses[4])
b.viz_trace_meshcat(potential_new_trace, renderer)

In [None]:
print(b.score_trace(trace, renderer))
print(b.score_trace(potential_new_trace, renderer))

In [None]:
b.viz_trace_meshcat(all_traces[1], renderer)

In [None]:
b.clear()
seg = b.render_image(trace, renderer)[:,:,3]
# b.show_cloud("rerender", b.render_image(trace,renderer)[:,:,:3].reshape(-1,3),color=b.RED)

In [None]:
b.get_depth_image(seg)

In [None]:
trace.variance

In [None]:
inferred_colors = []
distinct_colors = jnp.array(distinct_colors)
for i in range(1,len(trace.ids)+1):
    seg_colors = rgbd.rgb[seg == i ,:3]
    distances = jnp.abs(seg_colors[:,None,:]/255.0 - distinct_colors[None,...]).sum(-1)
    values, counts = np.unique(jnp.argmin(distances,axis=-1), return_counts=True)
    inferred_colors.append(values[counts.argmax()])
inferred_colors

In [None]:
trace.ids

In [None]:
color

In [None]:
ids

In [None]:
distinct_colors = jnp.array(distinct_colors)

In [None]:
distances

In [None]:
seg_colors

In [None]:
## b.score_trace(gt_trace, renderer, filter_size=i)

In [None]:
x = b.render_image(gt_trace, renderer)[:,:,:3]
b.clear()
b.show_cloud("1", x.reshape(-1,3))

In [None]:
b

In [None]:
scores

In [None]:
# object_types: cube, sphere, pyramid, pad
# Output
# List of (object_type, color, contact_params)