In [None]:
import bayes3d as b
import os
import jax.numpy as jnp
import jax
import bayes3d.genjax
import genjax
import matplotlib
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
b.setup_visualizer()

In [None]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=50.0, fy=50.0,
    cx=50.0, cy=50.0,
    near=0.01, far=20.0
)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
meshes = []
for idx in range(1,22):
    mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
    b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)
b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)


In [None]:
table_pose = b.t3d.inverse_pose(
    b.t3d.transform_from_pos_target_up(
        jnp.array([0.0, 0.2, .05]),
        jnp.array([0.0, 0.0, 0.0]),
        jnp.array([0.0, 0.0, 1.0]),
    )
)


importance_jit = jax.jit(b.genjax.model.importance)
update_jit = jax.jit(b.genjax.model.update)

In [None]:
# del importance_jit
# del update_jit

In [None]:
b.utils.ycb_loader.MODEL_NAMES[13]

In [None]:
b.RENDERER.model_box_dims[13]

In [None]:
contact_enumerators = [b.genjax.make_enumerator([f"contact_params_{i}", "variance", "outlier_prob"]) for i in range(5)]
single_enumerators = b.genjax.make_enumerator([f"contact_params_1"])

def c2f_contact_update(trace_, key,  number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID):
    contact_param_grid = contact_param_deltas + trace_[f"contact_params_{number}"]
    scores = contact_enumerators[number].score_vmap(trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID)
    i,j,k = jnp.unravel_index(scores.argmax(), scores.shape)
    return contact_enumerators[number].enum_f(
        trace_, key,
        contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]
    )
c2f_contact_update_jit = jax.jit(c2f_contact_update, static_argnames=("number",))

In [None]:
key = jax.random.PRNGKey(100)

In [None]:
VARIANCE_GRID = jnp.linspace(0.0000398, 0.0000398, 3)
OUTLIER_GRID = jnp.array([0.00251])
OUTLIER_VOLUME = 10.0

grid_params = [
    (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi/2, (11,11,11)), (0.1, jnp.pi/2, (11,11,11)),
    (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.01, 0.0, (21,21,1))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]

width = 0.04
ang = jnp.pi
final_contact_param_deltas = b.utils.make_translation_grid_enumeration_3d(
    -width, -width, -ang,
    width, width, ang,
    21,21,300
)

def get_depth_image(image):
    mval = image[image < image.max()].max()
    return b.get_depth_image(image, max=mval)



In [None]:
for experiment_iteration in tqdm(range(2)):
    print(key)
    key = jax.random.split(key,1)[0]
    
    weight, gt_trace = importance_jit(key, genjax.choice_map({
        "parent_0": -1,
        "parent_1": 0,
        "id_0": jnp.int32(21),
        "id_1": jnp.int32(13),
        "camera_pose": jnp.eye(4),
        "root_pose_0": table_pose,
        # "contact_params_1": jnp.array([ 0.01630328 ,-0.06595182, -2.946241  ]),
        "face_parent_1": 2,
        "face_child_1": 3,
        "variance": VARIANCE_GRID[0],
        "outlier_prob": OUTLIER_GRID[0],
    }), (
        jnp.arange(2),
        jnp.arange(22),
        jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
        jnp.array([jnp.array([-0.1, -0.1, -2*jnp.pi]), jnp.array([0.1, 0.1, 2*jnp.pi])]),
        b.RENDERER.model_box_dims, OUTLIER_VOLUME)
                                  )
    trace = gt_trace
    contact_param_grid = final_contact_param_deltas + trace[f"contact_params_1"]
    weights = jnp.concatenate([
            contact_enumerators[1].score_vmap(
            trace,
            key,
            d + trace[f"contact_params_1"],
            VARIANCE_GRID,
            OUTLIER_GRID
        ) for d in jnp.array_split(final_contact_param_deltas, 55)
    ],axis=0
    )
    
    i,j,k = jnp.unravel_index(weights.argmax(), weights.shape)
    trace= contact_enumerators[1].enum_f(
        trace, key,
        contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]
    )
    print(trace["contact_params_1"])
    print(trace["variance"])
    
    print(trace.get_score())
    
    print(gt_trace["contact_params_1"])
    print(gt_trace["variance"])
    print(gt_trace.get_score())
    
    key2 = jax.random.PRNGKey(0)
    sampled_indices = jax.random.categorical(key2, weights.reshape(-1), shape=(1000,))
    sampled_indices = jnp.unravel_index(sampled_indices, weights.shape)[0]
    sampled_params = contact_param_grid[sampled_indices]
    actual_params = gt_trace["contact_params_1"]
    
    fig = plt.figure(constrained_layout=True)
    widths = [1, 1]
    heights = [2]
    spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,
                              height_ratios=heights)
    
    ax = fig.add_subplot(spec[0, 0])
    ax.imshow(jnp.array(get_depth_image(trace["image"][...,2])))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title("Observed Depth")
    
    
    ax = fig.add_subplot(spec[0, 1])
    ax.set_aspect(1.0)
    circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle="--", linewidth=0.5)
    ax.add_patch(circ)
    ax.set_xlim(-2.0, 2.0)
    ax.set_ylim(-2.0, 2.0)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.scatter(-jnp.sin(sampled_params[:,2]),-jnp.cos(sampled_params[:,2]),label="Posterior Samples", alpha=0.5, s=15)
    ax.scatter(-jnp.sin(actual_params[2]),-jnp.cos(actual_params[2]), color=(1.0, 0.0, 0.0),label="Actual", alpha=0.9, s=10)
    ax.set_title("Posterior on Orientation (top view)")
    ax.legend(fontsize=7)
    # plt.show()
    plt.savefig(f'{experiment_iteration:05d}.png')
    plt.clf()

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, FloatSlider, IntSlider, Button, Output, HBox, VBox, FloatLogSlider

