In [1]:
import sys
from functools import partial
sys.path.append("../")
import genjax
import numpy as np
import bayes3d as b
import jax.numpy as jnp
import bayes3d.genjax
import jax
from utils import *
from viz import *
from models import *
from renderer_setup import *
from genjax.inference.importance_sampling import sampling_importance_resampling
from genjax._src.core.transforms.incremental import NoChange
from genjax._src.core.transforms.incremental import UnknownChange
from genjax._src.core.transforms.incremental import Diff

console = genjax.pretty()

In [2]:
gt_path = "../ground_truths/genjax_generated/scene_6.pkl"
metadata = load_metadata(gt_path)
gt_images = metadata["rendered"]
model = eval("model_v{}".format(metadata["model_version"]))
RENDERER_ARGS = metadata["RENDERER_ARGS"]
setup_renderer_and_meshes = eval("setup_renderer_and_meshes_v{}".format(metadata["renderer_setup_version"]))
setup_renderer_and_meshes(**RENDERER_ARGS)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)


In [3]:
def c2f_pose_update_v2(trace_, key, pose_grid, enumerator):
    
    scores = enumerator.enumerate_choices_get_scores(trace_, key, pose_grid)
    return enumerator.update_choices(
        trace_, key,
        pose_grid[scores.argmax()]
    )
c2f_pose_update_v2_jit = jax.jit(c2f_pose_update_v2, static_argnames=("enumerator",))


# TODO: UPDATE for 
def c2f_pose_update_v3(trace_, key, pose_grid, enumerator):
    
    scores = enumerator.enumerate_choices_get_scores(trace_, key, pose_grid)
    return enumerator.update_choices(
        trace_, key,
        pose_grid[scores.argmax()]
    )
c2f_pose_update_v3_jit = jax.jit(c2f_pose_update_v3, static_argnames=("enumerator",))

In [None]:
def velocity_chm_builder(addresses, args):
    chm = genjax.choice_map({
                addresses[0][0]:genjax.index_choice_map(
                    jnp.arange(args[0].shape[0]),genjax.choice_map({
                        addresses[0][1]: args[0]
            }))
        })
    return chm


def inference_approach_F(model, gt, metadata):
    """
    2-step model with NO unfold
    HMM-style
    """
    T = metadata['MODEL_ARGS']['T_vec'].shape[0]

    # OR use 3d translation and rotation grid
    grid_widths = [0.2,0.1,0.05]
    grid_nums = [(3,3,3),(3,3,3),(3,3,3)]
    gridding_schedule = make_schedule_3d(grid_widths,grid_nums, [-jnp.pi/12, jnp.pi/12],10,10,jnp.pi/40)

    base_chm = genjax.choice_map(metadata["CHOICE_MAP_ARGS"])
    # first make the chm builder:
    enumerators = [b.make_enumerator([(f"velocity")]) for i in range(T)]

    key = jax.random.PRNGKey(metadata["key_number"])
    # make initial sample:
    _, trace = model.importance(key, base_chm, tuple(metadata["MODEL_ARGS"].values()))

    for t in range(1,T+1):
        print("t = ", t)
        # force new constaints values to take over
        chm = base_chm.unsafe_merge(genjax.choice_map(
            {"depths" : genjax.index_choice_map(jnp.arange(t+1),genjax.choice_map({
                    "depths": gt[:t+1]
            })),
            "init_pose" : metadata["init_pose"], # assume init pose is known
            **{f"velocity_{i+1}" : trace[f"velocity_{i+1}"] for i in range(t-1)}
            }) 
        )

        # RESORTING to model.importance as I am having issues with update and choicemaps with unfolds &/or maps
        _, trace = model.importance(key, chm, tuple(metadata["MODEL_ARGS"].values()))

        # trace = trace.update(key, chm, b.make_unknown_change_argdiffs(trace))

        # then update trace over all the proposals
        for i, grid in enumerate(gridding_schedule):
            print("Grid #",i+1)
            trace = c2f_pose_update_v2_jit(trace, key, grid, enumerators[t-1])
    return trace

In [25]:
tr = inference_approach_F(model, gt_images, metadata)

t =  1
Grid # 1
Grid # 2
Grid # 3
t =  2
Grid # 1
Grid # 2
Grid # 3
t =  3
Grid # 1
Grid # 2
Grid # 3
t =  4
Grid # 1
Grid # 2
Grid # 3
t =  5
Grid # 1
Grid # 2
Grid # 3


In [26]:
print(tr.score)
print(metadata['score'])

36070.402
48653.72


In [28]:
# video_from_trace(tr, framerate=5, use_retval=True)
# video_from_rendered(gt_images, framerate=5,)
video_comparison_from_trace(tr,framerate = 5, scale = 4)

In [28]:
a = [*(i for i in range(4))]
a

[1m[[0m[1;36m0[0m, [1;36m1[0m, [1;36m2[0m, [1;36m3[0m[1m][0m