In [None]:
import bayes3d as b
import bayes3d.genjax
import joblib
from tqdm import tqdm
import os
import jax.numpy as jnp
import jax
import numpy as np
import genjax
import matplotlib.pyplot as plt

In [None]:
b.setup_visualizer()

In [None]:
intrinsics = b.Intrinsics(
    height=50,
    width=50,
    fx=100.0, fy=100.0,
    cx=25.0, cy=25.0,
    near=0.01, far=1.0
)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
meshes = []
for idx in range(1,22):
    mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
    b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)

b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)


In [None]:

importance_jit = jax.jit(b.model.importance)

contact_enumerators = [b.make_enumerator([f"contact_params_{i}", "variance", "outlier_prob"]) for i in range(5)]
add_object_jit = jax.jit(b.add_object)

def c2f_contact_update(trace_, key,  number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID):
    contact_param_grid = contact_param_deltas + trace_[f"contact_params_{number}"]
    scores = contact_enumerators[number][3](trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID)
    i,j,k = jnp.unravel_index(scores.argmax(), scores.shape)
    return contact_enumerators[number][0](
        trace_, key,
        contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]
    )
c2f_contact_update_jit = jax.jit(c2f_contact_update, static_argnames=("number",))

In [None]:
OUTLIER_VOLUME = 1.0
VARIANCE_GRID = jnp.array([0.0001, 0.001, 0.01])
OUTLIER_GRID = jnp.array([0.00001, 0.0001, 0.001])

grid_params = [
    (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi, (11,11,11)), (0.1, jnp.pi, (11,11,11)),
    (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.05, 0.0, (21,21,1))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]


In [None]:
V_VARIANT = 0
O_VARIANT = 0
HIERARCHICAL_BAYES = True

if HIERARCHICAL_BAYES:
    V_GRID = VARIANCE_GRID
    O_GRID = OUTLIER_GRID
else:
    V_GRID, O_GRID = jnp.array([VARIANCE_GRID[V_VARIANT]]), jnp.array([OUTLIER_GRID[O_VARIANT]])

print(V_GRID, O_GRID)

In [None]:
key = jax.random.PRNGKey(502)

In [None]:
camera_pose = b.t3d.transform_from_pos_target_up(
    jnp.array([0.0, 0.4, 0.2]),
    jnp.array([0.0, 0.0, 0.0]),
    jnp.array([0.0, 0.0, 1.0]),
)

camera_poses = jnp.array([
    b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) @ camera_pose
    for angle in jnp.linspace(-jnp.pi, jnp.pi, 70)[:-1]]
)

In [None]:
for (i,p) in enumerate(camera_poses):
    b.show_pose(f"{i}", p)

In [None]:
camera_pose_enumerators = b.make_enumerator(["camera_pose"])

In [None]:
split_key = jax.jit(lambda x: jax.random.split(x,1)[1])

In [None]:
w = 1.0
translation_deltas = b.utils.make_translation_grid_enumeration(-w,-w,-w,w,w,w,11,11,11)
get_rotation_deltas = jax.jit(jax.vmap(lambda key, v, c: b.distributions.gaussian_vmf_zero_mean(key, v, c), in_axes=(0,None,None)))

In [None]:

def likelihoood(obs, render):
    return -jnp.linalg.norm((obs[:,:,2] - render[:,:,2]) / obs[:,:,2])

def render_at_camera_pose(camera_pose, poses, indices):
    rendered = b.RENDERER.render(b.inverse_pose(camera_pose) @ poses, indices)[:,:,:3]
    return rendered
    
def score_camera_pose_(camera_pose, obs, poses, indices):
    rendered = render_at_camera_pose(camera_pose, poses, indices)
    return likelihoood(obs, rendered)
score_camera_pose = jax.jit(score_camera_pose_)
score_camera_pose_parallel = jax.jit(jax.vmap(score_camera_pose, in_axes=(0,None, None, None)))

def update_pose_estimate(current_pose_estimate, key, obs,v,c, poses, indices):
    rotation_deltas = get_rotation_deltas(jax.random.split(key, 10000), v,c)
    proposals = jnp.einsum("ij,ajk->aik", current_pose_estimate, rotation_deltas)
    weights = score_camera_pose_parallel(proposals, obs, poses, indices)
    bool = (weights.max() > score_camera_pose(current_pose_estimate, obs,poses, indices ))
    current_pose_estimate = current_pose_estimate* (1.0 - bool) + bool * proposals[jnp.argmax(weights)]

    rotation_deltas = get_rotation_deltas(jax.random.split(key, 10000), v,c)
    proposals = jnp.einsum("ij,ajk->aik", current_pose_estimate, rotation_deltas)
    weights = score_camera_pose_parallel(proposals, obs, poses, indices)
    bool = (weights.max() > score_camera_pose(current_pose_estimate, obs, poses, indices ))
    current_pose_estimate = current_pose_estimate* (1.0 - bool) + bool * proposals[jnp.argmax(weights)]

    rotation_deltas = get_rotation_deltas(jax.random.split(key, 10000), v,c)
    proposals = jnp.einsum("ij,ajk->aik", current_pose_estimate, rotation_deltas)
    weights = score_camera_pose_parallel(proposals, obs,poses, indices)
    bool = (weights.max() > score_camera_pose(current_pose_estimate, obs, poses, indices))
    current_pose_estimate = current_pose_estimate* (1.0 - bool) + bool * proposals[jnp.argmax(weights)]

    
    return current_pose_estimate
