In [1]:
import sys
import jax
import time
import genjax
import bayes3d as b
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("../")
from viz import *
from utils import *
from mcs_utils import *
from PIL import Image
import bayes3d.transforms_3d as t3d
from jax.debug import print as jprint
from tqdm import tqdm
import jax.tree_util as jtu
from genjax._src.core.transforms.incremental import NoChange, UnknownChange, Diff
console = genjax.pretty()

In [2]:
# Loading and preprocessing all data and renderer
SCALE = 0.1
cam_pose = CAM_POSE_CV2
inverse_cam_pose = jnp.linalg.inv(CAM_POSE_CV2)
observations = load_observations_npz('passive_physics_validation_object_permanence_0001_01')
gt_images, gt_images_bg, gt_images_obj, intrinsics, registered_objects = preprocess_mcs_physics_scene(observations, MIN_DIST_THRESH=0.6, scale=SCALE)
b.setup_renderer(intrinsics)
for registered_obj in registered_objects:
    b.RENDERER.add_mesh(registered_obj['mesh'])
video_from_rendered(gt_images, scale = int(1/SCALE), framerate=30)

Extracting Meshes


 55%|█████▌    | 132/240 [00:08<00:02, 44.82it/s]

Adding review
Review passed, added to init queue


 57%|█████▋    | 136/240 [00:21<01:25,  1.22it/s]

Adding new mesh for t = {} 132


100%|██████████| 240/240 [00:35<00:00,  6.78it/s]


Extracting downsampled data


[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)
Centering mesh with translation [ 0.00250025  0.00500003 -0.00500049]


In [3]:
# model time!

def get_bottom_most_height(i, world_pose):
    # Half dimensions to get the corner points relative to the center
    rotation_matrix = world_pose[:3,:3]
    center = world_pose[:3,3]
    dimensions = b.RENDERER.model_box_dims[i]
    half_dims = dimensions / 2

    # Local corner points of the box in its local coordinate system
    local_corners = jnp.array([
        [-half_dims[0], -half_dims[1], -half_dims[2]],  # Lower rear left corner
        [ half_dims[0], -half_dims[1], -half_dims[2]],  # Lower rear right corner
        [-half_dims[0],  half_dims[1], -half_dims[2]],  # Lower front left corner
        [ half_dims[0],  half_dims[1], -half_dims[2]],  # Lower front right corner
        [-half_dims[0], -half_dims[1],  half_dims[2]],  # Upper rear left corner
        [ half_dims[0], -half_dims[1],  half_dims[2]],  # Upper rear right corner
        [-half_dims[0],  half_dims[1],  half_dims[2]],  # Upper front left corner
        [ half_dims[0],  half_dims[1],  half_dims[2]]   # Upper front right corner
    ])

    # Apply rotation to each corner point
    global_corners = jnp.stack([center + rotation_matrix @ corner for corner in local_corners])

    # Find the bottom-most point
    bottom_most_point_z = jnp.min(global_corners[:,2])
    # distance from centre of bbox to bottom of bbox
    center_to_bottom_dist = center[2] - bottom_most_point_z
    return bottom_most_point_z, center_to_bottom_dist

def update_vel_friction(vel_world, friction):
    deltax = vel_world[0,3]
    deltay = vel_world[1,3]

    deltax = deltax - friction*deltax
    deltay = deltay - friction*deltay

    deltax = jax.lax.cond(
        jnp.less_equal(jnp.abs(deltax),5e-4),
        lambda:0.0,
        lambda:deltax)
    
    deltay = jax.lax.cond(
        jnp.less_equal(jnp.abs(deltay),5e-4),
        lambda:0.0,
        lambda:deltay)
    
    return vel_world.at[:2,3].set([deltax,deltay])

# This model has to be recompiled for different # objects for now this is okay
def determine_next_pose(all_poses, t, i, friction, gravity):
    # ignoring rotations for now #
    # vel_prev = jnp.linalg.solve(all_poses[t-2], all_poses[t-1])
    # vel_prev_prev = jnp.linalg.solve(all_poses[t-3], all_poses[t-2])
    # accel = jnp.linalg.solve(vel_prev_prev, vel_prev)
    # vel_now = vel_prev @ accel

    # simple velocity update
    vel_prev_world = cam_pose @ jnp.linalg.solve(all_poses[t-2], all_poses[t-1])
    vel_world = vel_prev_world.at[2,3].set(vel_prev_world[2,3] - gravity * 1./20)

    # friction check if object is on ground
    prev_pose_world = cam_pose @ all_poses[t-1]
    vel_world = jax.lax.cond(
        jnp.less_equal(get_bottom_most_height(i, prev_pose_world)[0],0.05),
        update_vel_friction,
        lambda *_:vel_world,
        *(vel_world, friction)
    )

    # go back to cam pose and update pose
    vel = inverse_cam_pose @ vel_world
    next_pose = all_poses[t-1].at[:3,3].set(all_poses[t-1][:3,3] + vel[:3,3]) # trans only, no rot
    # # ground collision
    # next_pose_world = cam_pose @ next_pose
    # bottom_z, center_to_bottom = get_bottom_most_height(i, next_pose_world)
    # next_pose = jax.lax.cond(
    #     jnp.less_equal(bottom_z,0),
    #     lambda:inverse_cam_pose @ next_pose_world.at[2,3].set(center_to_bottom),
    #     lambda:next_pose
    # )
    
    return next_pose

