In [1]:
import sys
sys.path.append("../")
import genjax
import numpy as np
import bayes3d as b
import jax.numpy as jnp
import bayes3d.genjax
import jax
from utils import *
from viz import *
from models import *
from renderer_setup import *
from genjax.inference.importance_sampling import sampling_importance_resampling
from genjax._src.core.transforms.incremental import NoChange
from genjax._src.core.transforms.incremental import UnknownChange
from genjax._src.core.transforms.incremental import Diff

console = genjax.pretty()

In [2]:
gt_path = "../ground_truths/genjax_generated/scene_2.pkl"
metadata = load_metadata(gt_path)
gt_images = metadata["rendered"]
model = eval("model_v{}".format(metadata["model_version"]))
RENDERER_ARGS = metadata["RENDERER_ARGS"]
setup_renderer_and_meshes = eval("setup_renderer_and_meshes_v{}".format(metadata["renderer_setup_version"]))
setup_renderer_and_meshes(**RENDERER_ARGS)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)
Centering mesh with translation [ 8.9965761e-07  2.3238501e-02 -3.4500263e-06]
Centering mesh with translation [0.        0.0063132 0.       ]
Centering mesh with translation [ 5.9977174e-07  2.3238501e-02 -3.4500263e-06]
Centering mesh with translation [0.         0.05167415 0.        ]


In [3]:
def make_schedule_translation_3d(grid_widths, grid_nums):
    sched = []

    for (grid_width, grid_num) in zip(grid_widths, grid_nums):
        grid = b.utils.make_translation_grid_enumeration(
            -grid_width, -grid_width, -grid_width, 
            +grid_width, +grid_width, +grid_width, 
            *grid_num,  # *grid_num is num_x, num_y, num_z
        )
        sched.append(grid)
    return sched

grid_widths = [0.2,0.1,0.05]
grid_nums = [(3,3,3),(3,3,3),(3,3,3)]
gridding_schedule = make_schedule_translation_3d(grid_widths, grid_nums)



In [10]:
def unfold_with_proposals(T, proposal, unfold_vector):
    """
    Note that T starts from 1, where 0 is before the first run of 
    proposals: N x 4 x 4 of N proposed velocity vectors
    """
    return unfold_vector.at[T,...].set(proposal)

unfold_with_proposals_vmap = jax.jit(jax.vmap(unfold_with_proposals, in_axes = (None, 0, None)))
    

def c2f_pose_update(trace_, key, T, unfold_array, pose_grid, enumerator):
    
    proposed_unfold_vectors = unfold_with_proposals_vmap(T, pose_grid, unfold_array)
    scores = enumerator.score_vmap(trace_, key, proposed_unfold_vectors)
    return enumerator.enum_f(
        trace_, key,
        proposed_unfold_vectors[scores.argmax()]
    )
c2f_pose_update_jit = jax.jit(c2f_pose_update, static_argnames=("enumerator",))

In [11]:
def velocity_chm_builder(addresses, args):
    print(addresses, args)
    chm = genjax.choice_map({
                addresses[0][0]:genjax.index_choice_map(
                    jnp.arange(2),genjax.choice_map({
                        addresses[0][1]: args[0][:2]
            }))
        })
    return chm
    

def inference_approach_C(model, gt, metadata):
    """
    Greedy Grid Enumeration of T=0 to T=1
    """
    chm = genjax.choice_map(metadata["CHOICE_MAP_ARGS"])
    # force new constaints values to take over
    chm = chm.unsafe_merge(genjax.choice_map(
        {"depths" : genjax.vector_choice_map(genjax.choice_map({
                "depths": gt
        })),
        "init_pose" : metadata["init_pose"] # assume init pose is known
        }) 
    )
    
    # make 3d translation grid: list of N x 4 x 4 poses
    grid_widths = [0.1,0.05,0.025]
    grid_nums = [(3,3,3),(3,3,3),(3,3,3)]
    gridding_schedule = make_schedule_translation_3d(grid_widths, grid_nums)

    # make initial sample:
    key = jax.random.PRNGKey(metadata["key_number"])
    _, trace = model.importance(key, chm, tuple(metadata["MODEL_ARGS"].values()))
    # return trace

    # do inference by updating the T=1 slice of the velocity address
    # first make the chm builder:

    enumerator = b.make_enumerator([("dynamics_1", "velocity")], chm_builder = velocity_chm_builder)
    # then update trace over all the proposals
    velocity_vector = trace["dynamics_1", "velocity"]
    for grid in gridding_schedule:
        trace = c2f_pose_update_jit(trace, key, 1, velocity_vector, grid, enumerator)
    return trace

In [13]:
tr = inference_approach_C(model, gt_images, metadata)

[('dynamics_1', 'velocity')] (Traced<ShapedArray(float32[100,4,4])>with<BatchTrace(level=3/0)> with
  val = Traced<ShapedArray(float32[27,100,4,4])>with<DynamicJaxprTrace(level=2/0)>
  batch_dim = 0,)
