In [1]:
import jax.numpy as jnp
import bayes3d as b
import os
import jax
import functools
from jax.scipy.special import logsumexp
from functools import partial
from tqdm import tqdm
import matplotlib.pyplot as plt
import bayes3d.genjax
import genjax
import pathlib

from tensorflow_probability.substrates import jax as tfp

In [2]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7006/static/


In [3]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=200.0, fy=200.0,
    cx=50.0, cy=50.0,
    near=0.0001, far=2.0
)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
meshes = []
for idx in range(1,22):
    mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
    b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)
# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/10.0)
b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
table_pose = b.t3d.inverse_pose(
    b.t3d.transform_from_pos_target_up(
        jnp.array([0.0, 0.8, .15]),
        jnp.array([0.0, 0.0, 0.0]),
        jnp.array([0.0, 0.0, 1.0]),
    )
)
importance_jit = jax.jit(b.model.importance)

In [None]:
width = 0.03
ang = jnp.pi
num_position_grids = 51
num_angle_grids = 51
contact_param_deltas = b.utils.make_translation_grid_enumeration_3d(
    -width, -width, -ang,
    width, width, ang,
    num_position_grids,num_position_grids,num_angle_grids
)

In [None]:
enumerators = b.make_enumerator(["contact_params_1"])

In [None]:
key = jax.random.PRNGKey(100)

In [None]:
def make_orientation_posterior_viz(observation, gt_contact, sampled_contacts):
    fig = plt.figure(constrained_layout=True)
    # fig.suptitle(f"Variance: {variance} Outlier Prob: {outlier_prob}")
    widths = [1, 1]
    heights = [2]
    spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,
                              height_ratios=heights)
    
    ax = fig.add_subplot(spec[0, 0])
    ax.imshow(jnp.array(b.get_depth_image(observation[...,2],max=1.4)))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title(f"Observation (params {gt_contact[0]:0.2f}, {gt_contact[1]:0.2f} ,{gt_contact[2]:0.2f})")
    # ax.set_title(f"Observed Depth")
    
    
    ax = fig.add_subplot(spec[0, 1])
    ax.set_aspect(1.0)
    circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle="--", linewidth=0.5)
    ax.add_patch(circ)
    ax.set_xlim(-1.1, 1.1)
    ax.set_ylim(-1.1, 1.1)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.scatter(-jnp.sin(sampled_contacts[:,2]),jnp.cos(sampled_contacts[:,2]), color='red',label="Posterior Samples", alpha=0.5, s=30)
    ax.scatter(-jnp.sin(gt_contact[2]),jnp.cos(gt_contact[2]), label="Actual", alpha=0.9, s=25)
    ax.set_title("Posterior on Orientation (top view)")
    # ax.legend(fontsize=9)
    # plt.show()
    return fig

In [None]:
def c2f_contact_update(trace_, key, contact_param_deltas):
    contact_param_grid = contact_param_deltas + trace_[f"contact_params_1"]
    scores = enumerators[3](trace_, key, contact_param_grid)
    i = scores.argmax()
    return enumerators[0](
        trace_, key,
        contact_param_grid[i]
    )
c2f_contact_update_jit = jax.jit(c2f_contact_update)

In [None]:
key = jax.random.PRNGKey(100)
key2 = jax.random.PRNGKey(1000)

In [None]:
grid_params = [
    (0.3, jnp.pi, (15,15,15)), (0.2, jnp.pi, (15,15,15)), (0.1, jnp.pi, (15,15,15)),
    (0.05, jnp.pi/3, (15,15,15)), (0.02, jnp.pi, (9,9,51)), (0.01, jnp.pi/5, (15,15,15)), (0.01, 0.0, (31,31,1)),(0.05, 0.0, (31,31,1))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]

In [None]:
def get_init_trace(key, img):
    weight, trace = importance_jit(key, genjax.choice_map({
        "parent_0": -1,
        "parent_1": 0,
        "id_0": jnp.int32(21),
        "id_1": jnp.int32(13),
        "camera_pose": jnp.eye(4),
        "root_pose_0": table_pose,
        "face_parent_1": 2,
        "face_child_1": 3,
        "variance": 0.0001,
        "outlier_prob": 0.001,
        "image": b.unproject_depth(img, intrinsics),
        "contact_params_1": jnp.zeros(3)#jax.random.uniform(key, shape=(3,),minval=low, maxval=high)
    }), (
        jnp.arange(2),
        jnp.arange(22),
        jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
        jnp.array([jnp.array([-0.5, -0.5, -2*jnp.pi]), jnp.array([0.5, 0.5, 2*jnp.pi])]),
        b.RENDERER.model_box_dims, 1.0, intrinsics.fx)
    )
    return trace

