# 

In [1]:
import bayes3d as b
import os
import jax.numpy as jnp
import jax
import bayes3d.genjax
import numpy as np
import genjax
import matplotlib
import pathlib
from tqdm import tqdm
import matplotlib.pyplot as plt
import optax

In [2]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7006/static/


In [3]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=200.0, fy=200.0,
    cx=50.0, cy=50.0,
    near=0.0001, far=2.0
)

b.setup_renderer(intrinsics)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (128, 128, 1024)


In [4]:
import trimesh
box_width = 0.02
hammer_width = 0.05
hand_length = 0.15
b1 = trimesh.creation.box(
    np.array(jnp.array([hand_length, box_width,box_width])),
    np.array(b.transform_from_pos(jnp.array([0.0, 0.0, 0.0])))
)
b2 = trimesh.creation.box(
    np.array(jnp.array([hammer_width,hammer_width, hammer_width])),
    np.array(b.transform_from_pos(jnp.array([hand_length/2 - hammer_width/2, 0.0, 0.0])))
)
b3 = trimesh.creation.box(
    np.array(jnp.array([hammer_width,hammer_width, hammer_width])),
    np.array(b.transform_from_pos(jnp.array([-hand_length/2 + hammer_width/2, 0.0, 0.0, ])))
)
m1 = trimesh.util.concatenate([b1,b2])
m2 = trimesh.util.concatenate([b1,b2,b3])
b.show_trimesh("1", m2)

b.utils.mesh.export_mesh(m1, "m1.obj")
b.utils.mesh.export_mesh(m2, "m2.obj")
table_mesh_path = b.utils.get_assets_dir() + '/sample_objs/cube.obj'

box_mesh = b.utils.make_cuboid_mesh(jnp.array([0.1, 0.1, 0.3]))
b.RENDERER.add_mesh(m1)
b.RENDERER.add_mesh(m2)
b.RENDERER.add_mesh(box_mesh)
b.RENDERER.add_mesh_from_file(table_mesh_path, scaling_factor=1e-6)

In [5]:
table_pose = b.t3d.inverse_pose(
    b.t3d.transform_from_pos_target_up(
        jnp.array([0.0, 0.7, 0.5]),
        jnp.array([0.0, 0.0, 0.0]),
        jnp.array([0.0, 0.0, 1.0]),
    )
)
importance_jit = jax.jit(b.model.importance)
update_jit = jax.jit(b.model.update)
enumerators = b.make_enumerator(["contact_params_2"])

In [6]:
key = jax.random.PRNGKey(10000)

In [7]:
FORK_IDX, KNIFE_IDX, CHEESEITZ_BOX_IDX, TABLE_IDX = 0, 1, 2, 3
SHIFT_MIN = -0.2
SHIFT_SCALE = 0.4
CHEESEITZ_BOX_CONTACT_PARAMS = jnp.array([0.0, 0.1, 0.0])

contact_param_grid = jax.vmap(
    lambda shift: jnp.array([shift, 0, jnp.pi/2])
)(jnp.linspace(SHIFT_MIN, SHIFT_MIN+SHIFT_SCALE, 2000))



In [8]:
0.4/25

0.016

In [9]:
def c2f_contact_update(trace_, key, contact_param_deltas):
    contact_param_grid = contact_param_deltas + trace_[f"contact_params_2"]
    scores = enumerators[3](trace_, key, contact_param_grid)
    i = scores.argmax()
    return enumerators[0](
        trace_, key,
        contact_param_grid[i]
    )
c2f_contact_update_jit = jax.jit(c2f_contact_update)

grid_params = [
    (0.2, jnp.pi, (25,25,25)), (0.016, jnp.pi/25, (21, 21,21)), #(0.0, jnp.pi/32, (21,21,21)),
    #(0.05, jnp.pi/3, (20,20,20)), (0.02, jnp.pi, (10,10,51)), (0.01, jnp.pi/5, (15,15,15)), (0.01, 0.0, (31,31,1)),(0.05, 0.0, (31,31,1))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]



In [10]:
def fork_spoon_generator(x, y, theta, is_fork):

    #indices = jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, FORK_IDX if is_fork else KNIFE_IDX])
    indices = jax.lax.cond(is_fork,
                           lambda: jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, FORK_IDX]),
                           lambda: jnp.array([TABLE_IDX, CHEESEITZ_BOX_IDX, KNIFE_IDX]))

    box_dims = b.RENDERER.model_box_dims[indices]
    root_poses = jnp.array([table_pose, table_pose, table_pose])
    parents = jnp.array([-1, 0, 0])
    contact_params = jnp.array([[0.0, 0.0, 0.0],
                                [*CHEESEITZ_BOX_CONTACT_PARAMS],
                                [x, y, theta]])
    faces_parents = jnp.array([0, 2, 2])
    faces_child = jnp.array([0, 3, 3])
    poses = b.scene_graph.poses_from_scene_graph(
        root_poses, box_dims, parents, contact_params, faces_parents, faces_child)
    camera_pose = jnp.eye(4)
    rendered = b.RENDERER.render(
        jnp.linalg.inv(camera_pose) @ poses , indices
    )[...,:3]
    return rendered
fork_spoon_generator_jit = jax.jit(fork_spoon_generator)

