In [None]:
import numpy as np
import jax.numpy as jnp
import jax
import bayes3d as b
import time
from PIL import Image
from scipy.spatial.transform import Rotation as R
import matplotlib.pyplot as plt
import cv2
import trimesh
import os
import glob
import bayes3d.neural
import pickle
# Can be helpful for debugging:
# jax.config.update('jax_enable_checks', True) 
from bayes3d.neural.segmentation import carvekit_get_foreground_mask
import genjax

In [None]:
b.setup_visualizer()

In [None]:
importance_jit = jax.jit(b.model.importance)
key = jax.random.PRNGKey(10)

In [None]:
paths = glob.glob(
    "panda_scans_v6/*.pkl"
)
all_data = pickle.load(open(paths[0], "rb"))
IDX = 1
data = all_data[IDX]

In [None]:
print(data["camera_image"].keys())
K = data["camera_image"]['camera_matrix'][0]
rgb = data["camera_image"]['rgbPixels']
depth = data["camera_image"]['depthPixels']
camera_pose = data["camera_image"]['camera_pose']
camera_pose = b.t3d.pybullet_pose_to_transform(camera_pose)
fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2]
h,w = depth.shape
near = 0.001
rgbd_original = b.RGBD(rgb, depth, camera_pose, b.Intrinsics(h,w,fx,fy,cx,cy,0.001,10000.0))
b.get_rgb_image(rgbd_original.rgb)

In [None]:
b.get_depth_image(rgbd_original.depth,max=1.5)

In [None]:
scaling_factor = 0.23
rgbd_scaled_down = b.RGBD.scale_rgbd(rgbd_original, scaling_factor)


In [None]:
plane_pose, plane_dims = b.utils.find_plane_and_dims(
    b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3), 
    ransac_threshold=0.001, inlier_threshold=0.001, segmentation_threshold=0.1
)

In [None]:
plane_pose = plane_pose @ b.transform_from_axis_angle(jnp.array([1.0, 0.0, 0.0]), jnp.pi)

In [None]:
b.clear()
b.show_cloud("1", b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3))
b.show_pose("table", plane_pose)

In [None]:
mask = b.utils.scale(carvekit_get_foreground_mask(rgbd_original)*1.0, scaling_factor)

In [None]:
observed_depth = (rgbd_scaled_down.depth * mask) + (1.0 - mask)* rgbd_scaled_down.intrinsics.far

In [None]:
b.clear()
b.show_cloud("1", b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics).reshape(-1,3))
b.show_pose("table", plane_pose)

In [None]:
b.utils.ycb_loader.MODEL_NAMES[10]

In [None]:
b.utils.ycb_loader.MODEL_NAMES[9]

In [None]:
b.setup_renderer(rgbd_scaled_down.intrinsics)
b.RENDERER.add_mesh_from_file("toy_plane.ply")
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
mesh_path = os.path.join(model_dir,"obj_" + "{}".format(13+1).rjust(6, '0') + ".ply")
b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)
mesh_path = os.path.join(model_dir,"obj_" + "{}".format(10+1).rjust(6, '0') + ".ply")
b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)

b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)


In [None]:
for i in range(len(b.RENDERER.meshes)):
    b.show_trimesh(f"mesh_{i}", b.RENDERER.meshes[i])

In [None]:
grid_params = [
    (0.4, jnp.pi, (11,11,11)), (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi, (11,11,11)),
    (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.01, 0.0, (21,21,1)),
    (0.01, jnp.pi/10, (5,5,21)),(0.01, jnp.pi/20, (5,5,21))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]


In [None]:
weight, trace = importance_jit(key, genjax.choice_map({
    "parent_0": -1,
    "parent_1": 0,
    "id_0": jnp.int32(3),
    "camera_pose": jnp.eye(4),
    "root_pose_0": plane_pose,
    "face_parent_1": 2,
    "face_child_1": 3,
    "image": b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics),
    "variance": 0.001,
    "outlier_prob": 0.0001,
    "contact_params_1": jnp.array([0.0, 0.0, 0.0])
}), (
    jnp.arange(1),
    jnp.arange(22),
    jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
    jnp.array([jnp.array([-0.6, -0.6, -4*jnp.pi]), jnp.array([0.6, 0.6, 4*jnp.pi])]),
    b.RENDERER.model_box_dims, 1.0, 1.0)
)
b.viz_trace_meshcat(trace)
print(trace.get_score())

