In [1]:
import sys
sys.path.append("../")
import genjax
import numpy as np
import bayes3d as b
import jax.numpy as jnp
import bayes3d.genjax
import jax
from utils import *
from viz import *
from models import *
from renderer_setup import *
from genjax.inference.importance_sampling import sampling_importance_resampling

console = genjax.pretty()

In [8]:
import dataclasses
from typing import Tuple

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from plum import dispatch

from genjax._src.core.datatypes.generative import ChoiceMap
from genjax._src.core.datatypes.generative import GenerativeFunction
from genjax._src.core.pytree.pytree import Pytree
from genjax._src.core.typing import IntArray
from genjax._src.core.typing import PRNGKey
from genjax._src.core.typing import typecheck


@dataclasses.dataclass
class BootstrapImportanceSampling(Pytree):
    """Bootstrap importance sampling for generative functions."""

    num_particles: IntArray
    model: GenerativeFunction

    def flatten(self):
        return (), (self.num_particles, self.model, self.proposal)

    @typecheck
    @classmethod
    def new(
        cls,
        num_particles: IntArray,
        model: GenerativeFunction,
    ):
        return BootstrapImportanceSampling(num_particles, model)

    def apply(
        self,
        key: PRNGKey,
        observations: ChoiceMap,
        model_args: Tuple,
    ):
        sub_keys = jax.random.split(key, self.num_particles)
        (lws, trs) = jax.vmap(self.model.importance, in_axes=(0, None, None))(
            sub_keys,
            observations,
            model_args,
        )
        log_total_weight = jax.scipy.special.logsumexp(lws)
        log_normalized_weights = lws - log_total_weight
        log_ml_estimate = log_total_weight - jnp.log(self.num_particles)
        return (trs, log_normalized_weights, log_ml_estimate, lws)

    @typecheck
    def __call__(self, key: PRNGKey, choice_map: ChoiceMap, *args):
        return self.apply(key, choice_map, *args)
    

@dispatch
def importance_sampling(
    model: GenerativeFunction,
    n_particles: IntArray,
):
    return BootstrapImportanceSampling.new(n_particles, model)

In [19]:
gt_path = "../ground_truths/genjax_generated/scene_2.pkl"
metadata = load_metadata(gt_path)
gt_images = metadata["rendered"]
model = eval("model_v{}".format(metadata["model_version"]))
RENDERER_ARGS = metadata["RENDERER_ARGS"]
setup_renderer_and_meshes = eval("setup_renderer_and_meshes_v{}".format(metadata["renderer_setup_version"]))
setup_renderer_and_meshes(**RENDERER_ARGS)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)
Centering mesh with translation [ 8.9965761e-07  2.3238501e-02 -3.4500263e-06]
Centering mesh with translation [0.        0.0063132 0.       ]
Centering mesh with translation [ 5.9977174e-07  2.3238501e-02 -3.4500263e-06]
Centering mesh with translation [0.         0.05167415 0.        ]


In [11]:
# renderer initialized outside this
def inference_approach_B(model, gt, metadata, num_particles):
    """
    SIR: Get a bunch of importance samples and sample using their weights
    Over the FULL T timesteps
    """
    chm = genjax.choice_map(metadata["CHOICE_MAP_ARGS"])
    # force new constaints values to take over
    chm = chm.unsafe_merge(genjax.choice_map(
        {"depths" : genjax.vector_choice_map(genjax.choice_map({
                "depths": gt
        }))}))
    
    key = jax.random.PRNGKey(metadata["key_number"])
    # subkeys = jax.random.split(key, num)
    imp = sampling_importance_resampling(
        model, num_particles)
    (tr, lnw, log_ml_estimate) = imp.apply(
        key, chm, tuple(metadata["MODEL_ARGS"].values()))
    return tr

# renderer initialized outside this
def inference_approach_A(model, gt, metadata, num_particles):
    """
    IS + MLE: Get a bunch of importance samples and use MLE
    Over the FULL T timesteps
    """
    chm = genjax.choice_map(metadata["CHOICE_MAP_ARGS"])
    # force new constaints values to take over
    chm = chm.unsafe_merge(genjax.choice_map(
        {"depths" : genjax.vector_choice_map(genjax.choice_map({
                "depths": gt
        }))}))
    
    key = jax.random.PRNGKey(metadata["key_number"])
    # subkeys = jax.random.split(key, num)
    imp = importance_sampling(
        model, num_particles)
    (trs, lnw, lmle, lws) = imp.apply(
        key, chm, tuple(metadata["MODEL_ARGS"].values()))
    
    tr = jax.tree_util.tree_map(lambda v: v[jnp.argmax(lnw)], trs)
    return tr

In [49]:
tr = inference_approach_A(model, gt_images, metadata, 30)

In [50]:
tr.score

