In [1]:
import sys
from functools import partial
sys.path.append("../")
import genjax
import time
import numpy as np
from jax_tqdm import scan_tqdm
import bayes3d as b
import jax.numpy as jnp
import bayes3d.genjax
import jax
import jax.tree_util as jtu
from jax.debug import print as jprint
from utils import *
from viz import *
from models import *
from renderer_setup import *
from inference import *
from genjax.inference.importance_sampling import sampling_importance_resampling
from genjax._src.core.transforms.incremental import NoChange
from genjax._src.core.transforms.incremental import UnknownChange
from genjax._src.core.transforms.incremental import Diff
from genjax._src.core.pytree.utilities import *
from dataclasses import dataclass
from genjax.generative_functions.distributions import ExactDensity

import plotly.graph_objs as go
from scipy.spatial.transform import Rotation as R
import imageio
import os

console = genjax.pretty()

In [2]:
intrinsics = b.Intrinsics(
    height=50,
    width=80,
    fx=250.0, fy=250.0,
    cx=25.0, cy=25.0,
    near=0.1, far=20.0
)

b.setup_renderer(intrinsics)
b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/cube.obj"),scaling_factor=0.05)
b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/occulder.obj"),scaling_factor=0.05)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (96, 64, 1024)


In [3]:
gt_path = "../ground_truths/genjax_generated/physics_simple.pkl"
metadata = load_metadata(gt_path)
gt_images = metadata["rendered"][...,:3] # only for physics simple
# model = eval("model_v{}".format(metadata["model_version"]))
model = model_v5d
# RENDERER_ARGS = metadata["RENDERER_ARGS"]
# setup_renderer_and_meshes = eval("setup_renderer_and_meshes_v{}".format(metadata["renderer_setup_version"]))
# setup_renderer_and_meshes(**RENDERER_ARGS)
video_from_rendered(gt_images)

In [4]:
def pose_update_v4(key, trace_, pose_grid, enumerator):
    
    weights = enumerator.enumerate_choices_get_scores(trace_, key, pose_grid)
    return *enumerator.update_choices_with_weight(
        trace_, key,
        pose_grid[weights.argmax()]
    ), pose_grid[weights.argmax()]

pose_update_v4_jit = jax.jit(pose_update_v4, static_argnames=("enumerator",))


def c2f_pose_update_v4(key, trace_, gridding_schedule_stacked, enumerator, t):

    reference_vel = jax.lax.cond(jnp.equal(t,1),lambda:trace_.args[1][1],lambda:trace_["velocity"][t-1])
    for i in range(gridding_schedule_stacked.shape[0]):
        updated_grid = jnp.einsum("ij,ajk->aik", reference_vel, gridding_schedule_stacked[i])
        weight, trace_, reference_vel = pose_update_v4_jit(key, trace_, updated_grid, enumerator)
        
    return weight, trace_

c2f_pose_update_v4_vmap_jit = jax.jit(jax.vmap(c2f_pose_update_v4, in_axes=(0,0,None,None,None)),
                                    static_argnames=("enumerator", "t"))

c2f_pose_update_v4_jit = jax.jit(c2f_pose_update_v4,static_argnames=("enumerator", "t"))

def make_new_keys(key, N_keys):
    key, other_key = jax.random.split(key)
    new_keys = jax.random.split(other_key, N_keys)
    if N_keys > 1:
        return key, new_keys
    else:
        return key, new_keys[0]


def initial_choice_map(metadata):
    return genjax.index_choice_map(
            jnp.arange(0,metadata["T"]+1), 
            genjax.choice_map(metadata["CHOICE_MAP_ARGS"])
        )


# def initial_choice_map(metadata, gt):
#     return genjax.index_choice_map(
#             jnp.arange(0,metadata["T"]+1), genjax.choice_map(
#                 {"depth" : gt,
#                 **metadata["CHOICE_MAP_ARGS"]}
#             )
#         )