In [None]:
object_number_to_id = [None, 2, 1,0]
# object_number_to_id = [None, 1]

In [None]:
OBJECT_NUMBER = 1

In [None]:
address = f"contact_params_{OBJECT_NUMBER}"
trace = b.add_object_jit(trace, key, object_number_to_id[OBJECT_NUMBER], 0, 2,3)
enumerators = b.make_enumerator([address])
b.viz_trace_meshcat(trace)

In [None]:
traces = []
for idx in range(len(contact_param_gridding_schedule)):
    contact_param_deltas = contact_param_gridding_schedule[idx]
    contact_param_grid = contact_param_deltas + trace[address]
    scores = enumerators[3](trace, key, contact_param_grid)
    i= jnp.unravel_index(scores.argmax(), scores.shape)
    trace = enumerators[0](
        trace, key,
        contact_param_grid[i]
    )
    traces.append(trace)
    b.viz_trace_meshcat(trace)
b.get_depth_image(b.get_rendered_image(trace)[...,2], max=1.0)

In [None]:
OBJECT_NUMBER += 1

In [None]:
depth_viz = b.viz.resize_image(b.get_depth_image(rgbd_original.depth,max=1.5), b.RENDERER.intrinsics.height, b.RENDERER.intrinsics.width)
depth_reconstruction_viz = b.get_depth_image(b.get_rendered_image(trace)[...,2], max=1.0)
seg_viz = b.get_depth_image(b.RENDERER.render(b.get_poses(trace), b.get_indices(trace))[:,:,3], max=5.0)
rgb_viz = b.resize_image(b.get_rgb_image(rgbd_original.rgb), b.RENDERER.intrinsics.height, b.RENDERER.intrinsics.width)
overlay_viz = b.overlay_image(b.viz.resize_image(depth_reconstruction_viz, rgb_viz.height,rgb_viz.width), rgb_viz)
b.vstack_images([
    depth_viz,
    depth_reconstruction_viz,
    seg_viz,
    overlay_viz
])
    

In [None]:
depth_viz

In [None]:
b.get_depth_image(b.RENDERER.render(b.get_poses(trace), b.get_indices(trace))[:,:,3], max=5.0)

In [None]:
OBJECT_NUMBER += 1

In [None]:
depth_reconstruction_viz = b.get_depth_image(b.get_rendered_image(trace)[...,2], max=1.0)
rgb_viz = b.get_rgb_image(rgbd_original.rgb)

In [None]:
b.overlay_image(b.viz.resize_image(depth_reconstruction_viz, rgb_viz.height,rgb_viz.width), rgb_viz)

In [None]:
b.viz.scale_image(depth_reconstruction_viz, 1/scaling_factor).size

In [None]:
rgb_viz.size

In [None]:
OBJECT_NUMBER += 1

In [None]:
imgs = []

In [None]:
idx = 0
contact_param_deltas = contact_param_gridding_schedule[idx]
contact_param_grid = contact_param_deltas + trace[address]


In [None]:
key = jax.random.split(key,2)[0]

In [None]:
contact_param_deltas = contact_param_gridding_schedule[idx]
contact_param_grid = contact_param_deltas + trace[address]
indices_in_contact_param_grid = jax.random.choice(key, contact_param_grid.shape[0], shape=(50,))

In [None]:
images = []
for i in indices_in_contact_param_grid:
    trace_ = enumerators[0](
        trace, key,
        contact_param_grid[i]
    )
    images.append(b.get_depth_image(b.get_rendered_image(trace_)[...,2], max=1.5))

In [None]:
b.hvstack_images(images, 10,5)

In [None]:
traces = []
for idx in range(len(contact_param_gridding_schedule)):
    contact_param_deltas = contact_param_gridding_schedule[idx]
    contact_param_grid = contact_param_deltas + trace[address]
    scores = enumerators[3](trace, key, contact_param_grid)
    i= jnp.unravel_index(scores.argmax(), scores.shape)
    trace = enumerators[0](
        trace, key,
        contact_param_grid[i]
    )
    traces.append(trace)
    b.viz_trace_meshcat(trace)

In [None]:
b.viz.scale_image(b.get_depth_image(b.get_rendered_image(trace)[...,2], max=1.0),5)

In [None]:
OBJECT_NUMBER += 1

