In [1]:
import bayes3d as b
import bayes3d.genjax
import joblib
from tqdm import tqdm
import os
import jax.numpy as jnp
import jax
import numpy as np

In [2]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7004/static/


In [3]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=500.0, fy=500.0,
    cx=50.0, cy=50.0,
    near=0.01, far=20.0
)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
meshes = []
for idx in range(1,22):
    mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
    b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)

b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)


[E rasterize_gl.cpp:121] OpenGL version reported as 4.6
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Increasing frame buffer size to (width, height, depth) = (128, 128, 1024)


In [4]:
importance_jit = jax.jit(b.genjax.model.importance)
key = jax.random.PRNGKey(10)

In [5]:
gt_traces = []
NUM_TRACES = 100
for scene_id in tqdm(range(NUM_TRACES)):
    filename = f"data/trace_{scene_id}.joblib"
    gt_traces.append(importance_jit(key, *joblib.load(filename))[1][1])

hb_traces = []

V_VARIANT = 0
O_VARIANT = 0
HIERARCHICAL_BAYES = True

for scene_id in tqdm(range(NUM_TRACES)):
    if HIERARCHICAL_BAYES:
        filename = f"data/inferred_hb_{scene_id}.joblib"
    else:
        filename = f"data/inferred_{V_VARIANT}_{O_VARIANT}_{scene_id}.joblib"
    hb_traces.append(importance_jit(key, *joblib.load(filename))[1][1])

  0%|                                                                                                                                                                                                               | 0/100 [00:02<?, ?it/s]


NotImplementedError: MLIR translation rule for primitive 'render_multiple_140286986571936' not found for platform cpu

In [None]:
# Classification accuracy
prediction_sets = [hb_traces]
wrong_prediction = []
for pred_set in prediction_sets:
    correct = 0
    for scene_id in range(NUM_TRACES):
        gt_ids = np.array(b.genjax.get_indices(gt_traces[scene_id]))
        pred_ids = np.array(b.genjax.get_indices(pred_set[scene_id]))
        # print(pred_set[scene_id]["variance"])
        # print(pred_set[scene_id]["outlier_prob"])
        if set(pred_ids) == set(gt_ids):
            correct +=1
        else:
            wrong_prediction.append(scene_id)
            print(gt_ids, pred_ids)
            continue
    print(correct)    

In [None]:
wrong_prediction

In [None]:
b.genjax.viz_trace_meshcat(prediction_sets[0][scene_id])

In [None]:

importance_jit = jax.jit(b.genjax.model.importance)

contact_enumerators = [b.genjax.make_enumerator([f"contact_params_{i}", "variance", "outlier_prob"]) for i in range(5)]
add_object_jit = jax.jit(b.genjax.add_object)

def c2f_contact_update(trace_, key,  number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID):
    contact_param_grid = contact_param_deltas + trace_[f"contact_params_{number}"]
    scores = contact_enumerators[number][3](trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID)
    i,j,k = jnp.unravel_index(scores.argmax(), scores.shape)
    return contact_enumerators[number][0](
        trace_, key,
        contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]
    )
c2f_contact_update_jit = jax.jit(c2f_contact_update, static_argnames=("number",))

In [None]:
OUTLIER_VOLUME = 1000.0
VARIANCE_GRID = jnp.array([0.0001, 0.001, 0.01])
OUTLIER_GRID = jnp.array([0.00001, 0.0001, 0.001])

grid_params = [
    (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi, (11,11,11)), (0.1, jnp.pi, (11,11,11)),
    (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.05, 0.0, (21,21,1))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]
key = jax.random.PRNGKey(500)


In [None]:
scene_id = 4

In [20]:
V_VARIANT = 0
O_VARIANT = 0
HIERARCHICAL_BAYES = True

if HIERARCHICAL_BAYES:
    V_GRID = VARIANCE_GRID
    O_GRID = OUTLIER_GRID