In [11]:
#train_data_file = jnp.load('train_data.npz')
test_data_file = jnp.load('test_data.npz')
test_imgs = test_data_file['arr_0'][:200]
test_labels = test_data_file['arr_1'[:200]]
N_TEST = test_imgs.shape[0]

In [12]:
#train_data_file = jnp.load('train_data.npz')
test_data_file_noisy = jnp.load('test_data_noisy.npz')
test_imgs_noisy = test_data_file_noisy['arr_0'][:200]
test_labels_noisy = test_data_file_noisy['arr_1'][:200]
N_TEST_NOISY = test_imgs_noisy.shape[0]

In [13]:
def make_init_trace(key, img):
    weight, trace = importance_jit(key, genjax.choice_map({
        "parent_0": -1,
        "parent_1": 0,
        "parent_2": 0,
        "id_0": jnp.int32(3),
        "id_1": jnp.int32(CHEESEITZ_BOX_IDX),
        "id_2": jnp.int32(FORK_IDX),# if is_fork else jnp.int32(KNIFE_IDX)),
        "camera_pose": jnp.eye(4),
        "root_pose_0": table_pose,
        "face_parent_1": 2,
        "face_parent_2": 2,
        "face_child_1": 3,
        "face_child_2": 3,
        "variance": 0.00001, #0.00001, 
        "outlier_prob": 0.5, #0.1,
        "image": b.unproject_depth(img, intrinsics),
        "contact_params_1": CHEESEITZ_BOX_CONTACT_PARAMS,
        "contact_params_2": jnp.array([0.0, 0.0, 0.0])
    }), (
        jnp.arange(3),
        jnp.arange(4),
        jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
        jnp.array([jnp.array([-0.2, -0.2, -2*jnp.pi]), jnp.array([0.2, 0.2, 2*jnp.pi])]),
        b.RENDERER.model_box_dims, 1.0, intrinsics.fx)
    )
    return trace

In [14]:
def do_c2f(key, trace):
    for c2f_iter in range(len(contact_param_gridding_schedule)):
        trace = c2f_contact_update_jit(trace, key, contact_param_gridding_schedule[c2f_iter])
    return trace

In [15]:
def do_inference(key, img):
    original_trace = make_init_trace(key, img)
    trace1 = update_jit(key, original_trace, genjax.choice_map({"id_2": KNIFE_IDX}), b.make_unknown_change_argdiffs(original_trace))[2]
    best_trace1 = do_c2f(key, trace1)
    trace2 = update_jit(key, original_trace, genjax.choice_map({"id_2": FORK_IDX}), b.make_unknown_change_argdiffs(original_trace))[2]
    best_trace2 = do_c2f(key, trace2)
    z = jnp.minimum(best_trace1.get_score(), best_trace2.get_score())
    log_scores = jnp.array([best_trace1.get_score(), best_trace2.get_score()]) - z
    return log_scores, best_trace1, best_trace2

In [16]:
#cross_entropies = jnp.zeros(N_TEST)
log_scoress = jnp.zeros((N_TEST, 2))
log_scoress_noisy = jnp.zeros((N_TEST, 2))
for (idx, key_) in tqdm(enumerate(jax.random.split(key, N_TEST))):
    log_scoress = log_scoress.at[idx, :].set(do_inference(key_, test_imgs[idx,...,-1])[0])
    log_scoress_noisy = log_scoress_noisy.at[idx, :].set(do_inference(key_, test_imgs_noisy[idx,...,-1])[0])
    #cross_entropies = cross_entropies.at[idx].set(optax.softmax_cross_entropy(log_scores, true_scores))

200it [06:23,  1.92s/it]


In [17]:
losses = optax.softmax_cross_entropy(logits=log_scoress, labels=test_labels)
print(losses.sum())
losses_noisy = optax.softmax_cross_entropy(logits=log_scoress_noisy, labels=test_labels_noisy)
print(losses.sum())

TypeError: mul got incompatible shapes for broadcasting: (2000, 2), (200, 2).

In [None]:
21/100 * 2000

In [None]:
logsumexp = jax.scipy.special.logsumexp

In [None]:
cel_loss = lambda logits, labels: -((logits - logsumexp(logits, axis=1).reshape(logits.shape[0], 1)) * labels).sum()

In [None]:
cel_loss(log_scoress, test_labels)

In [None]:
logsumexp(log_scoress, axis=1)

In [None]:
test_labels

In [None]:
losses = optax.softmax_cross_entropy(log_scoress, test_labels)
print(losses.sum())
jnp.argsort(losses)

In [None]:
idx = 97
log_scores, bt1, bt2 = do_inference(key, test_imgs[idx,:, :, 0])
true_scores = jnp.array([1, 0]) if test_labels[idx].flatten()[0] else jnp.array([0, 1])
print(jnp.exp(log_scores - jax.scipy.special.logsumexp(log_scores)))
print(true_scores)
b.viz.scale_image(
    b.viz.hstack_images([b.get_depth_image(test_imgs[idx, :, :, 0]),
                         b.get_depth_image(fork_spoon_generator_jit(*bt1['contact_params_2'], False)[:, :, 2]),
                         b.get_depth_image(fork_spoon_generator_jit(*bt2['contact_params_2'], True)[:, :, 2])]), 2)

In [None]:
optax.softmax_cross_entropy(log_scores, test_labels[idx])

In [None]:
jnp.save('b3d_scores_no_noise.npy', log_scoress)