update_pose_estimate_jit = jax.jit(update_pose_estimate)

In [None]:
key = jax.random.split(key, 1)[0]

gt_traces = [importance_jit(key, genjax.choice_map({
    "parent_0": -1,
    "parent_1": 0,
    "parent_2": 0,
    "parent_3": 0,
    "id_0": jnp.int32(21),
    "camera_pose": p,
    "root_pose_0": jnp.eye(4),
    "face_parent_1": 2,
    "face_parent_2": 2,
    "face_parent_3": 2,
    "face_child_1": 3,
    "face_child_2": 3,
    "face_child_3": 3,
    "variance": 0.0001,
    "outlier_prob": 0.1,
}), (
    jnp.arange(4),
    jnp.arange(22),
    jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
    jnp.array([jnp.array([-0.1, -0.1, -1*jnp.pi]), jnp.array([0.1, 0.1, 1*jnp.pi])]),
    b.RENDERER.model_box_dims, OUTLIER_VOLUME, 1.0)
)[1] for p in camera_poses]
poses,indices = b.get_poses(gt_traces[0]), b.get_indices(gt_traces[0])

In [None]:
current_pose_estimate = camera_poses[0]

In [None]:
T = 4
gt_pose = camera_poses[T]
obs = render_at_camera_pose(gt_pose, poses, indices)
b.clear()
b.show_pose("gt", gt_pose)
b.show_pose("pred", current_pose_estimate,size=0.07)
b.show_cloud("cloud", obs.reshape(-1,3))
b.get_depth_image(obs[:,:,2])

In [None]:
current_pose_estimate = update_pose_estimate_jit(current_pose_estimate, key, obs, 0.05, 2029.293,poses, indices)
key = split_key(key)
current_pose_estimate = update_pose_estimate_jit(current_pose_estimate, key, obs, 0.01,5029.293,poses, indices)
key = split_key(key)
current_pose_estimate = update_pose_estimate_jit(current_pose_estimate, key, obs, 0.01,10029.293,poses, indices)
key = split_key(key)
current_pose_estimate = update_pose_estimate_jit(current_pose_estimate, key, obs, 0.001,20029.293,poses, indices)
key = split_key(key)
current_pose_estimate = update_pose_estimate_jit(current_pose_estimate, key, obs, 0.0001,50029.293,poses, indices)
b.show_pose("pred", current_pose_estimate,size=0.07)
print(score_camera_pose(current_pose_estimate,obs,poses, indices))

In [None]:
reconstruction = render_at_camera_pose(current_pose_estimate, poses, indices)
b.show_cloud("cloud", obs.reshape(-1,3))
b.show_cloud("reconstruction", reconstruction.reshape(-1,3),color=b.RED)

plt.matshow(reconstruction[:,:,2] - obs[:,:,2])
plt.colorbar()

In [None]:
current_pose_estimate - gt_pose

In [None]:
vmf_score = jax.jit(jax.vmap(jax.vmap(b.distributions.gaussian_vmf_logpdf_jit, in_axes=(None, None, None, 0)), in_axes=(None, None, 0, None)))

In [None]:
delta = b.inverse_pose(current_pose_estimate) @ gt_pose
variances = jnp.linspace(0.001, 0.1, 100)
concentrations = jnp.linspace(2000.0, 100000.0, 200)
scores = vmf_score(delta, jnp.eye(4), variances, concentrations)
i,j = jnp.unravel_index(scores.argmax(),scores.shape)
print(variances[i], concentrations[j])
plt.matshow(scores)
plt.colorbar()

In [None]:
proposals = jnp.einsum("ij,ajk->aik", current_pose_estimate)
weights = score_camera_pose_parallel(proposals, obs)
bool = (weights.max() > score_camera_pose(current_pose_estimate, obs ))
current_camera_pose = current_pose_estimate* (1.0 - bool) + bool * proposals[jnp.argmax(weights)]
print(current_pose_estimate)
print(weights.max())
score_camera_pose(current_pose_estimate, obs)

In [None]:
weights[jnp.argmax(weights)]

In [None]:
current_pose_estimate - proposals[jnp.argmax(weights)]

In [None]:
score_camera_pose(current_pose_estimate, obs)

In [None]:
bool

In [None]:
viz_images = [b.get_depth_image(i["image"][:,:,2]) for i in gt_traces]
b.make_gif_from_pil_images(viz_images,"sweep.gif")

In [None]:
current_camera_pose = camera_poses[0]

