In [4]:
import sys
from functools import partial
sys.path.append("../")
import genjax
import numpy as np
from jax_tqdm import scan_tqdm
import bayes3d as b
import jax.numpy as jnp
import bayes3d.genjax
import jax
import jax.tree_util as jtu
from jax.debug import print as jprint
from utils import *
from viz import *
from models import *
from renderer_setup import *
from inference import *
from genjax.inference.importance_sampling import sampling_importance_resampling
from genjax._src.core.transforms.incremental import NoChange
from genjax._src.core.transforms.incremental import UnknownChange
from genjax._src.core.transforms.incremental import Diff
from genjax._src.core.pytree.utilities import *


# import logging
# logging.basicConfig(filename = "ameya.txt", level = logging.INFO)
# logging.info("somethingggg")
# # jax.disable_jit()
console = genjax.pretty()

In [5]:
gt_path = "../ground_truths/genjax_generated/scene_8.pkl"
metadata = load_metadata(gt_path)
gt_images = metadata["rendered"]
model = eval("model_v{}".format(metadata["model_version"]))
RENDERER_ARGS = metadata["RENDERER_ARGS"]
setup_renderer_and_meshes = eval("setup_renderer_and_meshes_v{}".format(metadata["renderer_setup_version"]))
setup_renderer_and_meshes(**RENDERER_ARGS)

Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)