@genjax.gen
def mcs_single_object(prev_state, t_inits, t_fulls, init_poses, pose_update_params, dynamic_params, variance, outlier_prob):
    """
    Single Object Model HMM
    """

    (_, _, poses, all_poses, active_states, t) = prev_state
    friction, gravity = dynamic_params
    num_objects = poses.shape[0]
    
    # for each object
    for i in range(num_objects):
        # POSSIBLE BUG ALERT --> pose decorator may be in the wrong place, will be an issue for multiple objs
        
        poses = poses.at[i].set(
            jax.lax.cond(
                jnp.greater_equal(t,t_fulls[i]+2),
                determine_next_pose,
                lambda *_:poses[i],
                *(all_poses[:,i,...], t, i, friction, gravity)
            )
        )
        updated_pose = b.gaussian_vmf_pose(poses[i], *pose_update_params)  @ "pose"
        poses = poses.at[i].set(updated_pose)
        # # activate object when t == t_init for that object and initialize the correct pose
        # active_states = active_states.at[i].set(jax.lax.cond(
        #     jnp.equal(t_inits[i],t), # doing t_init + 1 so in first time step, the pose is fixed 
        #     lambda:True, 
        #     lambda:active_states[i]))
        
        poses = poses.at[i].set(jax.lax.cond(
            jnp.equal(t_inits[i],t), # init pose at the corerct time step
            lambda:init_poses[i], 
            lambda:poses[i]))
        # jprint("t = {}, pose is {}",t, poses[0][:3,3])

    all_poses = all_poses.at[t].set(poses)
    rendered_image_obj = b.RENDERER.render(
        poses, jnp.arange(num_objects))[...,:3]

    # NOTE: gt_images_bg is a global variable here as it consumes too much memory for the trace
    rendered_image = splice_image(rendered_image_obj, gt_images_bg[t])

    sampled_image = b.image_likelihood(rendered_image, variance, outlier_prob) @ "depth"

    return (rendered_image, rendered_image_obj, poses, all_poses, active_states, t+1)

In [4]:
def pose_update_v4(key, trace_, pose_grid, enumerator):
    
    weights = enumerator.enumerate_choices_get_scores(trace_, key, pose_grid)
    sampled_idx = weights.argmax() # jax.random.categorical(key, weights)
    # jprint("weights = {}",weights)
    # jprint("idx chosen = {}",sampled_idx)
    return *enumerator.update_choices_with_weight(
        trace_, key,
        pose_grid[sampled_idx]
    ), pose_grid[sampled_idx]


pose_update_v4_jit = jax.jit(pose_update_v4, static_argnames=("enumerator",))


def c2f_pose_update_v4(key, trace_, gridding_schedule_stacked, enumerator, t, addr):
    # TODO: VALID POSES THAT DONT PENETRATE (ground)
    # reference = trace_[addr][t-1]
    reference = trace_.get_retval()[2][t-1][0] # note that poses' shape is now (1,4,4)
    for i in range(len(gridding_schedule_stacked)):
        updated_grid = jnp.einsum("ij,ajk->aik", reference, gridding_schedule_stacked[i])
        # Time to check valid poses that dont intersect with the floor
        valid = jnp.logical_not(are_bboxes_intersecting_many_jit(
                            (100,100,20),
                            b.RENDERER.model_box_dims[0],
                            jnp.eye(4).at[:3,3].set([0,0,-10]),
                            jnp.einsum("ij,ajk->aik",cam_pose,updated_grid)
                            ))
        # if pose is not valid, use the reference pose
        valid_grid = jnp.where(valid[:,None,None], updated_grid, reference[None,...])
        weight, trace_, reference = pose_update_v4_jit(key, trace_, valid_grid, enumerator)
    # jprint("t = {}, weight = {}",t,weight)
    # jprint("t = {}",t)

    return weight, trace_

