In [5]:
import numpy as np
import jax.numpy as jnp
import jax
import bayes3d as b
import time
from PIL import Image
from scipy.spatial.transform import Rotation as R
import matplotlib.pyplot as plt
import cv2
import trimesh
import os
import glob
import bayes3d.neural
import pickle
# Can be helpful for debugging:
# jax.config.update('jax_enable_checks', True) 
from bayes3d.neural.segmentation import carvekit_get_foreground_mask
import genjax

In [6]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7007/static/


In [7]:
# cloud = np.load("plane_pc_video_capture.npy")
# b.show_cloud("1", cloud)

In [8]:
importance_jit = jax.jit(b.model.importance)
key = jax.random.PRNGKey(10)
enumerators = b.make_enumerator([f"contact_params_1"])

In [9]:
grid_params = [
    (0.5, jnp.pi, (11,11,21)), (0.4, jnp.pi/2, (11,11,21)), (0.1, jnp.pi/2, (11,11,11)),
    (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi/3, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.1, 0.0, (21,21,1)),(0.1, 0.0, (21,21,1))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]

In [77]:
# paths = glob.glob(
#     # "panda_dataset/*.pkl"
#     "panda_scans_v5/*.pkl"
# )
# all_data = pickle.load(open(paths[2], "rb"))
# IDX = 0
# data = all_data[IDX]

paths = glob.glob(
    # "panda_dataset/*.pkl"
    "/home/nishadgothoskar/*.pkl"
)
paths = sorted(paths)
print(paths)
input_path = paths[-1]
output_path = input_path + "result.pkl"
if os.path.exists(output_path):
    print("result already exists for this input.")
all_data = pickle.load(open(input_path, "rb"))
data = all_data[0]

['/home/nishadgothoskar/1695309847.223737.pkl', '/home/nishadgothoskar/1695309847.223737.pklresult.pkl', '/home/nishadgothoskar/1695309957.3550975.pkl', '/home/nishadgothoskar/1695309957.3550975.pklresult.pkl', '/home/nishadgothoskar/1695310026.53716.pkl', '/home/nishadgothoskar/1695310026.53716.pklresult.pkl', '/home/nishadgothoskar/1695310090.890176.pkl', '/home/nishadgothoskar/1695310090.890176.pklresult.pkl', '/home/nishadgothoskar/1695310137.9889367.pkl', '/home/nishadgothoskar/1695310137.9889367.pklresult.pkl', '/home/nishadgothoskar/1695310331.5396512.pkl', '/home/nishadgothoskar/1695310331.5396512.pklresult.pkl', '/home/nishadgothoskar/1695310465.8956923.pkl']


In [78]:
K = data["camera_image"]['camera_matrix'][0]
rgb = data["camera_image"]['rgbPixels']
depth = data["camera_image"]['depthPixels']
camera_pose = data["camera_image"]['camera_pose']
camera_pose = b.t3d.pybullet_pose_to_transform(camera_pose)
fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2]
h,w = depth.shape
near = 0.001
rgbd_original = b.RGBD(rgb, depth, camera_pose, b.Intrinsics(h,w,fx,fy,cx,cy,0.001,10000.0))

In [79]:
scaling_factor = 0.2
rgbd_scaled_down = b.RGBD.scale_rgbd(rgbd_original, scaling_factor)
# b.get_rgb_image(rgbd_original.rgb)

In [80]:
if b.RENDERER is None:
    b.setup_renderer(rgbd_scaled_down.intrinsics)
    # b.RENDERER.add_mesh_from_file("toy_final.ply")
    cloud = np.load("plane_pc_video_capture.npy")
    mesh  = b.utils.make_voxel_mesh_from_point_cloud(cloud, 0.007)
    b.show_trimesh("2", mesh)
    b.RENDERER.add_mesh(mesh)

    b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)

   

In [81]:
plane_pose, plane_dims = b.utils.infer_table_plane(
    b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics), jnp.eye(4), rgbd_scaled_down.intrinsics, 
    ransac_threshold=0.001, inlier_threshold=0.001, segmentation_threshold=0.1
)
plane_pose = camera_pose @ plane_pose

In [82]:

table_pose = b.inverse_pose(camera_pose) @ plane_pose
mask = b.utils.scale(carvekit_get_foreground_mask(rgbd_original)*1.0, scaling_factor)
observed_depth = (rgbd_scaled_down.depth * mask) + (1.0 - mask)* rgbd_scaled_down.intrinsics.far
b.clear()
b.show_cloud("1", b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics).reshape(-1,3))
b.show_pose("table", table_pose)

