In [1]:
import sys
from functools import partial
sys.path.append("../")
import genjax
import time
import numpy as np
from jax_tqdm import scan_tqdm
import bayes3d as b
import jax.numpy as jnp
import bayes3d.genjax
import jax
import jax.tree_util as jtu
from jax.debug import print as jprint
from utils import *
from viz import *
from models import *
from renderer_setup import *
from inference import *
from genjax.inference.importance_sampling import sampling_importance_resampling
from genjax._src.core.transforms.incremental import NoChange
from genjax._src.core.transforms.incremental import UnknownChange
from genjax._src.core.transforms.incremental import Diff
from genjax._src.core.pytree.utilities import *
from dataclasses import dataclass
from genjax.generative_functions.distributions import ExactDensity

import plotly.graph_objs as go
from scipy.spatial.transform import Rotation as R
import imageio
import os

console = genjax.pretty()

In [2]:
gt_path = "../ground_truths/genjax_generated/scenes/scene_1.pkl"
metadata = load_metadata(gt_path)
gt_images = metadata["rendered"]
# model = eval(metadata["model_name"])
model = model_single_object
RENDERER_ARGS = metadata["RENDERER_ARGS"]
setup_renderer_and_meshes = eval("setup_renderer_and_meshes_v{}".format(metadata["renderer_setup_version"]))
setup_renderer_and_meshes(**RENDERER_ARGS)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)


In [3]:
def pose_update_v4(key, trace_, pose_grid, enumerator):
    
    weights = enumerator.enumerate_choices_get_scores(trace_, key, pose_grid)
    sampled_idx = weights.argmax() 
    # sampled_idx = jax.random.categorical(key, weights)

    return *enumerator.update_choices_with_weight(
        trace_, key,
        pose_grid[sampled_idx]
    ), pose_grid[sampled_idx]

pose_update_v4_jit = jax.jit(pose_update_v4, static_argnames=("enumerator",))


def c2f_pose_update_v4(key, trace_, gridding_schedule_stacked, enumerator, t):

    # reference_vel = jax.lax.cond(jnp.equal(t,1),lambda:trace_.args[1][1],lambda:trace_["velocity"][t-1])
    reference_vel = trace_["velocity"][t-1]
    for i in range(len(gridding_schedule_stacked)):
        updated_grid = jnp.einsum("ij,ajk->aik", reference_vel, gridding_schedule_stacked[i])
        weight, trace_, reference_vel = pose_update_v4_jit(key, trace_, updated_grid, enumerator)
        
    return weight, trace_

c2f_pose_update_v4_vmap_jit = jax.jit(jax.vmap(c2f_pose_update_v4, in_axes=(0,0,None,None,None)),
                                    static_argnames=("enumerator", "t"))

c2f_pose_update_v4_jit = jax.jit(c2f_pose_update_v4,static_argnames=("enumerator", "t"))

def make_new_keys(key, N_keys):
    key, other_key = jax.random.split(key)
    new_keys = jax.random.split(other_key, N_keys)
    if N_keys > 1:
        return key, new_keys
    else:
        return key, new_keys


# def initial_choice_map(metadata):
#     return genjax.index_choice_map(
#             jnp.arange(0,metadata["T"]+1), 
#             genjax.choice_map(metadata["CHOICE_MAP_ARGS"]).unsafe_merge(
#                 genjax.choice_map({"velocity" : np.tile(metadata['INIT_VELOCITY'][None,...],(metadata['T']+1, 1, 1))})
#             )
#         )

# def initial_choice_map(metadata):
#     return genjax.index_choice_map(
#             [0],
#             genjax.choice_map(metadata["CHOICE_MAP_ARGS"])
#         )

# def initial_choice_map(metadata, gt):
#     return genjax.index_choice_map(
#             jnp.arange(0,metadata["T"]+1), genjax.choice_map(
#                 {"depth" : gt,
#                 **metadata["CHOICE_MAP_ARGS"]}
#             )
#         )

def init_choice_map(gt_depths, constant_choices, init_state):
    constant_choices['depth'] = gt_depths[0][None,...]
    constant_choices['velocity'] = jnp.expand_dims(init_state[-1], axis = 0)
    return genjax.index_choice_map(
            [0], genjax.choice_map(
                constant_choices
            )
        )

def update_choice_map(gt_depths, constant_choices, t):
    constant_choices['depth'] = gt_depths[t][None,...]
    return genjax.index_choice_map(
            [t], genjax.choice_map(
                constant_choices
            )
        )