In [None]:
#train_data_file = jnp.load('train_data.npz')
test_data_file = jnp.load('test_data_noxy.npz')
test_imgs = test_data_file['arr_0']
test_labels = test_data_file['arr_1']
N_TEST = test_imgs.shape[0]

In [None]:
#train_data_file = jnp.load('train_data.npz')
test_data_file_noisy = jnp.load('test_data_noxy_noisy.npz')
test_imgs_noisy = test_data_file_noisy['arr_0']
test_labels_noisy = test_data_file_noisy['arr_1']

In [None]:
key = jax.random.PRNGKey(0)

In [None]:
def von_mises_cross_entropy(gt_theta, sampled_thetas):
    von_mises = tfp.distributions.VonMises.experimental_fit(sampled_thetas)
    return -von_mises.log_prob(gt_theta)

In [None]:
def fit_von_mises(thetas):
    # https://www.jmlr.org/papers/volume6/banerjee05a/banerjee05a.pdf 
    n = thetas.shape[0]
    unit_vecs = jax.vmap(lambda theta: jnp.array([jnp.cos(theta), jnp.sin(theta)]))(thetas)
    r = unit_vecs.sum(axis=0)
    normr = jnp.linalg.norm(r)
    rbar = normr / n 
    vec_mean = r / normr
    conc0 = 1.0
    f = lambda x: jax.scipy.special.i1(x)/(jax.scipy.special.i0(x) + 1e-6) - rbar
    df = jax.grad(f)
    conc = jax.lax.fori_loop(0, 100,
                             lambda i, conc: jnp.minimum(conc - f(conc)/(df(conc) + 1e-6), 1e8),
                             conc0)
    mean = jnp.arctan2(vec_mean[1], vec_mean[0])
    return mean, conc, rbar

In [None]:
def get_posterior_samples(key, img):
    img = img.reshape(100, 100)
    tr = get_init_trace(key, img)
    
    key = jax.random.split(key, 1)[0]
    path = []
    path.append(tr)
    for c2f_iter in range(len(contact_param_gridding_schedule)):
        tr = c2f_contact_update_jit(tr, key, contact_param_gridding_schedule[c2f_iter])
        path.append(tr)
        key = jax.random.split(key, 1)[0]

        
    contact_param_grid = tr["contact_params_1"] + contact_param_deltas
    weights =  jnp.concatenate([
        enumerators[3](tr, key, cp)
        for cp in jnp.array_split(contact_param_grid, 60)
    ],axis=0)
    
    key = jax.random.split(key, 1)[0]
    normalized_weights = b.utils.normalize_log_scores(weights)
    sampled_indices = jax.random.choice(key ,jnp.arange(normalized_weights.shape[0]), shape=(1000,), p=normalized_weights)
    sampled_contact_params = contact_param_grid[sampled_indices]

    thetas = sampled_contact_params[:, 2]

    return thetas

In [None]:
thetass = jnp.empty((0, 1000))
thetass_buffer = []
for (i, key_) in tqdm(zip(range(N_TEST), jax.random.split(key, N_TEST))):
    thetass_buffer.append(get_posterior_samples(key_, test_imgs[i, :, :, 0]))
    if (i + 1) % 20 == 0:
        thetass = jnp.vstack([thetass, *thetass_buffer])
        jnp.save('thetass.npy', thetass)
        thetass_buffer = []

In [None]:
thetass_noisy =  jnp.empty((0, 1000))
thetass_noisy_buffer = []
for (i, key_) in tqdm(zip(range(N_TEST), jax.random.split(key, N_TEST))):
    thetass_noisy_buffer.append(get_posterior_samples(key_, test_imgs_noisy[i, :, :, 0]))
    if (i + 1) % 20 == 0:
        thetass_noisy = jnp.vstack([thetass_noisy, *thetass_noisy_buffer])
        jnp.save('thetass_noisy.npy', thetass_noisy)
        thetass_noisy_buffer = []

In [None]:
means, concs, rbars = jax.jit(jax.vmap(fit_von_mises))(thetass)

In [None]:
von_mises = tfp.distributions.VonMises
calc_loss = jax.jit(jax.vmap(lambda mean, conc, label: -von_mises(loc=mean, concentration=conc).log_prob(label)))

In [None]:
test_loss = calc_loss(means, concs, test_labels.flatten()).sum()
test_loss

In [None]:
means, concs, rbars = jax.jit(jax.vmap(fit_von_mises))(thetass_noisy)
test_loss_noisy = calc_loss(means, concs, test_labels_noisy.flatten()).sum()
test_loss_noisy