In [None]:
b.viz.scale_image(b.get_rgb_image(get_depth_image_alternate(b.get_rendered_image(traces[0])[...,2], 1.0)),5)

In [None]:
b.setup_renderer(rgbd_original.intrinsics)
b.RENDERER.add_mesh_from_file("toy_plane.ply")
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
mesh_path = os.path.join(model_dir,"obj_" + "{}".format(13+1).rjust(6, '0') + ".ply")
b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)
mesh_path = os.path.join(model_dir,"obj_" + "{}".format(10+1).rjust(6, '0') + ".ply")
b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)

b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)

In [None]:
img = b.RENDERER.render(b.get_poses(trace), b.get_indices(trace))

In [None]:
b.get_depth_image(img[:,:,2],max=1.5)

In [None]:
OBJECT_NUMBER += 1

In [None]:
counter = 0

In [None]:
key = jax.random.split(key,1)[0]
new_object_idx = jax.random.choice(key,3)
contact_param_grid = contact_param_gridding_schedule[0] + jnp.zeros(3)
key = jax.random.split(key,1)[0]
contact_param_random = contact_param_grid[jax.random.choice(key, contact_param_grid.shape[0]),:]
print(contact_param_random)
trace_ = b.update_address(trace, key, address, contact_param_random)
trace_ = b.update_address(trace_, key, f"id_{OBJECT_NUMBER}", new_object_idx)
counter +=1
b.get_rgb_image(get_depth_image_alternate(b.get_rendered_image(trace_)[...,2], 1.0)).save(f"{counter}.png")


In [None]:
trace_[address]

In [None]:
for idx in range(len(contact_param_gridding_schedule)):
    contact_param_deltas = contact_param_gridding_schedule[idx]
    contact_param_grid = contact_param_deltas + trace[f"contact_params_1"]
    scores = enumerators[3](trace, key, contact_param_grid)
    i= jnp.unravel_index(scores.argmax(), scores.shape)
    trace = enumerators[0](
        trace, key,
        contact_param_grid[i]
    )
    b.viz_trace_meshcat(trace)

In [None]:
b.viz.scale_image(b.get_rgb_image(get_depth_image_alternate(b.get_rendered_image(trace)[...,2], 1.0)),5)

In [None]:
enumerators = b.make_enumerator([f"contact_params_2"])
trace = b.add_object_jit(trace, key, 1, 0, 2,3)
b.viz_trace_meshcat(trace)

In [None]:
for idx in range(len(contact_param_gridding_schedule)):
    contact_param_deltas = contact_param_gridding_schedule[idx]
    contact_param_grid = contact_param_deltas + trace[f"contact_params_2"]
    scores = enumerators[3](trace, key, contact_param_grid)
    i= jnp.unravel_index(scores.argmax(), scores.shape)
    trace = enumerators[0](
        trace, key,
        contact_param_grid[i]
    )
    b.viz_trace_meshcat(trace)

In [None]:
enumerators = b.make_enumerator([f"contact_params_3"])
trace = b.add_object_jit(trace, key, 0, 0, 2,3)
b.viz_trace_meshcat(trace)

In [None]:
for idx in range(len(contact_param_gridding_schedule)):
    contact_param_deltas = contact_param_gridding_schedule[idx]
    contact_param_grid = contact_param_deltas + trace[f"contact_params_3"]
    scores = enumerators[3](trace, key, contact_param_grid)
    i= jnp.unravel_index(scores.argmax(), scores.shape)
    trace = enumerators[0](
        trace, key,
        contact_param_grid[i]
    )
    b.viz_trace_meshcat(trace)

In [None]:
def get_depth_image_alternate(depth, maxval=None):
    far = jnp.max(depth)
    minval = jnp.min(depth[depth > jnp.min(depth)])
    if maxval is None:
        maxval = jnp.max(depth[depth < jnp.max(depth)])
    depth = depth.at[depth >= far].set(jnp.nan)
    viz_img = np.array(b.get_depth_image(
       depth, min=minval,  max=maxval
    ))
    viz_img[viz_img.sum(-1) == 0,:] = 255.0
    return viz_img

In [None]:
b.viz.scale_image(b.get_rgb_image(get_depth_image_alternate(b.get_rendered_image(trace)[...,2], 1.0)),5)

In [None]:
b.get_rgb_image(get_depth_image_alternate(jnp.array(rgbd_original.depth),1.0))