In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE = False
import os
import genjax
import numpy as np
import bayes3d as b
from genjax.generative_functions.distributions import ExactDensity
import jax.numpy as jnp
import bayes3d.genjax
import jax
from jax.debug import print as jprint
import matplotlib
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from IPython.display import HTML
import PIL

def display_video(frames, framerate=30):
    if type(frames[0]) == PIL.Image.Image:
      frames = [np.array(frames[i]) for i in range(len(frames))]
    height, width, _ = frames[0].shape
    dpi = 70
    orig_backend = matplotlib.get_backend()
    matplotlib.use('Agg')  # Switch to headless 'Agg' to inhibit figure rendering.
    fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
    matplotlib.use(orig_backend)  # Switch back to the original backend.
    ax.set_axis_off()
    ax.set_aspect('equal')
    ax.set_position([0, 0, 1, 1])
    im = ax.imshow(frames[0])
    def update(frame):
      im.set_data(frame)
      return [im]
    interval = 1000/framerate
    anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                   interval=interval, blit=True, repeat=True)
    return HTML(anim.to_html5_video())

def video_from_trace(trace, scale = 8, use_retval = False):
    if use_retval:
        depths = trace.get_retval()[0]
    else:
        depths = trace["depths", "depths"]
    images = [b.scale_image(b.get_depth_image(depths[i,...,2]),scale) for i in range(depths.shape[0])]
    return images

console = genjax.pretty()

env: XLA_PYTHON_CLIENT_PREALLOCATE=False


In [2]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/


In [3]:
intrinsics = b.Intrinsics(
    height=50,
    width=50,
    fx=250.0, fy=250.0,
    cx=25.0, cy=25.0,
    near=0.01, far=20.0
)

b.setup_renderer(intrinsics)
b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=0.1)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)


In [4]:
T = 10

@genjax.gen
def dynamics_model_v1(prev):
    (t, pose, velocity) = prev
    velocity = b.gaussian_vmf_pose(velocity, 0.005, 10000.0)  @ f"velocity"
    pose = b.gaussian_vmf_pose(pose @ velocity, 0.005, 10000.0)  @ f"pose"
    return (t + 1, pose, velocity)

dynamics_model_v1_unfold = genjax.UnfoldCombinator.new(dynamics_model_v1, T+1)
image_likelihood_vmap = genjax.MapCombinator.new(genjax.gen(lambda *x: b.image_likelihood(*x) @ "depths"), in_axes=(0,None,None,None,None))

In [5]:
@genjax.gen
def model_single_object(T_vec, outlier_volume, focal_length):
    T = T_vec.shape[0]
    pose = b.uniform_pose(jnp.array([-0.1,-0.1,1.5]), jnp.array([0.1,0.1,2])) @ "init_pose"
    velocity = b.gaussian_vmf_pose(jnp.eye(4), 0.01, 10000.0) @ "init_velocity"
    dynamics = dynamics_model_v1_unfold(T,(0, pose, velocity)) @ "dynamics"
    poses = jax.lax.slice_in_dim(dynamics[1], 0, T+1, axis = 0) 

    # poses = jax.lax.cond(poses.shape[0] < 4, lambda x: poses[:,None,...], lambda x: poses, (None,))
    indices = jnp.array([0])

    rendered_images = b.RENDERER.render_many(
        poses[:,None,...], indices 
    )[...,:3]

    variance = genjax.distributions.tfp_uniform(0.00000000001, 10000.0) @ "variance"
    outlier_prob  = genjax.distributions.tfp_uniform(-0.01, 10000.0) @ "outlier_prob"
    
    # images = b.image_likelihood(rendered_images[0,...], variance, outlier_prob, outlier_volume, focal_length) @ "images"
    images = image_likelihood_vmap(rendered_images, variance, outlier_prob, outlier_volume, focal_length) @ "depths"
    return rendered_images, poses


model_single_object_simulate_jit = jax.jit(model_single_object.simulate)
model_single_object_importance_jit = jax.jit(model_single_object.importance)


In [6]:
# key = jax.random.PRNGKey(31415)
# gt_tr = model_single_object_simulate_jit(key, (jnp.zeros(T), jnp.float64(1000.0), jnp.float64(1.0)))
# poses = tr["dynamics"]["pose"]
# for i in range(poses.shape[0]):
#     b.show_pose(f"{i}", poses[i])
# display_video(video_from_trace(gt_tr, 0.7),  framerate=24)