else:
    V_GRID, O_GRID = jnp.array([VARIANCE_GRID[V_VARIANT]]), jnp.array([OUTLIER_GRID[O_VARIANT]])

print(V_GRID, O_GRID)

gt_trace = importance_jit(key, *joblib.load(f"data/trace_{scene_id}.joblib"))[1][1]
print(b.genjax.get_indices(gt_trace))
b.genjax.viz_trace_meshcat(gt_trace)
choices = gt_trace.get_choices()
key, (_,trace) = importance_jit(key, choices, (jnp.arange(1), jnp.arange(22), *gt_trace.get_args()[2:-1], 100.0))
print(trace.get_score())

[1.e-04 1.e-03 1.e-02] [1.e-05 1.e-04 1.e-03]
[21  2  4  6]
-925.0466


In [21]:


all_all_paths = []
for _ in range(3):
    all_paths = []
    for obj_id in tqdm(range(len(b.RENDERER.meshes)-1)):
        path = []
        trace_ = add_object_jit(trace, key, obj_id, 0, 2,3)
        number = b.genjax.get_contact_params(trace_).shape[0] - 1
        path.append(trace_)
        for c2f_iter in range(len(contact_param_gridding_schedule)):
            trace_ = c2f_contact_update_jit(trace_, key, number,
                contact_param_gridding_schedule[c2f_iter], V_GRID, O_GRID)
            path.append(trace_)
        # for c2f_iter in range(len(contact_param_gridding_schedule)):
        #     trace_ = c2f_contact_update_jit(trace_, key, number,
        #         contact_param_gridding_schedule[c2f_iter], VARIANCE_GRID, OUTLIER_GRID)
        all_paths.append(
            path
        )
    all_all_paths.append(all_paths)
    
    scores = jnp.array([t[-1].get_score() for t in all_paths])
    print(scores)
    normalized_scores = b.utils.normalize_log_scores(scores)
    trace = all_paths[jnp.argmax(scores)][-1]
    b.genjax.viz_trace_meshcat(trace)

print(b.genjax.get_indices(gt_trace))
print(b.genjax.get_indices(trace))
b.genjax.viz_trace_meshcat(trace)

  0%|                                                                                                                                                                                                                         | 0/21 [00:00<?, ?it/s]2023-07-18 19:29:36.522162: W external/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:561] libdevice is required by this HLO module but was not found at /opt/conda/envs/bayes3d/lib/python3.9/site-packages/nvidia/cuda_nvcc/nvvm/libdevice/libdevice.10.bc
2023-07-18 19:29:36.525514: W external/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:561] libdevice is required by this HLO module but was not found at /opt/conda/envs/bayes3d/lib/python3.9/site-packages/nvidia/cuda_nvcc/nvvm/libdevice/libdevice.10.bc
2023-07-18 19:29:36.527243: W external/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:561] libdevice is required by this HLO module but was not found at /opt/conda/envs/bayes3d/lib/python3.9/site-packages/nvidia/cuda_nvcc/

XlaRuntimeError: INTERNAL: libdevice not found at /opt/conda/envs/bayes3d/lib/python3.9/site-packages/nvidia/cuda_nvcc/nvvm/libdevice/libdevice.10.bc

In [183]:
b.genjax.print_trace(trace)



    SCORE: 446075.1250000
    VARIANCE: 0.0000100
    OUTLIER_PROB 0.0010000
    


In [184]:
b.genjax.viz_trace_meshcat(gt_trace)

In [186]:
trace_alternate = all_all_paths[2][18][-1]
trace_alternate2 = all_all_paths[2][19][-1]
b.genjax.print_trace(trace_alternate)
b.genjax.print_trace(trace_alternate2)
b.genjax.viz_trace_meshcat(trace_alternate)
b.genjax.viz_trace_meshcat(trace_alternate2)


    SCORE: 446075.1250000
    VARIANCE: 0.0000100
    OUTLIER_PROB 0.0010000
    

    SCORE: 445532.4375000
    VARIANCE: 0.0000100
    OUTLIER_PROB 0.0010000
    
