In [1]:
import bayes3d as b
import os
import jax.numpy as jnp
import jax
import bayes3d.genjax
import genjax
import matplotlib
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7007/static/


In [3]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=50.0, fy=50.0,
    cx=50.0, cy=50.0,
    near=0.01, far=20.0
)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
meshes = []
for idx in range(1,22):
    mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
    b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)
b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)


[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (128, 128, 1024)


In [4]:
table_pose = b.t3d.inverse_pose(
    b.t3d.transform_from_pos_target_up(
        jnp.array([0.0, 0.2, .05]),
        jnp.array([0.0, 0.0, 0.0]),
        jnp.array([0.0, 0.0, 1.0]),
    )
)


importance_jit = jax.jit(b.genjax.model.importance)
update_jit = jax.jit(b.genjax.model.update)

In [5]:
# del importance_jit
# del update_jit

In [6]:
b.utils.ycb_loader.MODEL_NAMES[13]

'025_mug'

In [7]:
b.RENDERER.model_box_dims[13]

Array([0.116966, 0.093075, 0.081384], dtype=float32)

In [9]:
contact_enumerators = [b.genjax.make_enumerator([f"contact_params_{i}", "variance", "outlier_prob"]) for i in range(5)]
single_enumerators = b.genjax.make_enumerator([f"contact_params_1"])

def c2f_contact_update(trace_, key,  number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID):
    contact_param_grid = contact_param_deltas + trace_[f"contact_params_{number}"]
    scores = contact_enumerators[number][3](trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID)
    i,j,k = jnp.unravel_index(scores.argmax(), scores.shape)
    return contact_enumerators[number][0](
        trace_, key,
        contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]
    )
c2f_contact_update_jit = jax.jit(c2f_contact_update, static_argnames=("number",))

In [10]:
key = jax.random.PRNGKey(100)

In [11]:
VARIANCE_GRID = jnp.linspace(0.0000398, 0.0000398, 3)
OUTLIER_GRID = jnp.array([0.00251])
OUTLIER_VOLUME = 10.0

grid_params = [
    (0.3, jnp.pi, (11,11,11)), (0.2, jnp.pi/2, (11,11,11)), (0.1, jnp.pi/2, (11,11,11)),
    (0.05, jnp.pi/3, (11,11,11)), (0.02, jnp.pi, (5,5,51)), (0.01, jnp.pi/5, (11,11,11)), (0.01, 0.0, (21,21,1)),(0.01, 0.0, (21,21,1))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]

width = 0.04
ang = jnp.pi
final_contact_param_deltas = b.utils.make_translation_grid_enumeration_3d(
    -width, -width, -ang,
    width, width, ang,
    21,21,300
)

def get_depth_image(image):
    mval = image[image < image.max()].max()
    return b.get_depth_image(image, max=mval)



In [12]:
for experiment_iteration in tqdm(range(2)):
    print(key)
    key = jax.random.split(key,1)[0]
    
    weight, gt_trace = importance_jit(key, genjax.choice_map({
        "parent_0": -1,
        "parent_1": 0,
        "id_0": jnp.int32(21),
        "id_1": jnp.int32(13),
        "camera_pose": jnp.eye(4),
        "root_pose_0": table_pose,
        # "contact_params_1": jnp.array([ 0.01630328 ,-0.06595182, -2.946241  ]),
        "face_parent_1": 2,
        "face_child_1": 3,
        "variance": VARIANCE_GRID[0],
        "outlier_prob": OUTLIER_GRID[0],
    }), (
        jnp.arange(2),
        jnp.arange(22),
        jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
        jnp.array([jnp.array([-0.1, -0.1, -2*jnp.pi]), jnp.array([0.1, 0.1, 2*jnp.pi])]),
        b.RENDERER.model_box_dims, OUTLIER_VOLUME)
                                  )
    trace = gt_trace
    contact_param_grid = final_contact_param_deltas + trace[f"contact_params_1"]
    weights = jnp.concatenate([
            contact_enumerators[1][3](
            trace,
            key,
            d + trace[f"contact_params_1"],
            VARIANCE_GRID,
            OUTLIER_GRID
        ) for d in jnp.array_split(final_contact_param_deltas, 55)
    ],axis=0
    )
    
    i,j,k = jnp.unravel_index(weights.argmax(), weights.shape)
    trace= contact_enumerators[1][0](
        trace, key,
        contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]
    )
    print(trace["contact_params_1"])
    print(trace["variance"])
    
    print(trace.get_score())
    
    print(gt_trace["contact_params_1"])
    print(gt_trace["variance"])
    print(gt_trace.get_score())
    
    key2 = jax.random.PRNGKey(0)
    sampled_indices = jax.random.categorical(key2, weights.reshape(-1), shape=(1000,))
    sampled_indices = jnp.unravel_index(sampled_indices, weights.shape)[0]
    sampled_params = contact_param_grid[sampled_indices]
    actual_params = gt_trace["contact_params_1"]
    
    fig = plt.figure(constrained_layout=True)
    widths = [1, 1]
    heights = [2]
    spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,
                              height_ratios=heights)
    
    ax = fig.add_subplot(spec[0, 0])
    ax.imshow(jnp.array(get_depth_image(trace["image"][...,2])))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title("Observed Depth")
    
    
    ax = fig.add_subplot(spec[0, 1])
    ax.set_aspect(1.0)
    circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle="--", linewidth=0.5)
    ax.add_patch(circ)
    ax.set_xlim(-2.0, 2.0)
    ax.set_ylim(-2.0, 2.0)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.scatter(-jnp.sin(sampled_params[:,2]),-jnp.cos(sampled_params[:,2]),label="Posterior Samples", alpha=0.5, s=15)
    ax.scatter(-jnp.sin(actual_params[2]),-jnp.cos(actual_params[2]), color=(1.0, 0.0, 0.0),label="Actual", alpha=0.9, s=10)
    ax.set_title("Posterior on Orientation (top view)")
    ax.legend(fontsize=7)
    # plt.show()
    plt.savefig(f'{experiment_iteration:05d}.png')
    plt.clf()

  0%|                                          | 0/2 [00:00<?, ?it/s]