def update_choice_map(gt, t):
    return genjax.index_choice_map(
            [t], genjax.choice_map(
                {'depth' : jnp.expand_dims(gt[t], axis = 0)}
            )
        )


# def update_choice_map(gt, t):
#     return genjax.index_choice_map(
#             jnp.arange(t+1), genjax.choice_map(
#                 {'depth' : gt[:t+1]}
#             )
#         )

# update_choice_map_jit = jax.jit(update_choice_map, static_argnames=("t"))

def argdiffs_modelv5(trace, t):
    """
    Argdiffs specific to modelv5
    """
    # print(trace.args)
    args = trace.get_args()
    argdiffs = (
        Diff(args[0], NoChange),
        jtu.tree_map(lambda v: Diff(v, NoChange), args[1]),
        *jtu.tree_map(lambda v: Diff(v, NoChange), args[2:]),
    )
    return argdiffs

def proposal_choice_map(addresses, args, chm_args):
    addr = addresses[0] # custom defined
    return genjax.index_choice_map(
                    jnp.array([chm_args[0]]),genjax.choice_map({
                        addr: jnp.expand_dims(args[0], axis = 0)
            }))

In [5]:
def inference_approach_F2(model, gt, gridding_schedule, init_chm, T, model_args, init_state, key):
    """
    Sequential Importance Sampling on the unfolded HMM model
    with 'dumb' 3D pose enumeration proposal

    WITH JUST ONE PARTICLE
    """
    # extract data

    key, init_key = make_new_keys(key, 1)

    # define functions for SIS/SMC
    init_fn = jax.jit(model.importance)
    update_fn = jax.jit(model.update)
    proposal_fn = c2f_pose_update_v4_jit

    # initialize SMC/SIS
    init_log_weight, init_particle = init_fn(init_key, init_chm, (T, init_state, *model_args))
    argdiffs = argdiffs_modelv5(init_particle, 0)
    _, init_log_weight, init_particle, _ = update_fn(
            init_key, init_particle, update_choice_map(gt, 0), argdiffs)
    

    def smc_body(state, t):
        # get new keys
        print("step")
        jprint("t = {}",t)
        key, log_weight, particle = state
        key, update_key = make_new_keys(key, 1)
        key, proposal_key = make_new_keys(key, 1)

        argdiffs = argdiffs_modelv5(particle, t)

        # make enumerator for this time step (affects the proposal choice map)
        enumerator = b.make_enumerator([("velocity")], 
                                        chm_builder = proposal_choice_map,
                                        argdiff_f=lambda x: argdiffs,
                                        chm_args = [t])

        # update model to new depth observation
        _, update_log_weight, updated_particle, _ = update_fn(
            update_key, particle, update_choice_map(gt, t), argdiffs)

        # propose good poses based on proposal
        proposal_log_weight, new_particle = proposal_fn(
            proposal_key, updated_particle, gridding_schedule, enumerator, t)

        # get weight of particle
        new_log_weight = log_weight + proposal_log_weight + update_log_weight

        return (key, new_log_weight, new_particle), None

    (_, final_log_weight, particle), _ = jax.lax.scan(
        smc_body, (key, init_log_weight, init_particle), jnp.arange(1, T+1))
    print("SCAN finished")
    rendered = particle.get_retval()[0]
    return final_log_weight, particle, rendered

In [32]:
# Prepare 3d translation and rotation grid
# grid_widths = [0.05, 0.025, 0.0125]
# grid_nums = [(3,3,3),(3,3,3),(3,3,3)]
# gridding_schedule = make_schedule_3d(grid_widths,grid_nums, [-jnp.pi/4, jnp.pi/4],10,10,jnp.pi)
# gridding_schedule = jnp.stack(gridding_schedule)

############################################################################
# grid_widths = [0.1, 0.025, 0.00625, 0.0016]
# grid_nums = [(5,5,5),(5,5,5),(5,5,5),(5,5,5)]