out = Output(layout={'border': '5px solid black', "height" : '100px'})

def func(x,y,ang,variance, outlier_prob, outlier_volume):

    VARIANCE_GRID = jnp.array([variance])
    OUTLIER_GRID = jnp.array([outlier_prob])
    OUTLIER_VOLUME = outlier_volume

    weight, gt_trace = importance_jit(key, genjax.choice_map({
        "parent_0": -1,
        "parent_1": 0,
        "id_0": jnp.int32(21),
        "id_1": jnp.int32(13),
        "camera_pose": jnp.eye(4),
        "root_pose_0": table_pose,
        "contact_params_1": jnp.array([ x,y,ang  ]),
        "face_parent_1": 2,
        "face_child_1": 3,
        "variance": VARIANCE_GRID[0],
        "outlier_prob": OUTLIER_GRID[0],
    }), (
        jnp.arange(2),
        jnp.arange(22),
        jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
        jnp.array([jnp.array([-0.1, -0.1, -2*jnp.pi]), jnp.array([0.1, 0.1, 2*jnp.pi])]),
        b.RENDERER.model_box_dims, OUTLIER_VOLUME)
                                  )
    trace = gt_trace
    contact_param_grid = final_contact_param_deltas + trace[f"contact_params_1"]
    weights = jnp.concatenate([
            contact_enumerators[1].score_vmap(
            trace,
            key,
            d + trace[f"contact_params_1"],
            VARIANCE_GRID,
            OUTLIER_GRID
        ) for d in jnp.array_split(final_contact_param_deltas, 55)
    ],axis=0
    )
    
    i,j,k = jnp.unravel_index(weights.argmax(), weights.shape)
    trace= contact_enumerators[1].enum_f(
        trace, key,
        contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]
    )
    print(trace["contact_params_1"])
    print(trace["variance"])
    
    print(trace.get_score())
    
    print(gt_trace["contact_params_1"])
    print(gt_trace["variance"])
    print(gt_trace.get_score())
    
    key2 = jax.random.PRNGKey(0)
    sampled_indices = jax.random.categorical(key2, weights.reshape(-1), shape=(1000,))
    sampled_indices = jnp.unravel_index(sampled_indices, weights.shape)[0]
    sampled_params = contact_param_grid[sampled_indices]
    actual_params = gt_trace["contact_params_1"]
    
    fig = plt.figure(constrained_layout=True)
    widths = [1, 1]
    heights = [2]
    spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,
                              height_ratios=heights)
    
    ax = fig.add_subplot(spec[0, 0])
    ax.imshow(jnp.array(get_depth_image(trace["image"][...,2])))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title("Observed Depth")
    
    
    ax = fig.add_subplot(spec[0, 1])
    ax.set_aspect(1.0)
    circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle="--", linewidth=0.5)
    ax.add_patch(circ)
    ax.set_xlim(-2.0, 2.0)
    ax.set_ylim(-2.0, 2.0)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.scatter(-jnp.sin(sampled_params[:,2]),-jnp.cos(sampled_params[:,2]),label="Posterior Samples", alpha=0.5, s=15)
    ax.scatter(-jnp.sin(actual_params[2]),-jnp.cos(actual_params[2]), color=(1.0, 0.0, 0.0),label="Actual", alpha=0.9, s=10)
    ax.set_title("Posterior on Orientation (top view)")
    ax.legend(fontsize=7)
    # plt.show()
    # plt.savefig(f'{experiment_iteration:05d}.png')
    # plt.clf()
    
    with out: 
        out.clear_output()
        display(f"variance   = {variance}")
        display(f"outlier_prob = {outlier_prob}")
        display(f"outlier_volume = {outlier_volume}")

w = interactive(func, 
    x = FloatSlider(min=-0.1, max=0.1, value=0.0, description=" x:"),
    y = FloatSlider(min=-0.1, max=0.1, value=0.0, description=" y:"),
    ang = FloatSlider(min=-jnp.pi, max=jnp.pi, value=jnp.pi, description=" ang:"),
    variance = FloatLogSlider(base=10.0, min=-9, max=1, value=0.000501, description="variance:"),
    outlier_prob = FloatLogSlider(base=10.0, min=-4, max=0, value=0.631, description="outlier_prob:"),
    outlier_volume = FloatLogSlider(base=10.0, min=1, max=5, value=10.0, description="outlier_volume:")
);
display(VBox([w,out]))