def argdiffs_modelv6(trace, t):
    """
    Argdiffs specific to modelv6
    """
    # print(trace.args)
    args = trace.get_args()
    argdiffs = (
        Diff(t, UnknownChange),
        jtu.tree_map(lambda v: Diff(v, NoChange), args[1]),
        *jtu.tree_map(lambda v: Diff(v, NoChange), args[2:]),
    )
    return argdiffs

def proposal_choice_map(addresses, args, chm_args):
    addr = addresses[0] # custom defined
    return genjax.index_choice_map(
                    jnp.array([chm_args[0]]),genjax.choice_map({
                        addr: jnp.expand_dims(args[0], axis = 0)
            }))

In [6]:
def inference_approach_G(model, gt, gridding_schedule, init_chm, model_args, init_state, key, constant_choices, T, num_particles = 1):
    """
    Sequential Importance Sampling on the unfolded HMM model
    with 'dumb' 3D pose enumeration proposal

    WITH JUST ONE PARTICLE
    """
    # extract data

    key, init_keys = make_new_keys(key, num_particles)

    # define functions for SIS/SMC
    init_fn = jax.vmap(model.importance, in_axes=(0,None,None))
    # update_fn = jax.vmap(model.update, in_axes=(0, 0, None, None))
    update_fn = model.update
    proposal_fn = c2f_pose_update_v4_jit

    # initialize SMC/SIS
    init_log_weights, init_particles = init_fn(init_keys, init_choice_map(gt,constant_choices,init_state), (0, init_state, *model_args))


    def SMC_step(state, t):
        # get new keys
        # jprint("t = {}",t)
        key, log_weights, particles = state
        key, resample_key = jax.random.split(key)
        key, update_key = jax.random.split(key)
        key, proposal_key = jax.random.split(key)

        argdiffs = argdiffs_modelv6(jax.tree_util.tree_map(lambda v: v[0], particles),t)

        # make enumerator for this time step (affects the proposal choice map)
        enumerator = b.make_enumerator([("velocity")], 
                                        chm_builder = proposal_choice_map,
                                        argdiff_f=lambda x: argdiffs,
                                        chm_args = [t])
        
        
        # Resampling at every time step
        sampled_indices = jax.random.categorical(resample_key, log_weights, shape=(num_particles,))
        resampled_particles = jtu.tree_map(lambda v: v[sampled_indices], particles)
        
        def updater(key, particle):
            key, new_key = jax.random.split(key)
            return new_key, update_fn(new_key, particle, update_choice_map(gt,constant_choices,t), argdiffs)[1:3]
        _, (update_weights, updated_particles) = jax.lax.scan(updater, update_key, resampled_particles)


        def proposer(key, particle):
            key, new_key = jax.random.split(key)
            return new_key, proposal_fn(new_key, particle, gridding_schedule, enumerator, t)
        _, (proposal_weights, proposed_particles) = jax.lax.scan(proposer, proposal_key, updated_particles)

        # get weight of particles
        new_log_weights = log_weights + update_weights + proposal_weights#  + proposal_log_weight + 

        return (key, new_log_weights, proposed_particles), None

    (_, final_log_weights, particles), _ = jax.lax.scan(
        SMC_step, (key, init_log_weights, init_particles), jnp.arange(1, T+1))
    print("SCAN finished")
    rendered_particles = particles.get_retval()[0]
    return final_log_weights, particles, rendered_particles

In [8]:
grid_widths = [0.02, 0.01,0.005]
grid_nums = [(3,3,3),(3,3,3),(3,3,3)]
gridding_schedule_trans = make_schedule_translation_3d(grid_widths, grid_nums)
# gridding_schedule_rot = [jnp.concatenate((b.utils.make_rotation_grid_enumeration(100, 21, -jnp.pi/4, jnp.pi/4, jnp.pi/4), jnp.stack([gridding_schedule_trans[-1][-1]]*97, axis=0)), axis=0)]
gridding_schedule_rot = [b.utils.make_rotation_grid_enumeration(10, 15, -jnp.pi/12, jnp.pi/12, jnp.pi/12)]
gridding_schedule = [gridding_schedule_trans[0], gridding_schedule_trans[1], gridding_schedule_trans[2], gridding_schedule_rot[0]]
# gridding_schedule = jnp.stack(gridding_schedule_trans + gridding_schedule_rot)
# gridding_schedule = jnp.stack(gridding_schedule_rot)

T = metadata['T']

rend_idx = 0
# metadata["CHOICE_MAP_ARGS"]['indices'] = jnp.tile(jnp.array([[rend_idx]]), (T+1,1))
CONSTANT_CHOICES = {
    'variance': jnp.array([0.01]),
    'outlier_prob': jnp.array([0.0001]),
    'indices': jnp.array([[rend_idx]])
}

