In [1]:
import jax.numpy as jnp
import bayes3d as b
import os
import jax
import functools
from jax.scipy.special import logsumexp
from functools import partial
from tqdm import tqdm
import matplotlib.pyplot as plt
import bayes3d.genjax
import genjax
import pathlib

from tensorflow_probability.substrates import jax as tfp

In [2]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7005/static/


In [3]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=200.0, fy=200.0,
    cx=50.0, cy=50.0,
    near=0.0001, far=2.0
)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
meshes = []
for idx in range(1,22):
    mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
    b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)
# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/10.0)
b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)


[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (128, 128, 1024)


In [4]:
table_pose = b.t3d.inverse_pose(
    b.t3d.transform_from_pos_target_up(
        jnp.array([0.0, 0.8, .15]),
        jnp.array([0.0, 0.0, 0.0]),
        jnp.array([0.0, 0.0, 1.0]),
    )
)
importance_jit = jax.jit(b.model.importance)

In [5]:
width = 0.03
ang = jnp.pi
num_position_grids = 51
num_angle_grids = 51
contact_param_deltas = b.utils.make_translation_grid_enumeration_3d(
    -width, -width, -ang,
    width, width, ang,
    num_position_grids,num_position_grids,num_angle_grids
)

In [6]:
enumerators = b.make_enumerator(["contact_params_1"])

In [7]:
key = jax.random.PRNGKey(100)

In [8]:
def make_orientation_posterior_viz(observation, gt_contact, sampled_contacts):
    fig = plt.figure(constrained_layout=True)
    # fig.suptitle(f"Variance: {variance} Outlier Prob: {outlier_prob}")
    widths = [1, 1]
    heights = [2]
    spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,
                              height_ratios=heights)
    
    ax = fig.add_subplot(spec[0, 0])
    ax.imshow(jnp.array(b.get_depth_image(observation[...,2],max=1.4)))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title(f"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})")
    # ax.set_title(f"Observed Depth")
    
    
    ax = fig.add_subplot(spec[0, 1])
    ax.set_aspect(1.0)
    circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle="--", linewidth=0.5)
    ax.add_patch(circ)
    ax.set_xlim(-1.1, 1.1)
    ax.set_ylim(-1.1, 1.1)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.scatter(-jnp.sin(sampled_contacts[:,2]),jnp.cos(sampled_contacts[:,2]), color='red',label="Posterior Samples", alpha=0.5, s=30)
    ax.scatter(-jnp.sin(gt_contact[2]),jnp.cos(gt_contact[2]), label="Actual", alpha=0.9, s=25)
    ax.set_title("Posterior on Orientation (top view)")
    # ax.legend(fontsize=9)
    # plt.show()
    return fig

In [9]:
def c2f_contact_update(trace_, key, contact_param_deltas):
    contact_param_grid = contact_param_deltas + trace_[f"contact_params_1"]
    scores = enumerators[3](trace_, key, contact_param_grid)
    i = scores.argmax()
    return enumerators[0](
        trace_, key,
        contact_param_grid[i]
    )
c2f_contact_update_jit = jax.jit(c2f_contact_update)

In [10]:
key = jax.random.PRNGKey(100)
key2 = jax.random.PRNGKey(1000)

In [11]:
grid_params = [
    (0.3, jnp.pi, (15,15,15)), (0.2, jnp.pi, (15,15,15)), (0.1, jnp.pi, (15,15,15)),
    (0.05, jnp.pi/3, (15,15,15)), (0.02, jnp.pi, (9,9,51)), (0.01, jnp.pi/5, (15,15,15)), (0.01, 0.0, (31,31,1)),(0.05, 0.0, (31,31,1))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]

In [12]:
def get_init_trace(key, img):
    weight, trace = importance_jit(key, genjax.choice_map({
        "parent_0": -1,
        "parent_1": 0,
        "id_0": jnp.int32(21),
        "id_1": jnp.int32(13),
        "camera_pose": jnp.eye(4),
        "root_pose_0": table_pose,
        "face_parent_1": 2,
        "face_child_1": 3,
        "variance": 0.0001,
        "outlier_prob": 0.001,
        "image": b.unproject_depth(img, intrinsics),
        "contact_params_1": jnp.zeros(3)#jax.random.uniform(key, shape=(3,),minval=low, maxval=high)
    }), (
        jnp.arange(2),
        jnp.arange(22),
        jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
        jnp.array([jnp.array([-0.5, -0.5, -2*jnp.pi]), jnp.array([0.5, 0.5, 2*jnp.pi])]),
        b.RENDERER.model_box_dims, 1.0, intrinsics.fx)
    )
    return trace

In [13]:
#train_data_file = jnp.load('train_data.npz')
test_data_file = jnp.load('test_data_noxy.npz')
test_imgs = test_data_file['arr_0']
test_labels = test_data_file['arr_1']
N_TEST = test_imgs.shape[0]