# def c2f_pose_update(trace_, key,  number, contact_param_deltas, VARIANCE_GRID, OUTLIER_GRID):
#     contact_param_grid = contact_param_deltas + trace_[f"contact_params_{number}"]
#     scores = contact_enumerators[number][3](trace_, key, contact_param_grid, VARIANCE_GRID, OUTLIER_GRID)
#     i,j,k = jnp.unravel_index(scores.argmax(), scores.shape)
#     return contact_enumerators[number][0](
#         trace_, key,
#         contact_param_grid[i], VARIANCE_GRID[j], OUTLIER_GRID[k]
#     )

In [6]:
key = jax.random.PRNGKey(31415)
subkey = jax.random.split(key, 1)[0]
weight, gt_trace = model_single_object_importance_jit(subkey, genjax.choice_map({
    "variance": 0.0001,
    "outlier_prob": 0.0001,
}), (jnp.zeros(T), jnp.float32(1000.0), jnp.float32(1.0)))

In [7]:
display_video(video_from_trace(gt_trace, 4, use_retval = True),  framerate=24)

  matplotlib.use(orig_backend)  # Switch back to the original backend.


In [8]:
gt_trace.get_score()

[1;35mArray[0m[1m([0m[1;36m89422.77[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m

In [9]:
subkey = jax.random.split(subkey, 5446)[0]
subkey = jax.random.PRNGKey(5436456)
chm = genjax.choice_map(
    {"depths" : genjax.vector_choice_map(
        genjax.choice_map({
            "depths":gt_trace.get_choices()["depths", "depths"]
        })
    ),
    "variance": 0.0001,
    "outlier_prob": 0.0001}
)
# chm = genjax.vector_choice_map(genjax.choice_map({("depths", "depths"):gt_trace.get_choices()["depths", "depths"]}))
w,trace = model_single_object_importance_jit(subkey, chm, (jnp.zeros(T), jnp.float32(1000.0), jnp.float32(1.0)))
w

[1;35mArray[0m[1m([0m[1;36m16312.931[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m

In [10]:
display_video(video_from_trace(trace, 4, use_retval = True),  framerate=24)

  matplotlib.use(orig_backend)  # Switch back to the original backend.


In [12]:
# viz_images = []
# max_depth = 10.0

# # inferred_poses_with_occ = jnp.stack([inferred_poses, occ1_poses], axis = 1)
# occ_image = b.viz.get_depth_image(b.RENDERER.render(occ1_pose[None,...], jnp.array([1]))[:,:,2])

# pred_images = b.RENDERER.render_many(inferred_poses[:,None, ...], jnp.array([0]))

# pred_with_occ_images = [b.overlay_image(b.viz.get_depth_image(pred_images[i,:,:,2]), 
# occ_image, alpha=0.4) for i in range(pred_images.shape[0])]

# gt_images = b.RENDERER.render_many(gt_poses[:,None, ...], jnp.array([0]))
# gt_with_occ_images = [b.overlay_image(b.viz.get_depth_image(gt_images[i,:,:,2]), 
# occ_image, alpha=0.5) for i in range(pred_images.shape[0])]

# viz_images = [
#     b.viz.multi_panel(
#         [g, b.viz.get_depth_image(p[:,:,2]), po],
#         labels = ["Ground Truth", "Reconstruction w/o Occluder", "Reconstruction w Occluder"],
#         title = "Scene 11A",
#         # bottom_text = "3DP3 + Physics Prior v1"
#     )
#     for (g, p, po) in zip(gt_with_occ_images, pred_images, pred_with_occ_images)
# ]
# display_video(viz_images)

In [11]:
imp = genjax.inference.importance_sampling.sampling_importance_resampling(model_single_object, 100)
(tr, lnw, log_ml_estimate) = imp.apply(subkey, chm, (jnp.zeros(T), jnp.float32(1000.0), jnp.float32(1.0)))

In [12]:
display_video(video_from_trace(tr, 4, use_retval = True),  framerate=24)

  matplotlib.use(orig_backend)  # Switch back to the original backend.


In [13]:
tr.get_score()

[1;35mArray[0m[1m([0m[1;36m43208.86[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m

In [14]:
log_ml_estimate

[1;35mArray[0m[1m([0m[1;36m42709.805[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m