[('dynamics_1', 'velocity')] (Traced<ShapedArray(float32[100,4,4])>with<DynamicJaxprTrace(level=2/0)>,)
[('dynamics_1', 'velocity')] (Traced<ShapedArray(float32[100,4,4])>with<BatchTrace(level=3/0)> with
  val = Traced<ShapedArray(float32[27,100,4,4])>with<DynamicJaxprTrace(level=2/0)>
  batch_dim = 0,)
[('dynamics_1', 'velocity')] (Traced<ShapedArray(float32[100,4,4])>with<DynamicJaxprTrace(level=2/0)>,)


In [15]:
# video_from_trace(tr, framerate=5, use_retval=True)
video_from_rendered(gt_images, framerate=5,)

In [13]:
gt_vel = tr["dynamics_1","velocity"]
en = b.make_enumerator([("dynamics_1", "velocity")], chm_builder = velocity_chm_builder)

In [8]:
key = jax.random.PRNGKey(234523)
chm = velocity_chm_builder([('dynamics_1', 'velocity')], gt_vel)

[('dynamics_1', 'velocity')] [[[ 9.9868363e-01  4.6832435e-02  2.0925160e-02  1.2540886e-02]
  [-4.6936616e-02  9.9888760e-01  4.5154467e-03 -7.2484016e-03]
  [-2.0690415e-02 -5.4916586e-03  9.9977082e-01  1.3298560e-03]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  1.0000000e+00]]

 [[ 9.9697751e-01  5.9375841e-02  5.0104488e-02  1.5674656e-02]
  [-5.9408851e-02  9.9823326e-01 -8.3156227e-04 -8.8680061e-03]
  [-5.0065342e-02 -2.1476008e-03  9.9874353e-01  3.8148614e-03]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  1.0000000e+00]]

 [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]]

 ...

 [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]

In [26]:
# genjax.choice_map({"variance": 0.01})
# x = tr.update(key, chm, b.make_unknown_change_argdiffs(tr))
ichm = genjax.choice_map({
            "dynamics_1":genjax.index_choice_map(jnp.arange(gt_vel.shape[0]),genjax.choice_map({
                "velocity": gt_vel
        }))
    })

x = tr.update(key, ichm, tuple(map(lambda v: Diff(v, NoChange), tr.args)))
# x = tr.update(key, genjax.choice_map({"init_velocity": jnp.eye(4)}), tuple(map(lambda v: Diff(v, NoChange), tr.args)))



In [20]:
x[2].score

[1;35mArray[0m[1m([0m[1;36m16228.155[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m

In [15]:
argdiff = b.make_unknown_change_argdiffs(tr)

In [16]:
argdiff


[1m([0m
    [1;35mDiff[0m[1m([0m[33mprimal[0m=[1;35mArray[0m[1m([0m[1m[[0m[1;36m0[0m.[1m][0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m, [33mtangent[0m=[1;35m_UnknownChange[0m[1m([0m[1m)[0m[1m)[0m,
    [1;35mDiff[0m[1m([0m
        [33mprimal[0m=[1;35mArray[0m[1m([0m[1m[[0m [1;36m0[0m,  [1;36m1[0m,  [1;36m2[0m,  [1;36m3[0m,  [1;36m4[0m,  [1;36m5[0m,  [1;36m6[0m,  [1;36m7[0m,  [1;36m8[0m,  [1;36m9[0m, [1;36m10[0m, [1;36m11[0m, [1;36m12[0m, [1;36m13[0m, [1;36m14[0m, [1;36m15[0m, [1;36m16[0m,
       [1;36m17[0m, [1;36m18[0m, [1;36m19[0m, [1;36m20[0m, [1;36m21[0m, [1;36m22[0m, [1;36m23[0m, [1;36m24[0m, [1;36m25[0m, [1;36m26[0m, [1;36m27[0m, [1;36m28[0m, [1;36m29[0m, [1;36m30[0m, [1;36m31[0m[1m][0m,      [33mdtype[0m=[35mint32[0m[1m)[0m,
        [33mtangent[0m=[1;35m_UnknownChange[0m[1m([0m[1m)[0m
    [1m)[0m,
    [1;35mDiff[0m[1m([0m[33mprimal[0m=[1;35mArray[0m[

In [15]:
tr["init_velocity"]


[1;35mArray[0m[1m([0m[1m[[0m[1m[[0m [1;36m0.99992085[0m, [1;36m-0.00174896[0m,  [1;36m0.01246393[0m,  [1;36m0.0140807[0m [1m][0m,
       [1m[[0m [1;36m0.00168465[0m,  [1;36m0.9999852[0m ,  [1;36m0.00516792[0m, [1;36m-0.00165914[0m[1m][0m,
       [1m[[0m[1;36m-0.01247278[0m, [1;36m-0.00514652[0m,  [1;36m0.9999089[0m ,  [1;36m0.00492127[0m[1m][0m,
       [1m[[0m [1;36m0[0m.        ,  [1;36m0[0m.        ,  [1;36m0[0m.        ,  [1;36m1[0m.        [1m][0m[1m][0m,      [33mdtype[0m=[35mfloat32[0m[1m)[0m