init_chm = update_choice_map(gt_images, CONSTANT_CHOICES, 0)

model_args = tuple(metadata["MODEL_ARGS"].values())
margs = list(model_args)
# margs[2] = jnp.array([1, 0.0000000001])
margs[2] = jnp.array([1e+20, 0])
margs = tuple(margs)

init_state = (gt_images[0], metadata["INIT_POSE"], metadata["INIT_VELOCITY"])
# init_state = (gt_images[0], metadata["INIT_POSE"])
# key = jax.random.PRNGKey(metadata["key_number"])
key = jax.random.PRNGKey(45675456)

model_unfold = genjax.UnfoldCombinator.new(model, metadata['T']+1)
inference_approach_G_jit = jax.jit(inference_approach_G, static_argnames=("T", "num_particles"))
start = time.time()
lw, tr, rendered_particles = inference_approach_G_jit(model_unfold, gt_images, gridding_schedule, init_chm, margs, init_state, key, CONSTANT_CHOICES, T,20)
print ("FPS:", T+1 / (time.time() - start))

running model jprint
(20, 1, 1, 4, 4) and (1,)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
running model jprint
(20, 1, 1, 4, 4) and (20, 1)
run

In [9]:
xx = tr.strip()['depth'][0]
yy = tr.get_retval()[0][0]

In [10]:
video_comparison_from_images(yy,xx)

In [None]:
video_comparison_from_images(rendered, gt_images, framerate=10, scale = 8)

In [None]:
start = time.time()
lw, tr, rendered = inference_approach_G_jit(model_unfold, gt_images, gridding_schedule, init_chm, margs, init_state, key, CONSTANT_CHOICES, T)
print ("FPS:", rendered.shape[0] / (time.time() - start))

In [None]:
# ALL POST PROCESSING
gt_poses = metadata['poses']
# inferred_poses = tr.get_retval()[1].at[0].set(metadata["INIT_POSE"])
inferred_poses = tr.get_retval()[1]
plot_3d_poses_x([gt_poses, inferred_poses], ["Ground Truth", "Inferred"], name = "default", fps = 10, save = False)

# plot_bayes3d_likelihood(gt_images, rendered, fps = 10, ll_f = b.threedp3_likelihood_per_pixel_old, ll_f_args = (0.001,0.001,1000,3))
# # video_comparison_from_images(rendered, gt_images)

In [None]:
a = jax.random.split(key, 1)
a