weight, trace = importance_jit(key, genjax.choice_map({
    "parent_0": -1,
    "parent_1": 0,
    "id_0": jnp.int32(1),
    "id_1": jnp.int32(0),
    "camera_pose": jnp.eye(4),
    "root_pose_0": table_pose,
    "face_parent_1": 2,
    "face_child_1": 3,
    "image": b.unproject_depth(observed_depth, rgbd_scaled_down.intrinsics),
    "variance": 0.0001,
    "outlier_prob": 0.0001,
    "contact_params_1": jnp.array([0.0, 0.0, jnp.pi/2])
}), (
    jnp.arange(2),
    jnp.arange(22),
    jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
    jnp.array([jnp.array([-5.2, -5.2, -15*jnp.pi]), jnp.array([5.2, 5.2, 15*jnp.pi])]),
    b.RENDERER.model_box_dims, 1.0, 1.0)
)
# b.viz_trace_meshcat(trace)
print(trace.get_score())
for idx in range(len(contact_param_gridding_schedule)):
    contact_param_deltas = contact_param_gridding_schedule[idx]
    contact_param_grid = contact_param_deltas + trace[f"contact_params_1"]
    scores = enumerators[3](trace, key, contact_param_grid)
    i= jnp.unravel_index(scores.argmax(), scores.shape)
    trace = enumerators[0](
        trace, key,
        contact_param_grid[i]
    )
    print(trace["contact_params_1"])

-3828.0261
[ 0.         -0.19999999  1.5707964 ]
[ 1.4901161e-08 -1.9999997e-01  1.5707964e+00]
[-0.01999998 -0.23999998  1.5707964 ]
[-0.01999998 -0.23999998  1.5707964 ]
[-0.02999998 -0.23999998  1.69646   ]
[-0.02599998 -0.23799998  1.69646   ]
[-0.02599998 -0.23799998  1.69646   ]
[-0.02599997 -0.23799998  1.69646   ]


In [83]:
b.viz_trace_meshcat(trace)
b.show_pose("table", table_pose)


In [84]:
output_pose = b.get_poses(trace)[1]
output_pose_pybullet = b.transform_to_pybullet_pose(output_pose)

with open(output_path, 'wb') as handle:
    pickle.dump({"pose": [list(np.array(i)) for i in output_pose_pybullet]}, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
output

In [20]:
[list(np.array(i)) for i in output_pose_pybullet]

[[0.032561865, 0.11932532, 0.54254884],
 [0.9062123, 0.1890017, -0.057249494, 0.37338343]]

In [16]:
bbox = b.RENDERER.model_box_dims[0]


In [15]:
output_pose = b.get_poses(trace)[1]
b.transform_to_pybullet_pose(output_pose)

(Array([ 0.18599698, -0.05460909,  0.829901  ], dtype=float32),
 Array([ 0.59193957, -0.56721324,  0.37157702,  0.43413743], dtype=float32))

In [None]:
trace = b.update_address(trace, key, "contact_params_1", jnp.array([0.0, 0.0, jnp.pi]) + trace["contact_params_1"])
b.viz_trace_meshcat(trace)

In [106]:

print(trace.get_score())
for idx in range(len(contact_param_gridding_schedule)):
    contact_param_deltas = contact_param_gridding_schedule[idx]
    contact_param_grid = contact_param_deltas + trace[f"contact_params_1"]
    scores = enumerators[3](trace, key, contact_param_grid)
    i= jnp.unravel_index(scores.argmax(), scores.shape)
    trace = enumerators[0](
        trace, key,
        contact_param_grid[i]
    )
    print(trace["contact_params_1"])
    b.viz_trace_meshcat(trace)

-40490.67
[-1.78      -1.78      -7.4351025]
[-2.18     -2.18     -9.005898]
[ -2.2        -2.2       -10.5766945]
[ -2.2       -2.2      -11.623892]
[ -2.2       -2.2      -12.671089]
[ -2.2       -2.2      -13.299408]
[ -2.2       -2.2      -13.299408]
[ -2.2       -2.2      -13.299408]


In [98]:
trace["contact_params_1"]

Array([ 0.202     , -0.17199998,  2.0839233 ], dtype=float32)

In [20]:
x = 0.01
ang = 0.0
nums = (21,21,1)
contact_param_gridding_schedule = [
     b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for _ in range(5)
]

In [21]:
for idx in range(len(contact_param_gridding_schedule)):
    contact_param_deltas = contact_param_gridding_schedule[idx]
    contact_param_grid = contact_param_deltas + trace[f"contact_params_1"]
    scores = enumerators[3](trace, key, contact_param_grid)
    i= jnp.unravel_index(scores.argmax(), scores.shape)
    trace = enumerators[0](
        trace, key,
        contact_param_grid[i]
    )
    print(trace["contact_params_1"])
    b.viz_trace_meshcat(trace)

[0.01900001 0.03300001 0.62831867]
[0.01900001 0.03300001 0.62831867]
[0.01900001 0.03300001 0.62831867]
[0.01900001 0.03300001 0.62831867]
[0.01900001 0.03300001 0.62831867]


In [86]:
b.RENDERER.model_box_dims

Array([[2.87e-01, 2.87e-01, 1.57e-01],
       [1.00e-09, 1.00e-09, 1.00e-09]], dtype=float32)

In [87]:
b.clear()

In [155]:
b.show_trimesh("1", b.RENDERER.meshes[0])

In [22]:
cloud = np.load("plane_pc_video_capture_hi_res.npy")
b.clear()
b.show_cloud("1", cloud)

In [35]:
cloud_voxelized = b.utils.voxelize(cloud, 0.007)

In [36]:
cloud_voxelized.shape

(3447, 3)