In [14]:
key = jax.random.PRNGKey(0)

In [15]:
def von_mises_cross_entropy(gt_theta, sampled_thetas):
    von_mises = tfp.distributions.VonMises.experimental_fit(sampled_thetas)
    return -von_mises.log_prob(gt_theta)

In [44]:
def fit_von_mises(thetas):
    # https://www.jmlr.org/papers/volume6/banerjee05a/banerjee05a.pdf 
    n = thetas.shape[0]
    unit_vecs = jax.vmap(lambda theta: jnp.array([jnp.cos(theta), jnp.sin(theta)]))(thetas)
    r = unit_vecs.sum(axis=0)
    normr = jnp.linalg.norm(r)
    rbar = normr / n 
    vec_mean = r / normr
    conc0 = 1/(2 * (1 - rbar) + 1e-6)
    f = lambda x: jax.scipy.special.i1(x)/(jax.scipy.special.i0(x) + 1e-6) - rbar
    df = jax.grad(f)
    conc = jax.lax.fori_loop(0, 1,
                             lambda i, conc: conc - f(conc)/(df(conc) + 1e-6),
                             conc0)
    conc = conc0
    mean = jnp.arctan2(vec_mean[1], vec_mean[0])
    return mean, conc, rbar

In [17]:
def get_posterior_samples(key, img):
    img = img.reshape(100, 100)
    tr = get_init_trace(key, img)
    
    key = jax.random.split(key, 1)[0]
    path = []
    path.append(tr)
    for c2f_iter in range(len(contact_param_gridding_schedule)):
        tr = c2f_contact_update_jit(tr, key, contact_param_gridding_schedule[c2f_iter])
        path.append(tr)
        key = jax.random.split(key, 1)[0]

        
    contact_param_grid = tr["contact_params_1"] + contact_param_deltas
    weights =  jnp.concatenate([
        enumerators[3](tr, key, cp)
        for cp in jnp.array_split(contact_param_grid, 60)
    ],axis=0)
    
    key = jax.random.split(key, 1)[0]
    normalized_weights = b.utils.normalize_log_scores(weights)
    sampled_indices = jax.random.choice(key ,jnp.arange(normalized_weights.shape[0]), shape=(1000,), p=normalized_weights)
    sampled_contact_params = contact_param_grid[sampled_indices]

    thetas = sampled_contact_params[:, 2]

    return thetas

In [18]:
thetass = []
for (i, key_) in tqdm(zip(range(N_TEST), jax.random.split(key, N_TEST))):
    thetass.append(get_posterior_samples(key_, test_imgs[i, :, :, 0]))
thetass = jnp.vstack(thetass)

296it [11:21,  2.30s/it]


KeyboardInterrupt: 

In [21]:
thetass = jnp.vstack(thetass)
thetass.shape

(296, 1000)

In [45]:
means, concs, rbars = jax.jit(jax.vmap(fit_von_mises))(thetass)

In [46]:
concs

Array([8.21899033e+00, 1.00000000e+06, 7.33503723e+00, 3.06838242e+04,
       1.00000000e+06, 1.00000000e+06, 1.91145075e+06, 5.78768750e+03,
       5.24013424e+00, 1.31305725e+06, 1.00000000e+06, 1.04937793e+04,
       8.52483177e+00, 3.22590552e+03, 7.36578875e+05, 1.26070811e+04,
       1.00000000e+06, 1.14901514e+01, 5.97527695e+04, 7.57136011e+00,
       1.00000000e+06, 8.81280041e+00, 1.31305725e+06, 7.06380320e+00,
       1.08767185e+01, 1.00000000e+06, 1.58159658e+04, 3.06838242e+04,
       6.01814453e+04, 5.89134766e+04, 7.83148861e+00, 1.00000000e+06,
       3.11394297e+04, 2.59597607e+03, 1.56730089e+01, 7.94835986e+03,
       8.93487938e+05, 1.04937793e+04, 1.18600025e+01, 5.89134766e+04,
       7.14086723e+00, 8.93487938e+05, 1.31305725e+06, 6.01814453e+04,
       1.25504951e+04, 1.15371656e+01, 1.00000000e+06, 6.10575195e+04,
       1.31305725e+06, 5.80370459e+03, 4.91024609e+03, 1.00000000e+06,
       3.10242656e+04, 9.06495703e+03, 8.93487938e+05, 1.14216290e+01,
      

In [23]:
von_mises = tfp.distributions.VonMises
jax.jit(jax.vmap(lambda mean, conc, label: -von_mises(loc=mean, concentration=conc).log_prob(label)))(means, concs, test_labels.flatten()[:296]).sum()

Array(-271.65796, dtype=float32)

In [25]:
aa = _

In [26]:
aa / 296 * 2000

Array(-1835.5267, dtype=float32)