[  0 100]


2023-08-23 18:04:58.135776: W external/xla/xla/service/hlo_rematerialization.cc:2202] Can't reduce memory use below 11.83GiB (12701564928 bytes) by rematerialization; only reduced to 13.80GiB (14821108912 bytes), down from 13.80GiB (14821108912 bytes) originally
2023-08-23 18:05:08.672180: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 13.80GiB (rounded to 14821082112)requested by op 
2023-08-23 18:05:08.672389: W external/tsl/tsl/framework/bfc_allocator.cc:497] *___________________________________________________________________________________________________
2023-08-23 18:05:08.672772: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 14821081912 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   146.0KiB
              constant allocation:       312B
        maybe_live_out 

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 14821081912 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   146.0KiB
              constant allocation:       312B
        maybe_live_out allocation:    28.2KiB
     preallocated temp allocation:   13.80GiB
  preallocated temp fragmentation:        64B (0.00%)
                 total allocation:   13.80GiB
              total fragmentation:    30.4KiB (0.00%)
Peak buffers:
	Buffer 1:
		Size: 13.17GiB
		Operator: op_name="jit(enumerator_score)/jit(main)/reduce_max[axes=(4, 5)]" source_file="/home/nishadgothoskar/bayes3d/bayes3d/likelihood.py" source_line=185
		XLA Label: fusion
		Shape: f32[2406,3,100,100,7,7]
		==========================

	Buffer 2:
		Size: 367.13MiB
		Operator: op_name="jit(enumerator_score)/jit(main)/vmap(vmap(vmap(jit(_render_custom_call))))/jit(_render_custom_call)/render_multiple_140129822588640" source_file="/home/nishadgothoskar/bayes3d/bayes3d/renderer.py" source_line=156
		XLA Label: custom-call
		Shape: f32[2406,100,100,4]
		==========================

	Buffer 3:
		Size: 275.34MiB
		Operator: op_name="jit(enumerator_score)/jit(main)/reduce_max[axes=(4, 5)]" source_file="/home/nishadgothoskar/bayes3d/bayes3d/likelihood.py" source_line=185
		XLA Label: fusion
		Shape: f32[2406,3,100,100]
		==========================

	Buffer 4:
		Size: 117.2KiB
		Operator: op_name="jit(enumerator_score)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3, 4), collapsed_slice_dims=(), start_index_map=(0, 1, 2)) slice_sizes=(1, 1, 3) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/nishadgothoskar/bayes3d/bayes3d/likelihood.py" source_line=182
		XLA Label: fusion
		Shape: f32[10000,1,1,3]
		==========================

	Buffer 5:
		Size: 117.2KiB
		Entry Parameter Subshape: f32[100,100,3]
		==========================

	Buffer 6:
		Size: 28.2KiB
		Entry Parameter Subshape: f32[2406,3]
		==========================

	Buffer 7:
		Size: 28.2KiB
		Operator: op_name="jit(enumerator_score)/jit(main)/add" source_file="<@beartype(genjax._src.generative_functions.builtin.builtin_gen_fn.BuiltinGenerativeFunction.update) at 0x7f72a29818b0>" source_line=92
		XLA Label: fusion
		Shape: f32[2406,3,1]
		==========================

	Buffer 8:
		Size: 264B
		Entry Parameter Subshape: f32[22,3]
		==========================

	Buffer 9:
		Size: 64B
		Entry Parameter Subshape: f32[4,4]
		==========================

	Buffer 10:
		Size: 64B
		Entry Parameter Subshape: f32[4,4]
		==========================

	Buffer 11:
		Size: 64B
		Entry Parameter Subshape: f32[4,4]
		==========================

	Buffer 12:
		Size: 56B
		XLA Label: tuple
		Shape: (s32[], f32[2406,2,4,4], s32[2], f32[2,3], s32[2], /*index=5*/f32[2406,2,3], s32[2])
		==========================

	Buffer 13:
		Size: 48B
		Operator: op_name="jit(enumerator_score)/jit(main)/vmap(vmap(vmap(while)))/body/concatenate[dimension=1]" source_file="/home/nishadgothoskar/bayes3d/bayes3d/transforms_3d.py" source_line=29
		XLA Label: fusion
		Shape: (f32[16,4], f32[16,4], f32[16,4], f32[16,4], f32[16,4], /*index=5*/f32[16,4])
		==========================

	Buffer 14:
		Size: 36B
		XLA Label: constant
		Shape: f32[3,3]
		==========================

	Buffer 15:
		Size: 36B
		XLA Label: constant
		Shape: f32[3,3]
		==========================