# grid_widths = [0.01, 0.005, 0.001, 0.0005]
# grid_nums = [(5,5,5),(5,5,5),(5,5,5),(5,5,5)]
# gridding_schedule_trans = make_schedule_translation_3d(grid_widths, grid_nums)
# # gridding_schedule_rot = [jax.vmap(lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0))(
# #     jax.random.split(jax.random.PRNGKey(28665356), 125)
# # )]
# gridding_schedule = jnp.stack(gridding_schedule_trans)

# translation_deltas = b.utils.make_translation_grid_enumeration(-0.01, -0.01, -0.01, 0.01, 0.01, 0.01, 10, 10, 10)

# # translation_deltas = translation_deltas.at[-1].set(poses[1])

# rotation_deltas = jax.vmap(lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0))(
#     jax.random.split(jax.random.PRNGKey(3), 10**3)
# )
# rotation_deltas = rotation_deltas.at[-1].set(jnp.eye(4))

# gridding_schedule = jnp.stack([translation_deltas, rotation_deltas])



# gridding_schedule = jnp.stack(gridding_schedule_trans + gridding_schedule_rot)
##########################################################################

# grid_widths = [0.1, 0.01, 0.001]
# grid_nums = [(13,13,13),(13,13,13),(13,13,13)]
# gridding_schedule_trans = make_schedule_translation_3d(grid_widths, grid_nums)
# gridding_schedule_rot = [jnp.concatenate((b.utils.make_rotation_grid_enumeration(100, 21, -jnp.pi/4, jnp.pi/4, jnp.pi/4), jnp.stack([gridding_schedule_trans[-1][-1]]*97, axis=0)), axis=0)]

# # gridding_schedule_rot = [jax.vmap(lambda key: b.distributions.gaussian_vmf_zero_mean(key, 0.00001, 800.0))(
# #     jax.random.split(jax.random.PRNGKey(28665356), 125)
# # )]
# gridding_schedule = jnp.stack(gridding_schedule_trans + gridding_schedule_rot)
##########################################################################

# grid_widths = [0.01, 0.0025, 0.000675]
# grid_nums = [(5,5,5),(5,5,5),(5,5,5)]
# gridding_schedule = make_schedule_translation_3d(grid_widths, grid_nums)
# gridding_schedule = jnp.stack(gridding_schedule)

# grid = b.utils.make_rotation_grid_enumeration(100, 25, -jnp.pi/12, jnp.pi/12, jnp.pi)
# gridding_schedule = jnp.stack([grid])

grid_widths = [0.01]
grid_nums = [(10,10,10)]
gridding_schedule_trans = make_schedule_translation_3d(grid_widths, grid_nums)
# gridding_schedule_rot = [jnp.concatenate((b.utils.make_rotation_grid_enumeration(100, 21, -jnp.pi/4, jnp.pi/4, jnp.pi/4), jnp.stack([gridding_schedule_trans[-1][-1]]*97, axis=0)), axis=0)]
gridding_schedule_rot = [b.utils.make_rotation_grid_enumeration(25, 40, -jnp.pi/12, jnp.pi/12, jnp.pi/12)]
gridding_schedule = jnp.stack(gridding_schedule_trans + gridding_schedule_rot)

# T = metadata['T'] - 1
T = metadata['T']

MODEL_ARGS = {
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.0005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0]),
    'occ_pose' : metadata['occ_poses'][0]
}


# choice map args
CHOICE_MAP_ARGS = {
    'variance': jnp.repeat(0.0001,T+1),
    'outlier_prob': jnp.repeat(0.0001,T+1),
    'indices': jnp.tile(jnp.array([[0]]), (T+1,1))
}


metadata["CHOICE_MAP_ARGS"] = CHOICE_MAP_ARGS

init_chm = initial_choice_map(metadata)

model_args = tuple(MODEL_ARGS.values())
margs = list(model_args)
margs[3] = jnp.array([1, 0.00001])
margs = tuple(margs)

prev_pose = jnp.linalg.inv(jnp.linalg.solve(metadata["poses"][0], metadata["poses"][1])) @ metadata["poses"][0]

