In [1]:
import genjax
import numpy as np
import bayes3d as b
import jax.numpy as jnp
import bayes3d.genjax
import jax
from utils import *
from viz import *
from models import *

In [2]:
setup_renderer_and_meshes(
    height=50,
    width=50,
    focal_length=250,
    near=0.1, far=20,
    ids = range(1,22))

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)
Centering mesh with translation [ 8.9965761e-07  2.3238501e-02 -3.4500263e-06]
Centering mesh with translation [0.        0.0063132 0.       ]
Centering mesh with translation [ 5.9977174e-07  2.3238501e-02 -3.4500263e-06]
Centering mesh with translation [0.         0.05167415 0.        ]


In [3]:
b.RENDERER.mesh_names

['obj_1',
 'obj_2',
 'obj_3',
 'obj_4',
 'obj_5',
 'obj_6',
 'obj_7',
 'obj_8',
 'obj_9',
 'obj_10',
 'obj_11',
 'obj_12',
 'obj_13',
 'obj_14',
 'obj_15',
 'obj_16',
 'obj_17',
 'obj_18',
 'obj_19',
 'obj_20',
 'obj_21',
 'occulder',
 'pyramid',
 'bunny',
 'cube',
 'icosahedron',
 'box_large',
 'diamond',
 'sphere',
 'orange',
 'box_small',
 'table']

In [5]:
import genjax
import bayes3d as b
import jax
import jax.numpy as jnp
import numpy as np
from jax.debug import print as jprint

MAX_UNFOLD_LENGTH = 100

############################### V1 ########################################

@genjax.gen
def dynamics_v1(prev):
    (t, pose, velocity) = prev
    velocity = b.gaussian_vmf_pose(velocity, 0.005, 10000.0)  @ f"velocity"
    pose = b.gaussian_vmf_pose(pose @ velocity, 0.005, 10000.0)  @ f"pose"
    return (t + 1, pose, velocity)

dynamics_v1_unfold = genjax.UnfoldCombinator.new(dynamics_v1, MAX_UNFOLD_LENGTH)
image_likelihood_vmap = genjax.MapCombinator.new(genjax.gen(lambda *x: b.image_likelihood(*x) @ "depths"), in_axes=(0,None,None,None,None))

@genjax.gen
def model_v1(T_vec, N_total_vec, N_vec, all_box_dims, pose_bounds, outlier_volume, focal_length):
    """
    Single Object
    """
    T = T_vec.shape[0]
    # sample init pose and velocity
    pose = b.uniform_pose(jnp.array(pose_bounds[0]), jnp.array(pose_bounds[1])) @ "init_pose"
    velocity = b.gaussian_vmf_pose(jnp.eye(4), 0.01, 10000.0) @ "init_velocity"

    all_poses = []
    for i in range(N_vec.shape[0]):
        # sample dynamics over T time steps
        dynamics = dynamics_v1_unfold(T,(0, pose, velocity)) @ f"dynamics_{i+1}"
        # Slice off pose from the full unfold memory
        poses = jax.lax.slice_in_dim(dynamics[1], 0, T+1, axis = 0) 
        all_poses.append(poses)
    
    all_poses = jnp.stack(all_poses, axis = 1)

    indices = b.uniform_discrete_array(N_total_vec, N_vec) @ "indices"

    rendered_images = b.RENDERER.render_many(
        all_poses, indices 
    )[...,:3]

    variance = genjax.distributions.tfp_uniform(0.00000000001, 10000.0) @ "variance"
    outlier_prob  = genjax.distributions.tfp_uniform(-0.01, 10000.0) @ "outlier_prob"
    
    images = image_likelihood_vmap(rendered_images, variance, outlier_prob, outlier_volume, focal_length) @ "depths"
    return rendered_images, poses

model_v1_simulate_jit = jax.jit(model_v1.simulate)
model_v1_importance_jit = jax.jit(model_v1.importance)


In [6]:
T = 5
pose_bounds = jnp.array([[-0.1,-0.1,1.5],[0.1,0.1,2]])
key = jax.random.PRNGKey(31415)
weight, gt_trace = model_v1_importance_jit(key, genjax.choice_map({
    "variance": 0.0001,
    "outlier_prob": 0.0001,
    "indices" : jnp.array([23])
}), (jnp.zeros(T), jnp.arange(len(b.RENDERER.meshes)), jnp.zeros(1),
b.RENDERER.model_box_dims, pose_bounds, jnp.float32(1000.0), jnp.float32(1.0)))

In [7]:
gt_trace["indices"]

Array([23], dtype=int32)

In [8]:
video_from_trace(gt_trace, rendered_addr = ("depths", "depths"), framerate = 5)