In [1]:
import bayes3d as b
import bayes3d.genjax
import joblib
from tqdm import tqdm
import os
import jax.numpy as jnp
import jax
import numpy as np

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7012/static/


In [3]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=500.0, fy=500.0,
    cx=50.0, cy=50.0,
    near=0.01, far=20.0
)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
meshes = []
for idx in range(1,22):
    mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
    b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)

b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)


[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (128, 128, 1024)


In [4]:
importance_jit = jax.jit(b.genjax.model.importance)
key = jax.random.PRNGKey(10)

In [5]:
gt_traces = []
NUM_TRACES = 40
for scene_id in tqdm(range(NUM_TRACES)):
    filename = f"data/trace_{scene_id}.joblib"
    gt_traces.append(importance_jit(key, *joblib.load(filename))[1][1])

hb_traces = []

V_VARIANT = 0
O_VARIANT = 0
HIERARCHICAL_BAYES = True

for scene_id in tqdm(range(NUM_TRACES)):
    if HIERARCHICAL_BAYES:
        filename = f"data/inferred_hb_{scene_id}.joblib"
    else:
        filename = f"data/inferred_{V_VARIANT}_{O_VARIANT}_{scene_id}.joblib"
    hb_traces.append(importance_jit(key, *joblib.load(filename))[1][1])

variant_0_0_traces = []

V_VARIANT = 0
O_VARIANT = 0
HIERARCHICAL_BAYES = False

for scene_id in tqdm(range(NUM_TRACES)):
    if HIERARCHICAL_BAYES:
        filename = f"data/inferred_hb_{scene_id}.joblib"
    else:
        filename = f"data/inferred_{V_VARIANT}_{O_VARIANT}_{scene_id}.joblib"
    variant_0_0_traces.append(importance_jit(key, *joblib.load(filename))[1][1])


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:06<00:00,  6.42it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 26.64it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 31.14it/s]


In [6]:
# Classification accuracy
prediction_sets = [hb_traces]
wrong_prediction = []
for pred_set in prediction_sets:
    correct = 0
    for scene_id in range(NUM_TRACES):
        gt_ids = set(np.array(b.genjax.get_indices(gt_traces[scene_id])))
        pred_ids = set(np.array(b.genjax.get_indices(pred_set[scene_id])))
        # print(pred_set[scene_id]["variance"])
        # print(pred_set[scene_id]["outlier_prob"])
        if pred_ids == gt_ids:
            correct +=1
        else:
            wrong_prediction.append(scene_id)
            # print(gt_ids, pred_ids)
            continue
    print(correct)    

31


In [7]:
wrong_prediction

[3, 4, 5, 11, 19, 25, 26, 27, 34]

In [8]:

importance_jit = jax.jit(b.genjax.model.importance)

contact_enumerators = [b.genjax.make_enumerator([f"contact_params_{i}", "variance", "outlier_prob"]) for i in range(5)]
add_object_jit = jax.jit(b.genjax.add_object)

def c2f_contact_update(trace_, key,  number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID):
    contact_param_grid = contact_param_deltas + trace_[f"contact_params_{number}"]
    scores = contact_enumerators[number][3](trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID)
    i,j,k = jnp.unravel_index(scores.argmax(), scores.shape)
    return contact_enumerators[number][0](
        trace_, key,
        contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]
    )
c2f_contact_update_jit = jax.jit(c2f_contact_update, static_argnames=("number",))

In [9]:
VARIANCE_GRID = jnp.array([0.000001, 0.00001, 0.0001])
OUTLIER_GRID = jnp.array([0.0001, 0.001, 0.01])
# VARIANCE_GRID = jnp.array([0.001])
# OUTLIER_GRID = jnp.array([ 0.0001])

grid_params = [
    (0.2, jnp.pi, (11,11,11)), (0.1, jnp.pi/3, (11,11,11)), (0.05, 0.0, (11,11,1)),
    (0.05, jnp.pi/5, (11,11,11)), (0.02, 2*jnp.pi, (5,5,51)), (0.02, jnp.pi/5, (11,11,11))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]
key = jax.random.PRNGKey(500)


In [10]:
scene_id = 4

In [11]:
V_VARIANT = 0
O_VARIANT = 0
HIERARCHICAL_BAYES = False

if HIERARCHICAL_BAYES:
    V_GRID = VARIANCE_GRID
    O_GRID = OUTLIER_GRID
else:
    V_GRID, O_GRID = jnp.array([VARIANCE_GRID[V_VARIANT]]), jnp.array([OUTLIER_GRID[O_VARIANT]])

print(V_GRID, O_GRID)

gt_trace = importance_jit(key, *joblib.load(f"data/trace_{scene_id}.joblib"))[1][1]
choices = gt_trace.get_choices()
key, (_,trace) = importance_jit(key, choices, (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:]))
print(trace.get_score())

[1.e-06] [1.e-04]
-35821.44


In [12]:


all_all_paths = []
for _ in range(3):
    all_paths = []
    for obj_id in tqdm(range(len(b.RENDERER.meshes)-1)):
        path = []
        trace_ = add_object_jit(trace, key, obj_id, 0, 2,3)
        number = b.genjax.get_contact_params(trace_).shape[0] - 1
        path.append(trace_)
        for c2f_iter in range(len(contact_param_gridding_schedule)):
            trace_ = c2f_contact_update_jit(trace_, key, number,
                contact_param_gridding_schedule[c2f_iter], V_GRID, O_GRID)
            path.append(trace_)
        # for c2f_iter in range(len(contact_param_gridding_schedule)):
        #     trace_ = c2f_contact_update_jit(trace_, key, number,
        #         contact_param_gridding_schedule[c2f_iter], VARIANCE_GRID, OUTLIER_GRID)
        all_paths.append(
            path
        )
    all_all_paths.append(all_paths)
    
    scores = jnp.array([t[-1].get_score() for t in all_paths])
    print(scores)
    normalized_scores = b.utils.normalize_log_scores(scores)
    trace = all_paths[jnp.argmax(scores)][-1]

print(b.genjax.get_indices(gt_trace))
print(b.genjax.get_indices(trace))
b.genjax.viz_trace_meshcat(trace)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:26<00:00,  1.26s/it]


[140788.95  139039.81  155061.39  136737.61  151185.1   130344.41
 145694.47  129783.81  134647.02  128445.47  140451.97  145338.31
 123780.305 134878.27  122159.52  146830.38  127155.83  128167.47
 127229.41  126534.625 129810.05 ]


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:29<00:00,  1.42s/it]


[262425.3  244837.44 263188.16 266169.75 274830.38 262714.72 276297.22
 262107.2  265580.72 259876.86 245313.86 268077.56 249170.81 263501.38
 246880.9  259424.11 258437.   261482.69 256368.16 253637.3  262010.97]


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:33<00:00,  1.58s/it]


[299096.7  281531.44 299859.6  302841.16 311501.8  296588.66 295357.53
 296822.1  302252.12 295779.47 281985.28 304748.97 285842.25 300172.75
 282242.56 296095.53 295131.03 298773.6  291990.94 290279.22 299083.97]
[21  2  4  6]
[21  2  6  4]


In [13]:
b.genjax.viz_trace_meshcat(trace)

In [83]:
trace_alternate = all_paths[4][-1]
b.genjax.print_trace(trace)
b.genjax.print_trace(trace_alternate)


    SCORE: 152959.3750000
    VARIANCE: 0.0000100
    OUTLIER_PROB 0.0100000
    

    SCORE: 152818.5625000
    VARIANCE: 0.0000100
    OUTLIER_PROB 0.0100000
    