init_state = (gt_images[0], prev_pose, metadata["poses"][0])
# key = jax.random.PRNGKey(metadata["key_number"])
key = jax.random.PRNGKey(45675456)

model_unfold = genjax.UnfoldCombinator.new(model, metadata['T']+1)
inference_approach_F2_jit = jax.jit(inference_approach_F2, static_argnames=("T"))
start = time.time()
lw, tr, rendered = inference_approach_F2_jit(model_unfold, gt_images, gridding_schedule, init_chm, T, margs, init_state, key)
print ("FPS:", rendered.shape[0] / (time.time() - start))
rendered = rendered.at[0].set(gt_images[0])

t = 1
t = 2
t = 3
t = 4
t = 5
t = 6
t = 7
t = 8
t = 9
t = 10
t = 11
t = 12
t = 13
t = 14
t = 15
t = 16
t = 17
t = 18
t = 19
t = 20
t = 21
t = 22
t = 23
t = 24
t = 25
t = 26
t = 27
t = 28
t = 29
t = 30
t = 31
t = 32
t = 33
t = 34
t = 35
t = 36
t = 37
t = 38
t = 39
t = 40
t = 41
t = 42
t = 43
t = 44
t = 45
t = 46
t = 47
t = 48
t = 49
t = 50
t = 51
FPS: 1.1006864327994268


In [21]:
video_comparison_from_images(rendered, gt_images, framerate=10, scale = 8)

In [33]:
gt_poses = metadata['poses'][1:]
inferred_poses = tr.get_retval()[1][1:-1]
occ_pose = metadata['occ_poses'][0]

gt_images_no_occ = b.RENDERER.render_many(gt_poses[:,None,...],  jnp.array([0]))[...,2]
inferred_images_no_occ = b.RENDERER.render_many(inferred_poses[:,None,...],  jnp.array([0]))[...,2]
occ_image = b.RENDERER.render(occ_pose[None,...],  jnp.array([1]))[...,2]
gt_ims = [b.overlay_image(b.get_depth_image(gt_images_no_occ[i]), b.get_depth_image(occ_image), alpha = 0.3) for i in range(gt_poses.shape[0])]
inferred_images = [b.overlay_image(b.get_depth_image(inferred_images_no_occ[i]), b.get_depth_image(occ_image), alpha = 0.3) for i in range(gt_poses.shape[0])]


images = [
    b.viz.multi_panel(
        [
            b.scale_image(inferred_im, 6),
            b.scale_image(gt_im, 6)
        ],
        labels=["Inferred", "Observed"],
        label_fontsize=20)
    for (inferred_im, gt_im) in zip(inferred_images, gt_ims)
]
display_video(images, framerate=10)

In [None]:
# ALL POST PROCESSING
gt_poses = metadata['poses']
inferred_poses = tr.get_retval()[1].at[0].set(metadata["INIT_POSE"])
plot_3d_poses_x([gt_poses, inferred_poses], ["Ground Truth", "Inferred"], name = "default", fps = 10, save = True)

plot_bayes3d_likelihood(gt_images, rendered, fps = 10, ll_f = b.threedp3_likelihood_per_pixel_old, ll_f_args = (0.001,0.001,1000,3))
# # video_comparison_from_images(rendered, gt_images)

In [31]:
print(b.distributions.gaussian_vmf_logpdf(jnp.eye(4), gt_poses[2], 1, 0.00001))
print(b.distributions.gaussian_vmf_logpdf(gt_poses[3], gt_poses[2], 1, 0.00001))

-7.791864
-5.739464


In [37]:
inferred_poses[0]