In [None]:
T = 2
gt_pose = gt_traces[T]["camera_pose"]
trace = b.update_address(gt_traces[T], key,  "camera_pose", current_camera_pose)
print(gt_traces[T].get_score())
print(trace.get_score())
b.viz_trace_meshcat(trace)
b.show_pose("gt_pose", gt_traces[T]["camera_pose"], size=0.05)

In [None]:
for _ in range(100):
    key = split_key(key)
    print(key)
    print(trace.get_score())
    trace = update_pose_estimate_jit(trace,key)
    b.show_pose("gt_pose", gt_traces[T]["camera_pose"], size=0.05)
    b.show_pose("estimated_pose", trace["camera_pose"], size=0.07)
print(trace.get_score())


In [None]:
proposals = jnp.einsum("ij,ajk->aik", trace["camera_pose"], translation_deltas)
weights = camera_pose_enumerators[3](trace, key, proposals)
bool = (weights.max() > trace.get_score())
current_camera_pose = trace["camera_pose"] * (1.0 - bool) + bool * proposals[jnp.argmax(weights)]
trace = b.update_address(trace, key,  "camera_pose", current_camera_pose)

In [None]:
weights.max()

In [None]:
for (i,p) in enumerate(proposals):
    b.show_pose(f"{i}", p)

In [None]:
b.clear()

In [None]:
current_camera_pose = trace["camera_pose"] @ b.transform_from_pos(jnp.array([0.05, 0.0, 0.0]))
trace = b.update_address(trace, key,  "camera_pose", current_camera_pose)
print(trace.get_score())

In [None]:
gt_traces[T]["camera_pose"] @ b.inverse_pose(trace["camera_pose"])

In [None]:
proposals = jnp.einsum("ij,ajk->aik", trace["camera_pose"], rotation_deltas)
weights = camera_pose_enumerators[3](trace, key, proposals)
print(weights.shape)
current_camera_pose = proposals[jnp.argmax(weights)]
trace = b.update_address(trace, key,  "camera_pose", current_camera_pose)
print(trace.get_score())
b.viz_trace_meshcat(trace)
b.show_pose("gt_pose", gt_traces[T]["camera_pose"], size=0.05)

In [None]:
current_camera_pose - gt_traces[T]["camera_pose"]

In [None]:
b.get_depth_image(camera_pose_enumerator(gt_trace, key, camera_poses[10])["image"][:,:,2])

In [None]:
# import imageio
# images = [imageio.imread(f'mug_results/{experiment_iteration:05d}.png') for experiment_iteration in tqdm(range(50))]
# imageio.mimsave("mug_results.gif", images, 'GIF', duration=1200)

In [None]:
T = 1
new_image = b.update_address(trace, "camera_pose", 

In [None]:

OBJECT_INDEX = 2
print(f"Searching for object index {OBJECT_INDEX} {b.utils.ycb_loader.MODEL_NAMES[OBJECT_INDEX]}")

In [None]:
bounds = 1.5
grid = b.utils.make_translation_grid_enumeration_3d(
    -bounds, -bounds, -0.0,
    bounds, bounds, 0.0,
    50,50,1
)

In [None]:
_,trace = importance_jit(key, gt_trace.get_choices(), (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:]))

path = []
obj_id = OBJECT_INDEX
trace_ = add_object_jit(trace, key, obj_id, 0, 2,3)
number = b.get_contact_params(trace_).shape[0] - 1
path.append(trace_)
for c2f_iter in range(len(contact_param_gridding_schedule)):
    trace_ = c2f_contact_update_jit(trace_, key, number,
        contact_param_gridding_schedule[c2f_iter], V_GRID, O_GRID)
    path.append(trace_)

b.viz_trace_meshcat(trace_)

In [None]:
b.viz_trace_meshcat(trace_)

In [None]:
all_all_paths = []
for _ in range(3):
    all_paths = []
    for obj_id in tqdm(range(len(b.RENDERER.meshes)-1)):
        path = []
        trace_ = add_object_jit(trace, key, obj_id, 0, 2,3)
        number = b.get_contact_params(trace_).shape[0] - 1
        path.append(trace_)
        for c2f_iter in range(len(contact_param_gridding_schedule)):
            trace_ = c2f_contact_update_jit(trace_, key, number,
                contact_param_gridding_schedule[c2f_iter], V_GRID, O_GRID)
            path.append(trace_)
        # for c2f_iter in range(len(contact_param_gridding_schedule)):
        #     trace_ = c2f_contact_update_jit(trace_, key, number,
        #         contact_param_gridding_schedule[c2f_iter], VARIANCE_GRID, OUTLIER_GRID)
        all_paths.append(
            path
        )
    all_all_paths.append(all_paths)
    
    scores = jnp.array([t[-1].get_score() for t in all_paths])
    print(scores)
    normalized_scores = b.utils.normalize_log_scores(scores)
    trace = all_paths[jnp.argmax(scores)][-1]
    b.viz_trace_meshcat(trace)

In [None]:
print(b.get_indices(gt_trace))
print(b.get_indices(trace))