[1;35mArray[0m[1m([0m[1;36m1757.2998[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m

In [38]:
x = jnp.tile(jnp.expand_dims(jnp.eye(4), axis = 0),jnp.array([100,1,1]))
ichm = genjax.index_choice_map(jnp.array([0,1]), {"velocity" : x})

In [47]:
ichm = genjax.index_choice_map(jnp.array([0,1]), {"velocity" : jnp.ones((2,3))})
ichm




└── [1m(Index, i32[2])[0m
    └── [1m:velocity[0m
        └──  f32[2,3]

In [52]:
# video_from_trace(tr, framerate=5, use_retval=True)
video_from_rendered(gt_images, framerate=5,)

In [20]:
def make_schedule_translation_3d(grid_widths, grid_nums):
    sched = []

    for (grid_width, grid_num) in zip(grid_widths, grid_nums):
        grid = b.utils.make_translation_grid_enumeration(
            -grid_width, -grid_width, -grid_width, 
            +grid_width, +grid_width, +grid_width, 
            *grid_num,  # *grid_num is num_x, num_y, num_z
        )
        sched.append(grid)
    return sched

grid_widths = [0.2,0.1,0.05]
grid_nums = [(3,3,3),(3,3,3),(3,3,3)]
gridding_schedule = make_schedule_translation_3d(grid_widths, grid_nums)



In [21]:
def inference_approach_C(model, gt, metadata):
    """
    Greedy Grid Enumeration of T=0 to T=1
    """
    chm = genjax.choice_map(metadata["CHOICE_MAP_ARGS"])
    # force new constaints values to take over
    chm = chm.unsafe_merge(genjax.choice_map(
        {"depths" : genjax.vector_choice_map(genjax.choice_map({
                "depths": gt
        })),
        "init_pose" : metadata["init_pose"] # assume init pose is known
        }) 
    )
    
    # make 3d translation grid: list of N x 4 x 4 poses
    grid_widths = [0.1,0.05,0.025]
    grid_nums = [(3,3,3),(3,3,3),(3,3,3)]
    gridding_schedule = make_schedule_translation_3d(grid_widths, grid_nums)

    # make initial sample:
    key = jax.random.PRNGKey(metadata["key_number"])
    _, trace = model.importance(key, chm, tuple(metadata["MODEL_ARGS"].values()))

    # do inference by updating the T=1 slice of the velocity address
    # first get the enumerator
    enumerator = b.make_enumerator([("dynamics_1", "velocity")])
    # then update trace over all the proposals
    velocity_vector = trace["dynamics_1", "velocity"]
    for grid in gridding_schedule:
        trace = c2f_pose_update_jit(trace, key, 1, velocity_vector, grid, enumerator)
    return trace

In [22]:
tr = inference_approach_C(model, gt_images, metadata)

In [16]:
def unfold_with_proposals(T, proposal, unfold_vector):
    """
    Note that T starts from 1, where 0 is before the first run of 
    proposals: N x 4 x 4 of N proposed velocity vectors
    """
    return unfold_vector.at[T,...].set(proposal)

unfold_with_proposals_vmap = jax.jit(jax.vmap(unfold_with_proposals, in_axes = (None, 0, None)))
    

def c2f_pose_update(trace_, key, T, unfold_array, pose_grid, enumerator):
    
    proposed_unfold_vectors = unfold_with_proposals_vmap(T, pose_grid, unfold_array)
    scores = enumerator[3](trace_, key, proposed_unfold_vectors)
    return enumerator[0](
        trace_, key,
        proposed_unfold_vectors[scores.argmax()]
    )
c2f_pose_update_jit = jax.jit(c2f_pose_update, static_argnames=("enumerator",))

In [16]:
unfolded_arr = jnp.ones((100,4,4))
xs = jnp.ones((23,4,4)) * 0.3
r = jax.vmap(jax.jit(lambda x, idx : unfolded_arr.at[idx,...].set(unfolded_arr[idx-1] @ x)), in_axes = (0,None))(xs,3)


In [15]:
import jax
import jax.numpy as jnp
from collections import namedtuple

# Define jitted functions
@jax.jit
def add(a, b):
    return a + b

@jax.jit
def multiply(a, b):
    return a * b

# Create a named tuple of jitted functions
JittedFunctions = namedtuple('JittedFunctions', ['func1', 'func2'])
jitted_functions = JittedFunctions(func1=add, func2=multiply)
# jitted_functions = (add, multiply)

# Define a function that accepts the named tuple of jitted functions
@jax.jit
def process(add, x, y):
    # result1 = func_tuple.func1(x, y)
    # result2 = func_tuple.func2(x, y)
    add = xx[0]
    return add(x,y)
    # return result1, result2

# Pass the named tuple of jitted functions to the 'process' function
result = process(add, 3, 4)
print(result)  # Output: (7, 12)