In [37]:
import bayes3d as b
import bayes3d.genjax
import joblib
from tqdm import tqdm
import os
import jax.numpy as jnp
import jax
import numpy as np

In [18]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=500.0, fy=500.0,
    cx=50.0, cy=50.0,
    near=0.01, far=20.0
)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
meshes = []
for idx in range(1,22):
    mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
    b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)

b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)


Increasing frame buffer size to (width, height, depth) = (128, 128, 1024)


[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


In [23]:
importance_jit = jax.jit(b.genjax.model.importance)
key = jax.random.PRNGKey(10)

In [45]:
gt_traces = []
for scene_id in tqdm(range(200)):
    filename = f"data/trace_{scene_id}.joblib"
    gt_traces.append(importance_jit(key, *joblib.load(filename))[1][1])

hb_traces = []

V_VARIANT = 0
O_VARIANT = 0
HIERARCHICAL_BAYES = True

for scene_id in tqdm(range(200)):
    if HIERARCHICAL_BAYES:
        filename = f"data/inferred_hb_{scene_id}.joblib"
    else:
        filename = f"data/inferred_{V_VARIANT}_{O_VARIANT}_{scene_id}.joblib"
    hb_traces.append(importance_jit(key, *joblib.load(filename))[1][1])

variant_0_0_traces = []

V_VARIANT = 0
O_VARIANT = 0
HIERARCHICAL_BAYES = False

for scene_id in tqdm(range(200)):
    if HIERARCHICAL_BAYES:
        filename = f"data/inferred_hb_{scene_id}.joblib"
    else:
        filename = f"data/inferred_{V_VARIANT}_{O_VARIANT}_{scene_id}.joblib"
    variant_0_0_traces.append(importance_jit(key, *joblib.load(filename))[1][1])


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 33.40it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 31.31it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 33.65it/s]


In [52]:
# Classification accuracy
prediction_sets = [hb_traces]
for pred_set in prediction_sets:
    correct = 0
    for scene_id in range(200):
        gt_ids = set(np.array(b.genjax.get_indices(gt_traces[scene_id])))
        pred_ids = set(np.array(b.genjax.get_indices(pred_set[scene_id])))
        print(pred_set[scene_id]["variance"])
        print(pred_set[scene_id]["outlier_prob"])
        if pred_ids == gt_ids:
            correct +=1
        else:
            print(gt_ids, pred_ids)
    print(correct)    

1e-05
0.01
1e-05
0.01
1e-05
0.001
1e-05
0.01
{13, 2, 21, 15} {13, 21, 15}
1e-05
0.01
{2, 4, 21, 6} {2, 11, 21, 6}
1e-05
0.001
{9, 5, 21} {9, 5, 21, 17}
1e-05
0.01
1e-05
0.01
1e-05
0.001
1e-05
0.001
1e-05
0.01
1e-05
0.001
{20, 21, 7} {17, 20, 21, 7}
1e-05
0.01
1e-05
0.01
1e-05
0.01
1e-05
0.01
1e-05
0.001
1e-05
0.01
1e-05
0.001
1e-05
0.01
{9, 4, 21, 15} {1, 21, 9, 15}
1e-05
0.001
1e-05
0.001
1e-05
0.001
1e-05
0.001
1e-05
0.001
1e-05
0.001
{5, 3, 21} {17, 5, 3, 21}
1e-05
0.01
{16, 1, 18, 21} {1, 18, 21, 15}
1e-05
0.01
{0, 13, 21, 15} {0, 2, 21, 15}
1e-05
0.01
1e-05
0.01
1e-05
0.01
1e-05
0.001
1e-05
0.001
1e-05
0.01
1e-05
0.001
{10, 5, 4, 21} {10, 4, 21, 7}
1e-05
0.01
1e-05
0.01
1e-05
0.01
1e-05
0.001
1e-05
0.01
1e-05
0.001
{21, 6, 7} {17, 21, 6, 7}
1e-05
0.01
1e-05
0.01
1e-05
0.01
{19, 11, 21} {18, 11, 21}
1e-05
0.001
1e-05
0.01
{18, 11, 4, 21} {2, 18, 4, 21}
1e-05
0.001
{17, 3, 21, 1} {1, 3, 4, 21}
1e-05
0.01
{12, 11, 20, 21} {2, 20, 12, 21}
1e-05
0.01
1e-05
0.01
1e-05
0.01
{8, 4, 21} {8

In [30]:
scene_id = 0
a = b.genjax.get_indices(gt_traces[scene_id])

In [44]:
len(hb_traces)

400

In [33]:
set(list(a))

TypeError: unhashable type: 'ArrayImpl'

In [36]:
a.asarray()

AttributeError: 'ArrayImpl' object has no attribute 'asarray'