c2f_pose_update_v4_vmap_jit = jax.jit(jax.vmap(c2f_pose_update_v4, in_axes=(0,0,None,None,None)),
                                    static_argnames=("enumerator", "t", "addr"))

c2f_pose_update_v4_jit = jax.jit(c2f_pose_update_v4,static_argnames=("enumerator", "t", "addr"))

def make_new_keys(key, N_keys):
    key, other_key = jax.random.split(key)
    new_keys = jax.random.split(other_key, N_keys)
    if N_keys > 1:
        return key, new_keys
    else:
        return key, new_keys[0]


def update_choice_map(gt_depths, constant_choices, t):
    constant_choices['depth'] = gt_depths[t][None,...]
    return genjax.index_choice_map(
            [t], genjax.choice_map(
                constant_choices
            )
        )

def argdiffs_modelv6(trace, t):
    """
    Argdiffs specific to modelv6
    """
    # print(trace.args)
    args = trace.get_args()
    argdiffs = (
        Diff(t, UnknownChange),
        jtu.tree_map(lambda v: Diff(v, NoChange), args[1]),
        *jtu.tree_map(lambda v: Diff(v, NoChange), args[2:]),
    )
    return argdiffs

def proposal_choice_map(addresses, args, chm_args):
    addr = addresses[0] # custom defined
    return genjax.index_choice_map(
                    jnp.array([chm_args[0]]),genjax.choice_map({
                        addr: jnp.expand_dims(args[0], axis = 0)
            }))

In [5]:
def inference_approach_F3(model, gt, gridding_schedule, init_chm, model_args, init_state, key, constant_choices, T, addr):
    """
    Sequential Importance Sampling on the unfolded HMM model
    with 3D pose enumeration proposal

    WITH JUST ONE PARTICLE
    """
    # extract data

    key, init_key = make_new_keys(key, 1)

    # define functions for SIS/SMC
    init_fn = model.importance
    update_fn = jax.jit(model.update)
    proposal_fn = c2f_pose_update_v4_jit

    # initialize SMC/SIS
    init_log_weight, init_particle = init_fn(init_key, init_chm, (0, init_state, *model_args))

    def smc_body(state, t):
        # get new keys
        print("jit compiling")
        # jprint("t = {}",t)
        key, log_weight, particle = state
        key, update_key = make_new_keys(key, 1)
        key, proposal_key = make_new_keys(key, 1)

        argdiffs = argdiffs_modelv6(particle, t)

        # make enumerator for this time step (affects the proposal choice map)
        enumerator = b.make_enumerator([(addr)], 
                                        chm_builder = proposal_choice_map,
                                        argdiff_f=lambda x: argdiffs,
                                        chm_args = [t])

        # update model to new depth observation
        _, update_log_weight, updated_particle, _ = update_fn(
            update_key, particle, update_choice_map(gt,constant_choices, t), argdiffs)

        # propose good poses based on proposal
        proposal_log_weight, new_particle = proposal_fn(
            proposal_key, updated_particle, gridding_schedule, enumerator, t, addr)

        # get weight of particle
        new_log_weight = log_weight + proposal_log_weight + update_log_weight

        return (key, new_log_weight, new_particle), None

    (_, final_log_weight, particle), _ = jax.lax.scan(
        smc_body, (key, init_log_weight, init_particle), jnp.arange(1, T+1))
    print("SCAN finished")
    rendered = particle.get_retval()[0]
    return final_log_weight, particle, rendered

In [6]:
# enumeration grid
# TODO: ADAPTIVE GRID SIZING
grid_widths = [1, 0.2,0.04]
# grid_widths = [0.5, 0.1,0.02]
grid_nums = [(7,7,7),(7,7,7),(7,7,7)]
gridding_schedule_trans = make_schedule_translation_3d(grid_widths, grid_nums)
gridding_schedule_rot = [b.utils.make_rotation_grid_enumeration(10, 15, -jnp.pi/12, jnp.pi/12, jnp.pi/12)]
# gridding_schedule = [gridding_schedule_trans[0], gridding_schedule_trans[1], gridding_schedule_trans[2], gridding_schedule_rot[0]]
gridding_schedule = [gridding_schedule_trans[0], gridding_schedule_trans[1], gridding_schedule_trans[2]]