In [None]:
def plot_3d_poses_x(poses_list, poses_names = ["Ground Truth", "Inferred"], name = "default", fps = 10, save = False):
    """
    poses_list is an N-length list of T x 4 x 4 poses
    pose_names is an N-length list of names, same len as poses_list
    """

    assert len(poses_list) == len(poses_names)

    num_paths = len(poses_list)
    T = poses_list[0].shape[0]

    def get_walk(poses_list):
        walks = np.zeros((num_paths,T,3))
        for i in range(num_paths):
            for j in range(T):
                walks[i,j,:] = poses_list[i][j,:3,3]
        return walks
    
    walks = get_walk(poses_list)

    # Generate unique colors for each walk
    colors = ['rgba(255,0,0,0.8)']
    if num_paths == 2:
        colors = colors + ['rgba(0,0,255,0.8)']
    elif num_paths > 2:
     colors = colors + [f'rgba({np.random.randint(0, 255)}, {np.random.randint(0, 255)}, {np.random.randint(0, 255)}, 0.8)' for _ in range(num_paths - 1)]

    # Find the axis ranges based on the walks
    all_walks = walks.reshape(-1, 3)  # Reshape for simplicity
    x_range = [all_walks[:, 0].min(), all_walks[:, 0].max()]
    y_range = [all_walks[:, 1].min(), all_walks[:, 1].max()]
    z_range = [all_walks[:, 2].min(), all_walks[:, 2].max()]

    y_asp = (y_range[1]-y_range[0])/(x_range[1]-x_range[0])
    z_asp = (z_range[1]-z_range[0])/(x_range[1]-x_range[0])

    # Define a function to create an arrow at a given point and direction
    def create_arrow(point, direction, color='red', length_scale=1, showlegend=False):
        length = length_scale * np.linalg.norm([x_range[1] - x_range[0], y_range[1] - y_range[0], z_range[1] - z_range[0]]) * 0.1
        # Normalize the direction
        direction = direction / np.linalg.norm(direction)
        
        # Create the arrow components (shaft and head)
        shaft = go.Scatter3d(
            x=[point[0], point[0] + direction[0] * length],
            y=[point[1], point[1] + direction[1] * length],
            z=[point[2], point[2] + direction[2] * length],
            mode='lines',
            line=dict(color=color, width=4),
            showlegend=showlegend
        )
        
        # head = go.Cone(
        #     x=[point[0] + direction[0] * length],
        #     y=[point[1] + direction[1] * length],
        #     z=[point[2] + direction[2] * length],
        #     u=[direction[0]],
        #     v=[direction[1]],
        #     w=[direction[2]],
        #     sizemode='absolute',
        #     sizeref=0.001,
        #     anchor='tip',
        #     colorscale=[[0, color], [1, color]],
        #     showscale=False,
        #     showlegend=showlegend
        # )
        
        # return shaft, head
        return shaft, None

    # Create a directory for frames
    if not os.path.exists("frames"):
        os.mkdir("frames")

    # Generate each frame and save as an image file
    image_files = []

    # Generating frames and saving as image files, the loop for this
    for i in range(T):
        frame_data = []
        for w in range(num_paths):
            rotation_matrix = poses_list[w][i,:3,:3]

            # eigenvalues, eigenvectors = np.linalg.eig(rotation_matrix)

            # if i == 49 and w == 0:
            #     print(rotation_matrix)
            # # The axis of rotation is the eigenvector corresponding to the eigenvalue of 1
            # axis_of_rotation = eigenvectors[:, np.isclose(eigenvalues, 1)]



            # # Make sure it's a unit vector
            # axis_of_rotation = axis_of_rotation / np.linalg.norm(axis_of_rotation)

            # Compute the eigenvalues and right eigenvectors
            eigenvalues, eigenvectors = np.linalg.eig(rotation_matrix)

            # Find the index of the eigenvalue that is closest to 1
            index_of_one = np.argmin(np.abs(eigenvalues - 1))

            # The axis of rotation is the eigenvector corresponding to the eigenvalue closest to 1
            axis_of_rotation = eigenvectors[:, index_of_one]

            # Normalize the axis vector to make it a unit vector
            axis_of_rotation = axis_of_rotation / np.linalg.norm(axis_of_rotation)

            # print(axis_of_rotation, i,w)

            arrow_direction = np.real(axis_of_rotation)

            # print(arrow_direction)
            # print(T,w)
            # print(arrow_direction)

            # arrow_direction = rotation_matrix @ np.array([1, 0, 0])
            # print(arrow_direction)
            
            # Pass showlegend as False to the create_arrow function
            shaft, head = create_arrow(walks[w, i], arrow_direction, color=colors[w], showlegend=False)
            
            # Add name and legendgroup to the Scatter3d trace
            trace = go.Scatter3d(
                x=walks[w, :i+1, 0],
                y=walks[w, :i+1, 1],
                z=walks[w, :i+1, 2],
                mode='markers+lines',
                marker=dict(size=5, color=colors[w]),
                line=dict(color=colors[w], width=2),
                name=poses_names[w],  # Name for legend
                legendgroup=poses_names[w],  # Same legendgroup for walk dots and arrows
            )
            # Add only for the first frame to avoid duplicate legend entries
            if i == 0:
                trace.legendgrouptitle = dict(text=poses_names[w])

            frame_data.extend([trace, shaft])
        
        # Define the figure for the current frame
        fig = go.Figure(
            data=frame_data,
            layout=go.Layout(
                scene=dict(
                    xaxis=dict(range=x_range, autorange=False),
                    yaxis=dict(range=y_range, autorange=False),
                    zaxis=dict(range=z_range, autorange=False),
                    # aspectratio=dict(x=1, y=y_asp, z=z_asp),
                    aspectratio=dict(x=1, y=1, z=1),
                    camera=dict(
                    eye=dict(x=1.25, y=-1.25, z=-1.25),
                    up=dict(x=0, y=-1, z=0),
                    center=dict(x=0, y=0, z=0)
                )
                ),
                margin=dict(l=0, r=0, t=0, b=0)  # Reduce white space around the plot
            )
        )
        if save:
            # Save the figure as an image file
            img_file = f'frames/frame_{i:03d}.png'
            fig.write_image(img_file)
            image_files.append(img_file)

    if not save:
        fig.show()
    else:     
        # # Create a GIF using the saved image files
        with imageio.get_writer(f'{name}.gif', mode='I', fps=fps, loop = 0) as writer:
            for filename in image_files:
                image = imageio.imread(filename)
                writer.append_data(image)
                # Optionally, remove the image file after adding it to the GIF
                os.remove(filename)  

        # Clean up the frames directory if desired
        os.rmdir("frames")

        print(f"GIF saved as '{name}.gif'")