In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, FloatSlider, IntSlider, Button, Output, HBox, VBox, FloatLogSlider

out = Output(layout={'border': '5px solid black', "height" : '100px'})

def func(x,y,ang,variance, outlier_prob, outlier_volume):

    VARIANCE_GRID = jnp.array([variance])
    OUTLIER_GRID = jnp.array([outlier_prob])
    OUTLIER_VOLUME = outlier_volume

    weight, gt_trace = importance_jit(key, genjax.choice_map({
        "parent_0": -1,
        "parent_1": 0,
        "id_0": jnp.int32(21),
        "id_1": jnp.int32(13),
        "camera_pose": jnp.eye(4),
        "root_pose_0": table_pose,
        "contact_params_1": jnp.array([ x,y,ang  ]),
        "face_parent_1": 2,
        "face_child_1": 3,
        "variance": VARIANCE_GRID[0],
        "outlier_prob": OUTLIER_GRID[0],
    }), (
        jnp.arange(2),
        jnp.arange(22),
        jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
        jnp.array([jnp.array([-0.1, -0.1, -2*jnp.pi]), jnp.array([0.1, 0.1, 2*jnp.pi])]),
        b.RENDERER.model_box_dims, OUTLIER_VOLUME)
                                  )
    trace = gt_trace
    contact_param_grid = final_contact_param_deltas + trace[f"contact_params_1"]
    weights = jnp.concatenate([
            contact_enumerators[1][3](
            trace,
            key,
            d + trace[f"contact_params_1"],
            VARIANCE_GRID,
            OUTLIER_GRID
        ) for d in jnp.array_split(final_contact_param_deltas, 55)
    ],axis=0
    )
    
    i,j,k = jnp.unravel_index(weights.argmax(), weights.shape)
    trace= contact_enumerators[1][0](
        trace, key,
        contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]
    )
    print(trace["contact_params_1"])
    print(trace["variance"])
    
    print(trace.get_score())
    
    print(gt_trace["contact_params_1"])
    print(gt_trace["variance"])
    print(gt_trace.get_score())
    
    key2 = jax.random.PRNGKey(0)
    sampled_indices = jax.random.categorical(key2, weights.reshape(-1), shape=(1000,))
    sampled_indices = jnp.unravel_index(sampled_indices, weights.shape)[0]
    sampled_params = contact_param_grid[sampled_indices]
    actual_params = gt_trace["contact_params_1"]
    
    fig = plt.figure(constrained_layout=True)
    widths = [1, 1]
    heights = [2]
    spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,
                              height_ratios=heights)
    
    ax = fig.add_subplot(spec[0, 0])
    ax.imshow(jnp.array(get_depth_image(trace["image"][...,2])))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title("Observed Depth")
    
    
    ax = fig.add_subplot(spec[0, 1])
    ax.set_aspect(1.0)
    circ = plt.Circle((0, 0), radius=1, edgecolor='black', facecolor='None', linestyle="--", linewidth=0.5)
    ax.add_patch(circ)
    ax.set_xlim(-2.0, 2.0)
    ax.set_ylim(-2.0, 2.0)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.scatter(-jnp.sin(sampled_params[:,2]),-jnp.cos(sampled_params[:,2]),label="Posterior Samples", alpha=0.5, s=15)
    ax.scatter(-jnp.sin(actual_params[2]),-jnp.cos(actual_params[2]), color=(1.0, 0.0, 0.0),label="Actual", alpha=0.9, s=10)
    ax.set_title("Posterior on Orientation (top view)")
    ax.legend(fontsize=7)
    # plt.show()
    # plt.savefig(f'{experiment_iteration:05d}.png')
    # plt.clf()
    
    with out: 
        out.clear_output()
        display(f"variance   = {variance}")
        display(f"outlier_prob = {outlier_prob}")
        display(f"outlier_volume = {outlier_volume}")

w = interactive(func, 
    x = FloatSlider(min=-0.1, max=0.1, value=0.0, description=" x:"),
    y = FloatSlider(min=-0.1, max=0.1, value=0.0, description=" y:"),
    ang = FloatSlider(min=-jnp.pi, max=jnp.pi, value=jnp.pi, description=" ang:"),
    variance = FloatLogSlider(base=10.0, min=-9, max=1, value=0.000501, description="variance:"),
    outlier_prob = FloatLogSlider(base=10.0, min=-4, max=0, value=0.631, description="outlier_prob:"),
    outlier_volume = FloatLogSlider(base=10.0, min=1, max=5, value=10.0, description="outlier_volume:")
);
display(VBox([w,out]))