In [1]:
import sys
from functools import partial
sys.path.append("../")
import genjax
import numpy as np
import bayes3d as b
import jax.numpy as jnp
import bayes3d.genjax
import jax
from utils import *
from viz import *
from models import *
from renderer_setup import *
from genjax.inference.importance_sampling import sampling_importance_resampling
from genjax._src.core.transforms.incremental import NoChange
from genjax._src.core.transforms.incremental import UnknownChange
from genjax._src.core.transforms.incremental import Diff

console = genjax.pretty()

In [2]:
gt_path = "../ground_truths/genjax_generated/scene_3.pkl"
metadata = load_metadata(gt_path)
gt_images = metadata["rendered"]
model = eval("model_v{}".format(metadata["model_version"]))
RENDERER_ARGS = metadata["RENDERER_ARGS"]
setup_renderer_and_meshes = eval("setup_renderer_and_meshes_v{}".format(metadata["renderer_setup_version"]))
setup_renderer_and_meshes(**RENDERER_ARGS)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)
Centering mesh with translation [ 8.9965761e-07  2.3238501e-02 -3.4500263e-06]
Centering mesh with translation [0.        0.0063132 0.       ]
Centering mesh with translation [ 5.9977174e-07  2.3238501e-02 -3.4500263e-06]
Centering mesh with translation [0.         0.05167415 0.        ]


In [5]:
def velocity_chm_builder(addresses, args):
    print(addresses, args)
    chm = genjax.choice_map({
                addresses[0][0]:genjax.index_choice_map(
                    jnp.arange(args[0].shape[0]),genjax.choice_map({
                        addresses[0][1]: args[0]
            }))
        })
    return chm
    

def inference_approach_D(model, gt, metadata):
    """
    Greedy Grid Enumeration of T=0 to T=5
    # I think this might just be the 2-step model
    """
    T = metadata['MODEL_ARGS']['T_vec'].shape[0]
    # make 3d translation grid: list of N x 4 x 4 poses
    # grid_widths = [0.1,0.05,0.025]
    # grid_nums = [(3,3,3),(3,3,3),(3,3,3)]
    # gridding_schedule = make_schedule_translation_3d(grid_widths, grid_nums)

    # OR use 3d translation and rotation grid
    grid_widths = [0.2,0.1,0.05]
    grid_nums = [(3,3,3),(3,3,3),(3,3,3)]
    # sched = make_schedule_3d([0.1],[(1,1,1)], [-jnp.pi/6, jnp.pi/6],50,10,jnp.pi/40)
    gridding_schedule = make_schedule_3d(grid_widths,grid_nums, [-jnp.pi/12, jnp.pi/12],10,10,jnp.pi/40)

    base_chm = genjax.choice_map(metadata["CHOICE_MAP_ARGS"])
    # first make the chm builder:
    enumerator = b.make_enumerator([("dynamics_1", "velocity")], chm_builder = velocity_chm_builder)
    key = jax.random.PRNGKey(metadata["key_number"])
    # make initial sample:
    _, trace = model.importance(key, base_chm, tuple(metadata["MODEL_ARGS"].values()))

    for t in range(1,T+1):
        print("t = ", t)
        velocity_vector = trace["dynamics_1", "velocity"]
        # force new constaints values to take over
        chm = base_chm.unsafe_merge(genjax.choice_map(
            {"depths" : genjax.index_choice_map(jnp.arange(t+1),genjax.choice_map({
                    "depths": gt[:t+1]
            })),
            "init_pose" : metadata["init_pose"], # assume init pose is known
            "dynamics_1":genjax.index_choice_map(
                jnp.arange(velocity_vector.shape[0]),genjax.choice_map({
                    "velocity": velocity_vector
            }))
            }) 
        )

        # RESORTING to model.importance as I am having issues with update and choicemaps with unfolds & maps
        _, trace = model.importance(key, chm, tuple(metadata["MODEL_ARGS"].values()))
        # put index number as one hot encoded
        t_arr = jnp.zeros(T+1).at[t].set(1)
        # then update trace over all the proposals

        for i, grid in enumerate(gridding_schedule):
            print("Grid #",i+1)
            trace = c2f_pose_update_jit(trace, key, t_arr, velocity_vector, grid, enumerator)
    return trace

In [6]:
tr = inference_approach_D(model, gt_images, metadata)

t =  1
Grid # 1
[('dynamics_1', 'velocity')] (Traced<ShapedArray(float32[100,4,4])>with<BatchTrace(level=3/0)> with
  val = Traced<ShapedArray(float32[2700,100,4,4])>with<DynamicJaxprTrace(level=2/0)>
  batch_dim = 0,)
[('dynamics_1', 'velocity')] (Traced<ShapedArray(float32[100,4,4])>with<DynamicJaxprTrace(level=2/0)>,)
Grid # 2
[('dynamics_1', 'velocity')] (Traced<ShapedArray(float32[100,4,4])>with<BatchTrace(level=3/0)> with
  val = Traced<ShapedArray(float32[2700,100,4,4])>with<DynamicJaxprTrace(level=2/0)>
  batch_dim = 0,)
[('dynamics_1', 'velocity')] (Traced<ShapedArray(float32[100,4,4])>with<DynamicJaxprTrace(level=2/0)>,)
Grid # 3
t =  2
Grid # 1
Grid # 2
Grid # 3
t =  3
Grid # 1
Grid # 2
Grid # 3
t =  4
Grid # 1
Grid # 2
Grid # 3
t =  5
Grid # 1
Grid # 2
Grid # 3
t =  6
Grid # 1
Grid # 2
Grid # 3
t =  7
Grid # 1
Grid # 2
Grid # 3
t =  8
Grid # 1
Grid # 2
Grid # 3
t =  9
Grid # 1
Grid # 2
Grid # 3
t =  10
Grid # 1
Grid # 2
Grid # 3
t =  11
Grid # 1
Grid # 2
Grid # 3
t =  12
Gr

In [11]:
print(tr.score)
print(metadata['score'])

118898.53
210822.4


In [20]:
# video_from_trace(tr, framerate=5, use_retval=True)
# video_from_rendered(gt_images, framerate=5,)
video_comparison_from_trace(tr,framerate = 5, scale = 4)