[1;35mArray[0m[1m([0m[1m[[0m[1m[[0m [1;36m0.9659258[0m ,  [1;36m0.25881904[0m,  [1;36m0[0m.        ,  [1;36m0.36999997[0m[1m][0m,
       [1m[[0m[1;36m-0.25881904[0m,  [1;36m0.9659258[0m ,  [1;36m0[0m.        ,  [1;36m0.04[0m      [1m][0m,
       [1m[[0m [1;36m0[0m.        ,  [1;36m0[0m.        ,  [1;36m1[0m.        ,  [1;36m1.99[0m      [1m][0m,
       [1m[[0m [1;36m0[0m.        ,  [1;36m0[0m.        ,  [1;36m0[0m.        ,  [1;36m1[0m.        [1m][0m[1m][0m,      [33mdtype[0m=[35mfloat32[0m[1m)[0m

In [28]:
gt_poses[3]


[1;35mArray[0m[1m([0m[1m[[0m[1m[[0m[1;36m1[0m.        , [1;36m0[0m.        , [1;36m0[0m.        , [1;36m0.31000003[0m[1m][0m,
       [1m[[0m[1;36m0[0m.        , [1;36m1[0m.        , [1;36m0[0m.        , [1;36m0.05[0m      [1m][0m,
       [1m[[0m[1;36m0[0m.        , [1;36m0[0m.        , [1;36m1[0m.        , [1;36m2[0m.        [1m][0m,
       [1m[[0m[1;36m0[0m.        , [1;36m0[0m.        , [1;36m0[0m.        , [1;36m1[0m.        [1m][0m[1m][0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m

In [None]:
def plot_3d_poses_x(poses_list, poses_names = ["Ground Truth", "Inferred"], name = "default", fps = 10, save = False):
    """
    poses_list is an N-length list of T x 4 x 4 poses
    pose_names is an N-length list of names, same len as poses_list
    """

    assert len(poses_list) == len(poses_names)

    num_paths = len(poses_list)
    T = poses_list[0].shape[0]

    def get_walk(poses_list):
        walks = np.zeros((num_paths,T,3))
        for i in range(num_paths):
            for j in range(T):
                walks[i,j,:] = poses_list[i][j,:3,3]
        return walks
    
    walks = get_walk(poses_list)

    # Generate unique colors for each walk
    colors = ['rgba(255,0,0,0.8)']
    if num_paths == 2:
        colors = colors + ['rgba(0,0,255,0.8)']
    elif num_paths > 2:
     colors = colors + [f'rgba({np.random.randint(0, 255)}, {np.random.randint(0, 255)}, {np.random.randint(0, 255)}, 0.8)' for _ in range(num_paths - 1)]

    # Find the axis ranges based on the walks
    all_walks = walks.reshape(-1, 3)  # Reshape for simplicity
    x_range = [all_walks[:, 0].min(), all_walks[:, 0].max()]
    y_range = [all_walks[:, 1].min(), all_walks[:, 1].max()]
    z_range = [all_walks[:, 2].min(), all_walks[:, 2].max()]

    y_asp = (y_range[1]-y_range[0])/(x_range[1]-x_range[0])
    z_asp = (z_range[1]-z_range[0])/(x_range[1]-x_range[0])

    # Define a function to create an arrow at a given point and direction
    def create_arrow(point, direction, color='red', length_scale=1, showlegend=False):
        length = length_scale * np.linalg.norm([x_range[1] - x_range[0], y_range[1] - y_range[0], z_range[1] - z_range[0]]) * 0.1
        # Normalize the direction
        direction = direction / np.linalg.norm(direction)
        
        # Create the arrow components (shaft and head)
        shaft = go.Scatter3d(
            x=[point[0], point[0] + direction[0] * length],
            y=[point[1], point[1] + direction[1] * length],
            z=[point[2], point[2] + direction[2] * length],
            mode='lines',
            line=dict(color=color, width=4),
            showlegend=showlegend
        )
        
        # head = go.Cone(
        #     x=[point[0] + direction[0] * length],
        #     y=[point[1] + direction[1] * length],
        #     z=[point[2] + direction[2] * length],
        #     u=[direction[0]],
        #     v=[direction[1]],
        #     w=[direction[2]],
        #     sizemode='absolute',
        #     sizeref=0.001,
        #     anchor='tip',
        #     colorscale=[[0, color], [1, color]],
        #     showscale=False,
        #     showlegend=showlegend
        # )
        
        # return shaft, head
        return shaft, None

    # Create a directory for frames
    if not os.path.exists("frames"):
        os.mkdir("frames")

    # Generate each frame and save as an image file
    image_files = []

    # Generating frames and saving as image files, the loop for this
    for i in range(T):
        frame_data = []
        for w in range(num_paths):
            rotation_matrix = poses_list[w][i,:3,:3]

            # eigenvalues, eigenvectors = np.linalg.eig(rotation_matrix)

            # if i == 49 and w == 0:
            #     print(rotation_matrix)
            # # The axis of rotation is the eigenvector corresponding to the eigenvalue of 1
            # axis_of_rotation = eigenvectors[:, np.isclose(eigenvalues, 1)]



            # # Make sure it's a unit vector
            # axis_of_rotation = axis_of_rotation / np.linalg.norm(axis_of_rotation)

            # Compute the eigenvalues and right eigenvectors
            eigenvalues, eigenvectors = np.linalg.eig(rotation_matrix)

            # Find the index of the eigenvalue that is closest to 1
            index_of_one = np.argmin(np.abs(eigenvalues - 1))

            # The axis of rotation is the eigenvector corresponding to the eigenvalue closest to 1
            axis_of_rotation = eigenvectors[:, index_of_one]

            # Normalize the axis vector to make it a unit vector
            axis_of_rotation = axis_of_rotation / np.linalg.norm(axis_of_rotation)

            # print(axis_of_rotation, i,w)

            arrow_direction = np.real(axis_of_rotation)

            # print(arrow_direction)
            # print(T,w)
            # print(arrow_direction)

            # arrow_direction = rotation_matrix @ np.array([1, 0, 0])
            # print(arrow_direction)
            
            # Pass showlegend as False to the create_arrow function
            shaft, head = create_arrow(walks[w, i], arrow_direction, color=colors[w], showlegend=False)
            
            # Add name and legendgroup to the Scatter3d trace
            trace = go.Scatter3d(
                x=walks[w, :i+1, 0],
                y=walks[w, :i+1, 1],
                z=walks[w, :i+1, 2],
                mode='markers+lines',
                marker=dict(size=5, color=colors[w]),
                line=dict(color=colors[w], width=2),
                name=poses_names[w],  # Name for legend
                legendgroup=poses_names[w],  # Same legendgroup for walk dots and arrows
            )
            # Add only for the first frame to avoid duplicate legend entries
            if i == 0:
                trace.legendgrouptitle = dict(text=poses_names[w])

            frame_data.extend([trace, shaft])
        
        # Define the figure for the current frame
        fig = go.Figure(
            data=frame_data,
            layout=go.Layout(
                scene=dict(
                    xaxis=dict(range=x_range, autorange=False),
                    yaxis=dict(range=y_range, autorange=False),
                    zaxis=dict(range=z_range, autorange=False),
                    # aspectratio=dict(x=1, y=y_asp, z=z_asp),
                    aspectratio=dict(x=1, y=1, z=1),
                    camera=dict(
                    eye=dict(x=1.25, y=-1.25, z=-1.25),
                    up=dict(x=0, y=-1, z=0),
                    center=dict(x=0, y=0, z=0)
                )
                ),
                margin=dict(l=0, r=0, t=0, b=0)  # Reduce white space around the plot
            )
        )
        if save:
            # Save the figure as an image file
            img_file = f'frames/frame_{i:03d}.png'
            fig.write_image(img_file)
            image_files.append(img_file)

    if not save:
        fig.show()
    else:     
        # # Create a GIF using the saved image files
        with imageio.get_writer(f'{name}.gif', mode='I', fps=fps, loop = 0) as writer:
            for filename in image_files:
                image = imageio.imread(filename)
                writer.append_data(image)
                # Optionally, remove the image file after adding it to the GIF
                os.remove(filename)  

        # Clean up the frames directory if desired
        os.rmdir("frames")

        print(f"GIF saved as '{name}.gif'")