[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


In [6]:
def pose_update_v3(key, trace_, pose_grid, enumerator):
    
    # print("big comp coming up")
    scores = enumerator.enumerate_choices_get_scores(trace_, key, pose_grid)
    # print("big comp is done")
    return enumerator.update_choices(
        trace_, key,
        pose_grid[scores.argmax()]
    )
pose_update_v3_jit = jax.jit(pose_update_v3, static_argnames=("enumerator",))

def c2f_pose_update_v3(key, trace_, gridding_schedule_stacked, enumerator):

    # print("HI")
    for i in range(gridding_schedule_stacked.shape[0]):
        # print("now")
        trace_ = pose_update_v3_jit(key, trace_, gridding_schedule_stacked[i], enumerator)
        # print("end")
    # trace_ = pose_update_v3(key, trace_, gridding_schedule_stacked[0], enumerator)
    return trace_

c2f_pose_update_v3_vmap_jit = jax.jit(jax.vmap(c2f_pose_update_v3, in_axes=(0,0,None,None)),
                                    static_argnames=("enumerator"))

c2f_pose_update_v3_jit = jax.jit(c2f_pose_update_v3,static_argnames=("enumerator"))

def make_new_keys(key, N_keys):
    key, other_key = jax.random.split(key)
    new_keys = jax.random.split(other_key, N_keys)
    return key, new_keys


def initial_choice_map(metadata):
    return genjax.index_choice_map(
            jnp.arange(0,metadata["T"]+1), genjax.choice_map(
                metadata["CHOICE_MAP_ARGS"]
            )
        )

def update_choice_map(gt, t):
    return genjax.index_choice_map(
            [t], genjax.choice_map(
                {'depth' : jnp.expand_dims(gt[t], axis = 0)}
            )
        )


def argdiffs_modelv5(trace, t):
    """
    Argdiffs specific to modelv5
    """
    # print(trace.args)
    args = trace.get_args()
    argdiffs = (
        Diff(t, UnknownChange),
        jtu.tree_map(lambda v: Diff(v, NoChange), args[1]),
        *jtu.tree_map(lambda v: Diff(v, NoChange), args[2:]),
    )
    return argdiffs

def proposal_choice_map(addresses, args, chm_args):
    addr = addresses[0] # custom defined
    return genjax.index_choice_map(
                    jnp.array([chm_args[0]]),genjax.choice_map({
                        addr: jnp.expand_dims(args[0], axis = 0)
            }))

In [25]:
def inference_approach_G(model, gt, metadata, num_particles = 1):
    # gt = gt_images
    # model = genjax.UnfoldCombinator.new(model, metadata['T']+1)
    # num_particles = 10

    """
    Sequential Importance Sampling on the unfolded HMM model
    with 'dumb' 3D pose enumeration proposal
    """
    # extract data
    T = metadata["T"]
    model_args = tuple(metadata["MODEL_ARGS"].values())
    init_state = (gt[0], metadata["INIT_POSE"], metadata["INIT_VELOCITY"])
    key = jax.random.PRNGKey(metadata["key_number"])
    key, init_keys = make_new_keys(key, num_particles)

    # define functions for SIS/SMC
    init_fn = jax.jit(jax.vmap(model.importance, in_axes=(0, None, None)))
    # jit may not be possible here
    # update_fn = jax.jit(jax.vmap(model.update, in_axes=(0, 0, None, None)))
    update_fn = jax.jit(model.update)
    # removing JIT
    # proposal_fn = jax.vmap(c2f_pose_update_v3, in_axes=(0,0,None,None))#,
                                    # static_argnames=("enumerator"))
    proposal_fn = c2f_pose_update_v3_jit

    # Prepare 3d translation and rotation grid
    grid_widths = [0.2,0.1,0.05, 0.025, 0.0125]
    grid_nums = [(3,3,3),(3,3,3),(3,3,3),(3,3,3),(3,3,3)]
    # grid_widths = [0.05, 0.025, 0.0125]
    # grid_nums = [(3,3,3),(3,3,3),(3,3,3)]
    gridding_schedule = make_schedule_3d(grid_widths,grid_nums, [-jnp.pi/12, jnp.pi/12],10,10,jnp.pi)
    gridding_schedule_stacked = jnp.stack(gridding_schedule)

    # initialize SMC/SIS
    # print("init")
    init_log_weights, init_particles = init_fn(
        init_keys, initial_choice_map(metadata), (0, init_state, *model_args))
    # print("end init")
    # return init_particles

# def scan(f, init, xs, length=None):
#   if xs is None:
#     xs = [None] * length
#   carry = init
#   ys = []
#   for x in xs:
#     carry, y = f(carry, x)
#     ys.append(y)
#   return carry, np.stack(ys)        


    # state = (key, init_log_weights, init_particles)
    def sis_body(state, t):
    # for t in range(1,3):
        # print("t = ", t)
        jprint("step")
        # print("p step")
        # get new keys
        key, particles = state
        key, update_keys = make_new_keys(key, num_particles)
        key, proposal_keys = make_new_keys(key, num_particles)

        argdiffs = argdiffs_modelv5(jtu.tree_map(lambda v: v[0], particles), t)

        # make enumerator for this time step (affects the proposal choice map)
        enumerator = b.make_enumerator([("velocity")], 
                                        chm_builder = proposal_choice_map,
                                        argdiff_f=lambda x: argdiffs,
                                        chm_args = [t])
        # print("update")

        def updater(carry, particle):
            key, _ = carry
            key = jax.random.split(key, 1)
            return (key), update_fn(key, particle, update_choice_map(gt,t), argdiffs)[2]
        particles_unstacked = tree_unstack(particles)

        updated_particles_unstacked = jax.lax.scan(updater, (key,0), particles_unstacked)

        # update model to new depth observation
        # _, log_weight_increments_update, updated_particles, _ = update_fn(
        #     update_keys, particles, update_choice_map(gt, t), argdiffs)
        # print("end update")
        # update model to enumerated proposals and choose the best (given the depth obs)
        # print("proposal")
        # new_particles = proposal_fn(
            # proposal_keys, updated_particles, gridding_schedule_stacked, enumerator)

        
        # jax.tree_util.tree_map(lambda v: v[index_you_want], updated_particles)
        # updated_particles_split = tree_unstack(updated_particles)
        new_particles_unstacked = [proposal_fn(proposal_keys[i], updated_particles_unstacked[i], 
                                            gridding_schedule_stacked, enumerator) for i in range(num_particles)]
        new_particles = tree_stack(new_particles_unstacked)
        

        # print("end proposal")
                
        # get the log increments for set of new particles <unneccesary compute, should absorb this above>
        # NOTE: the increment is from the updated particles
        # log_weight_increments_proposal = [update_fn(proposal_keys[i], updated_particles[i], 
        #                             proposal_choice_map([("velocity", t)], new_particles[i][t]),
        #                             argdiffs_maker_modelv5(init_particles[0]))[1]
        #                             for i in range(num_particles)]    
        # 
        # log_weight_increments_proposal = 0    
        # new_log_weights = log_weights + log_weight_increments_update + log_weight_increments_proposal

        # state = (key, 0, new_particles)


        return (key, new_particles), None

    (_, particles), _ = jax.lax.scan(
        sis_body, (key, init_particles), jnp.arange(1, 51))
    print("SCAN finished")
    rendered_videos = particles.get_retval()[0]
    return particles, rendered_videos


In [26]:
model_unfold = genjax.UnfoldCombinator.new(model, metadata['T']+1)
particles, rendered_videos = inference_approach_G(model_unfold, gt_images, metadata, 1)
# trs = inference_approach_G(model_unfold, gt_images, metadata, 7)

In [9]:
# video_from_rendered(rendered_videos[0])
video_comparison_from_images(rendered_videos[0], gt_images)

In [10]:
argdiffs = argdiffs_modelv5(jtu.tree_map(lambda v: v[0], particles))
p1 = tree_unstack(particles)[0]

key, nkeys = make_new_keys(key, 10)

_,_, tr,_ = jax.vmap(jax.jit(model.update), in_axes = (0,None, None, None))(nkeys, p1, update_choice_map(gt,1), argdiffs)

In [None]:
haha = [jax.tree_util.tree_map(lambda v: v[i], trs) for i in range(7)]
# trs.stack(haha)

In [None]:
from genjax._src.core.pytree.utilities import *

unst = tree_unstack(trs)

tree_stack(unst)

In [None]:
print(np.sum(metadata["scores"]))
print(np.sum([tr.get_score() for tr in trs]))

In [None]:
# video_from_trace(tr, framerate=5, use_retval=True)
# video_from_rendered(gt_images, framerate=5,)
# video_comparison_from_trace(tr,framerate = 5, scale = 4)
video_comparison_from_images(rend, gt_images,framerate = 5, scale = 8)

In [None]:
def inference_approach_F(model, gt, metadata):
    """
    2-step model with NO unfold
    HMM-style
    """
    # Use 3d translation and rotation grid
    grid_widths = [0.2,0.1,0.05, 0.025, 0.0125]
    grid_nums = [(3,3,3),(3,3,3),(3,3,3),(3,3,3),(3,3,3)]
    gridding_schedule = make_schedule_3d(grid_widths,grid_nums, [-jnp.pi/3, jnp.pi/3],20,20,jnp.pi)

    key = jax.random.PRNGKey(metadata["key_number"]+71)
    base_chm = genjax.choice_map(metadata["CHOICE_MAP_ARGS"])
    enumerator = b.make_enumerator(["velocity"])
    pose = metadata["INIT_POSE"]
    velocity = metadata["INIT_VELOCITY"]
    T = metadata["T"]
    traces = []
    model_args = metadata["MODEL_ARGS"]

    for t in range(1,T+1):
        print("t = ", t)
        # force new constaints values to take over
        chm = base_chm.unsafe_merge(genjax.choice_map({
            "depth" : gt[t]
        }))

        model_args["pose"] = pose
        model_args["velocity"] = velocity
        # RESORTING to model.importance as I am having issues with update and choicemaps with unfolds &/or maps
        _, trace = model.importance(key, chm, tuple(model_args.values()))

        # then update trace over all the proposals
        for i, grid in enumerate(gridding_schedule):
            # print("Grid #",i+1)
            trace = c2f_pose_update_v2_jit(trace, key, grid, enumerator)
        pose, velocity = trace.get_retval()[1]
        traces.append(trace)

    # first gt image can be assumed to be known as we have the init pose
    rendered = jnp.stack([gt[0]]+[tr.get_retval()[0] for tr in traces])
    return traces, rendered

In [None]:
metadata['MODEL_ARGS']['vel_params'] = jnp.array([0.0005, 100.0])
trs, rend = inference_approach_F(model, gt_images, metadata)

In [None]:
no_rot = [t.score for t in trs]

In [None]:
metadata['scores']

In [None]:
gt_images[0].shape

In [None]:
data = [b.threedp3_likelihood_per_pixel_old(gt_images[i], rend[i],0.0001,0.0001,1000,3) for i in range(51)]

In [None]:
import plotly.graph_objs as go
import plotly.offline as pyo
import numpy as np

# Generating sample data, replace this with your N by 50 by 50 by 1 array
N = 10  # Number of frames
# data = np.random.rand(N, 50, 50)  # Replace this with your array

frames = []
for i in range(N):
    frame = go.Frame(data=go.Heatmap(z=data[i], colorscale='Viridis'))
    frames.append(frame)

# Adjust the dimensions of the plot to make it more square-shaped
fig = go.Figure(
    data=go.Heatmap(z=data[0], colorscale='Viridis'),
    layout=go.Layout(
        title='Bayes3D Likelihood',
        width=600,  # Adjust these values to make the plot more square-like
        height=600,  # Adjust these values to make the plot more square-like
        updatemenus=[{
            'buttons': [{
                'args': [None, {'frame': {'duration': 500, 'redraw': True}, 'fromcurrent': True}],
                'label': 'Play',
                'method': 'animate'
            }, {
                'args': [[None], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate'}],
                'label': 'Pause',
                'method': 'animate'
            }],
            'direction': 'left',
            'pad': {'r': 10, 't': 87},
            'showactive': False,
            'type': 'buttons',
            'x': 0.1,
            'xanchor': 'right',
            'y': 0,
            'yanchor': 'top'
        }]
    ),
    frames=frames
)

# Display the plot
pyo.iplot(fig)


In [None]:
jnp.expand_dims(jnp.eye(4), axis = 0).shape