# Setup for inference
T = gt_images.shape[0]
num_registered_objects = len(registered_objects)
INIT_STATE = (
        gt_images[0],
        gt_images_obj[0],
        jnp.tile(jnp.eye(4).at[2,3].set(1e+5)[None,...],(num_registered_objects,1,1)),
        jnp.zeros((T,num_registered_objects,4,4)),
        jnp.zeros(num_registered_objects, dtype=bool),
        0
)
MODEL_ARGS = (
     jnp.array([registered_obj['t_init'] for r in registered_objects]),
     jnp.array([registered_obj['t_full'] for r in registered_objects]),
     jnp.array([registered_obj['pose'] for r in registered_objects]),
     jnp.array([5e-0, 5e-1]),
     (0.0, 9.81),
     0.1,
     None
)
CONSTANT_CHOICES = {}

init_chm = update_choice_map(gt_images, CONSTANT_CHOICES, 0)
key = jax.random.PRNGKey(45675456)

model_unfold = genjax.UnfoldCombinator.new(mcs_single_object, T)
# inference_approach_F3_jit = jax.jit(inference_approach_F3, static_argnames=("T", "addr"))
inference_approach_F3_jit = inference_approach_F3

start = time.time()
lw, tr, rendered = inference_approach_F3_jit(model_unfold, gt_images, 
    gridding_schedule, init_chm, MODEL_ARGS, INIT_STATE, key, CONSTANT_CHOICES, T, "pose")
print ("FPS:", rendered.shape[0] / (time.time() - start))

jit compiling
SCAN finished
FPS: 1.423121048574752


In [None]:
chm = genjax.choice_map({
    'depth' : gt_images[0]
})

_, tr = mcs_single_object.importance(key, chm, (INIT_STATE, *MODEL_ARGS))

In [None]:
grid_points = jnp.tile(jnp.eye(4)[None,...],(1000,1,1))
enumerators = b.make_enumerator(["pose"])
mcs_single_object.update(tr, key, grid_points)

In [None]:
lw, tr, rendered = inference_approach_F3_jit(model_unfold, gt_images, 
    gridding_schedule, init_chm, MODEL_ARGS, INIT_STATE, key, CONSTANT_CHOICES, 1, "pose")


In [7]:
images = []
for t in range(T):
    images.append(b.multi_panel([
                b.scale_image(b.get_depth_image(gt_images[t][...,2]),6),
                b.scale_image(b.get_depth_image(tr.get_retval()[0][t][...,2]),6),
                b.scale_image(b.get_depth_image(tr.get_retval()[1][t][...,2]),6)
                ],labels = ['gt/sampled', 'rendered', 'rendered_obj']))
display_video(images, framerate=30)

In [None]:
poses = tr.get_retval()[3][-1][:,0,...]

In [None]:
cam_pose @ poses[150]

In [None]:
b.RENDERER.model_box_dims

In [None]:
are_bboxes_intersecting_jit(b.RENDERER.model_box_dims[0],
                            (100,100,20),
                            cam_pose @ poses[0],
                            jnp.eye(4).at[:3,3].set([0,0,-10]))

In [None]:
xx = jnp.einsum("ij,ajk->aik",cam_pose,poses)
valid = jnp.logical_not(are_bboxes_intersecting_many_jit(
                            (100,100,20),
                            b.RENDERER.model_box_dims[0],
                            jnp.eye(4).at[:3,3].set([0,0,-10]),
                            xx
                            ))

In [None]:
jnp.where(valid[:,None,None], xx,jnp.ones((4,4))[None,...])[0]

In [None]:
valid

In [None]:
all_poses = tr.get_retval()[3][102][:,0,...]
cam_pose @ determine_next_pose(all_poses, 103, 0)

In [None]:
get_bottom_most_height(0,cam_pose @ poses[135])

In [None]:
[(cam_pose @ poses[i])[1,3] for i in range(132,240)]

In [None]:
t = 124
b.multi_panel([
                b.scale_image(b.get_depth_image(gt_images[t][...,2]),6),
                b.scale_image(b.get_depth_image(tr.get_retval()[0][t][...,2]),6),
                b.scale_image(b.get_depth_image(tr.get_retval()[1][t][...,2]),6)
                ],labels = ['gt/sampled', 'rendered', 'rendered_obj'])

In [None]:
video_comparison_from_images(tr.get_retval()[0], gt_images)

In [None]:
tr.score

In [None]:
registered_objects

In [None]:
# vmf param tuner
params = (1e+1,1e-1)
display(b.gaussian_vmf_pose.logpdf(jnp.eye(4), jnp.eye(4).at[:3,3].set([0,0,0]), *params))
display(b.gaussian_vmf_pose.logpdf(jnp.eye(4), jnp.eye(4).at[:3,3].set([1,1,1]), *params))
b.gaussian_vmf_pose.logpdf(jnp.eye(4), jnp.eye(4).at[:3,:3].set(R_zyx), *params)

In [None]:

b.image_likelihood.logpdf(gt_images[5],gt_images[4], 0.1, None)