In [1]:
# imports

import sys
import jax
import os
import time
import pickle
import genjax
import bayes3d as b
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from viz import *
from utils import *
from mcs.mcs_utils import *
from PIL import Image
from io import BytesIO
import bayes3d.transforms_3d as t3d
from jax.debug import print as jprint
from tqdm import tqdm
from dataclasses import dataclass
from genjax.generative_functions.distributions import ExactDensity
import jax.tree_util as jtu
from genjax._src.core.transforms.incremental import NoChange, UnknownChange, Diff
console = genjax.pretty()
%matplotlib inline
import pybullet as p
import pybullet_data

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


pybullet build time: Nov 28 2023 23:52:03


In [2]:
# pybullet helper/util functions
def object_pose_in_camera_frame(object_id, view_matrix):
    object_pos, object_orn = p.getBasePositionAndOrientation(object_id) # world frame
    world2cam = np.array(view_matrix).reshape([4,4]).T # world --> cam 
    object_transform_matrix = np.eye(4)
    object_transform_matrix[:3, :3] = np.reshape(p.getMatrixFromQuaternion(object_orn), (3, 3))
    object_transform_matrix[:3, 3] = object_pos
    cam2world = world2cam @ object_transform_matrix
    cam2world[1:3] *= -1
    return cam2world

def cam_pose_from_view_matrix(view_matrix):
    # cam2world
    world2cam = np.array(view_matrix).reshape([4,4]).T
    world2cam[1:3] *= -1
    cam2world  = np.linalg.inv(world2cam)
    return cam2world

def view_matrix_from_cam_pose(cam_pose):
    world2cam = np.linalg.inv(cam_pose)
    world2cam[1:3] *= -1
    return tuple(world2cam.T.reshape(world2cam.size))

In [3]:
# pybullet simulation code

# Initialize the PyBullet physics simulation
p.connect(p.DIRECT)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
# initialize list to store object_ids
object_ids = []

####################### PARAMETERS (NOTE to Eric: modify params here) ###################################

# Sim params
sim_time = 1.5 # simulation time
fps = 240 # physics engine FPS --> make this a multiple of fps_data
fps_data = 60 # ground truth saving FPS
friction_coefficient = 0.0  # Adjust this value as needed
gravity = -9.81

# camera intrisics params
width = 480
height = 360
field_of_view = 60
near = 0.1
far = 100

# camera_pose (define it explicitly or leave it as None)
# If cam pose is None, modify it based on computeViewMatrixFromYawPitchRoll below
cam_pose = None
# cam_pose = np.array([
#     [1,0,0,0],
#     [0,1,0,0]
#     [0,0,1,0]
#     [0,0,0,1]
# ])

## NOTE TO ERIC: This is where you make objects and give it an initial condition
# Create object (custom code)
box_mass = 1
box_position = [-3.25, 0, 0.501]
box_start_velocity = [6, 0, 0]
mesh_scale = [1,1,1]
box_shape = p.createCollisionShape(shapeType=p.GEOM_MESH, fileName=os.path.join(b.utils.get_assets_dir(),"sample_objs/cube.obj"), meshScale=mesh_scale)
box_id = p.createMultiBody(box_mass, box_shape, basePosition=box_position)
p.resetBaseVelocity(box_id, box_start_velocity, [0, 0, 0]) # rot vel
p.changeDynamics(box_id, -1, restitution = 1)
object_ids.append(box_id)
####################### END OF PARAMETERS ###################################


# Set up the simulation environment (floor, gravity and friction)
p.setGravity(0, 0, gravity)
floor_id = p.loadURDF("plane.urdf")
p.changeDynamics(floor_id, -1, lateralFriction=friction_coefficient)

# Arrays for serialization 
rgbs = []
depths = []
gt_poses = []

if cam_pose is None:
    view_matrix = p.computeViewMatrixFromYawPitchRoll(cameraTargetPosition=[0, 0, 0], distance=6, yaw=0, pitch=-5, roll=0,
                                                        upAxisIndex=2)
    cam_pose = cam_pose_from_view_matrix(view_matrix)
else:
    view_matrix = view_matrix_from_cam_pose(cam_pose)
proj_matrix = p.computeProjectionMatrixFOV(fov=field_of_view, aspect=float(width)/height, nearVal=near, farVal=far)
# manually save pose of object that remains constant (like floor, occluders etc.)
floor_cam_pose = object_pose_in_camera_frame(floor_id, view_matrix)

# Step through the simulation
num_timesteps = int(fps * sim_time)
save_data_frequency = int(fps / fps_data)
p.setTimeStep(1.0/fps)
for i in range(num_timesteps):
    p.stepSimulation()
    # record data as per FPS
    if i%save_data_frequency == 0:
    
        (_, _, px, d, _) = p.getCameraImage(width=width, height=height, viewMatrix=view_matrix,
                                            projectionMatrix=proj_matrix, renderer=p.ER_BULLET_HARDWARE_OPENGL)
        rgb_array = np.array(px, dtype=np.uint8)
        rgb_array = np.reshape(rgb_array, (height, width, 4))[:, :, :3] # remove alpha channel
        # depths.append(np.array(d))
        rgbs.append(rgb_array)

        pose_arr = []
        for i,obj_id in enumerate(object_ids):
            pose_arr.append(object_pose_in_camera_frame(obj_id, view_matrix))
        # pose_arr.append(floor_cam_pose)
        gt_poses.append(pose_arr)

gt_poses = jnp.array(gt_poses)
# depths = jnp.array(depths)
rgbs = jnp.array(rgbs)

p.disconnect()

# FOV is based on height and it scales proportionately to width
focal = (height/2) / np.tan(np.deg2rad(field_of_view) / 2.0)
intrinsics = b.Intrinsics(
    height,
    width,
    focal,
    focal,
    width/2,
    height/2,
    near,
    far
)

In [4]:
# view GT rgb video
display_video(rgbs)

In [5]:
# scale intrinsics for gpu memory
SCALE = 0.2
scaled_intrinsics = b.scale_camera_parameters(intrinsics, SCALE)

b.setup_renderer(scaled_intrinsics)
# add objects first
b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/cube.obj"), scaling_factor = 1)
# then get gt_images
gt_images = b.RENDERER.render_many(gt_poses, jnp.array([0]))[...,:3]
T = gt_images.shape[0]
# view gt_images as a depth image (I ignored the floor here)
vid = [b.scale_image(b.get_depth_image(gt_images[i,...,2]),int(1.0/SCALE)) for i in range(T)]
display_video(vid)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (96, 96, 1024)


In [6]:
# inference helper/util functions

def pose_update_v5(key, trace_, pose_grid, enumerator):
    # guard GPU memory usage
    num_splits = (pose_grid.shape[0] // 400) + 1
    all_weights = jnp.array([])
    for split_pose_grid in jnp.array_split(pose_grid, num_splits):
        weights = enumerator.enumerate_choices_get_scores(trace_, key, split_pose_grid)
        all_weights = jnp.hstack([all_weights, weights])
    sampled_idx = all_weights.argmax() # jax.random.categorical(key, weights)
    # jprint("weights = {}",weights)
    # jprint("idx chosen = {}",sampled_idx)
    return *enumerator.update_choices_with_weight(
        trace_, key,
        pose_grid[sampled_idx]
    ), pose_grid[sampled_idx]


pose_update_v5_jit = jax.jit(pose_update_v5, static_argnames=("enumerator",))


def c2f_pose_update_v5(key, trace_, reference, gridding_schedule, enumerator, obj_id):
    # for each object (TODO: gibbs sampling)
    for i in range(len(gridding_schedule)):
        updated_grid = jnp.einsum("ij,ajk->aik", reference, gridding_schedule[i])
        weight, trace_, reference = pose_update_v5_jit(key, trace_, updated_grid, enumerator)
        # jprint("ref position is {}", reference[:3,3])

    return weight, trace_

c2f_pose_update_v5_vmap_jit = jax.jit(jax.vmap(c2f_pose_update_v5, in_axes=(0,0,None,None,None)),
                                    static_argnames=("enumerator", "obj_id"))

c2f_pose_update_v5_jit = jax.jit(c2f_pose_update_v5,static_argnames=("enumerator", "obj_id"))

def make_new_keys(key, N_keys):
    key, other_key = jax.random.split(key)
    new_keys = jax.random.split(other_key, N_keys)
    return key, new_keys

def update_choice_map_no_unfold(gt_depths, constant_choices, t):
    constant_choices['depth'] = gt_depths[t]
    return genjax.choice_map(
                constant_choices
            )


def argdiffs_modelv7(trace):
    """
    Argdiffs specific to mcs_single_obejct model with no unfold
    """
    args = trace.get_args()
    argdiffs = (
        jtu.tree_map(lambda v: Diff(v, UnknownChange), args[0]),
        *jtu.tree_map(lambda v: Diff(v, NoChange), args[1:]),
    )
    return argdiffs


def proposal_choice_map_no_unfold(addresses, args, chm_args):
    addr = addresses[0] # custom defined
    return genjax.choice_map({
                        addr: args[0]
            })

In [7]:
def inference_approach(model, gt, gridding_schedules, model_args, init_state, key, constant_choices, T, addr, n_particles):
    """
    Sequential Importance Sampling on the non-unfolded HMM model
    with 3D pose enumeration proposal with multiple particles
    """
    
    num_objects = init_state[1].shape[0]

    def get_next_state(particle):
        return (None,*particle.get_retval()[1:])
    get_next_state_vmap = jax.vmap(get_next_state, in_axes = (0,))

    # broadcast init_state to number of particles
    init_states = jax.vmap(lambda x:init_state, in_axes=(0,))(jnp.arange(n_particles))

    # define functions for SIS/SMC
    init_fn = jax.jit(jax.vmap(model.importance, in_axes=(0,None,0)))
    update_fn = jax.jit(model.update)
    proposal_fn = c2f_pose_update_v5_jit

    def smc_body(carry, t):
        # get new keys
        print("jit compiling")
        # initialize particle based on last time step
        jprint("t = {}",t)
        
        key, log_weights, states,  = carry
        key, importance_keys = make_new_keys(key, n_particles)
        key, proposal_key = jax.random.split(key)

        full_args = jax.vmap(lambda x,y:(x, *y), in_axes=(0,None))(states, model_args)

        importance_log_weights, particles = init_fn(importance_keys, update_choice_map_no_unfold(gt,constant_choices, t), full_args)

        # propose good poses based on grid enum proposal
        def proposer(carry, p):
            key, idx = carry
            proposal_log_weight = 0
            # argdiff and enumerator
            argdiffs = argdiffs_modelv7(p)
            enumerators = [b.make_enumerator([(addr + f'_{i}')], 
                                        chm_builder = proposal_choice_map_no_unfold,
                                        argdiff_f=lambda x: argdiffs
                                        ) for i in range(num_objects)] 
            for obj_id in range(num_objects):
                key, new_key = jax.random.split(key)
                w, p = proposal_fn(new_key, p, states[1][idx][obj_id], gridding_schedules[obj_id], enumerators[obj_id], obj_id)
                proposal_log_weight += w
            return (new_key, idx + 1), (proposal_log_weight, p)
        _, (proposal_log_weights, proposed_particles) = jax.lax.scan(proposer, (proposal_key, 0), particles)

        # get weights of particles
        new_log_weight = log_weights + importance_log_weights + proposal_log_weights
        next_states = get_next_state_vmap(proposed_particles)

        return (key, new_log_weight, next_states), proposed_particles

    (_, final_log_weight, _), particles = jax.lax.scan(
        smc_body, (key, jnp.zeros(n_particles), init_states), jnp.arange(0, T))
    rendered = particles.get_retval()[0]
    inferred_poses = particles.get_retval()[1]
    print("SCAN finished")
    return final_log_weight, rendered, inferred_poses, particles

In [8]:
@genjax.gen
def simple_model(prev_state, pose_update_params, variance, outlier_prob):
    """
    Simple Multiple Object Model HMM
    """

    (_, poses, t) = prev_state

    num_objects = poses.shape[0]
    # for each object
    for i in range(num_objects):        
        updated_pose = b.gaussian_vmf_pose(poses[i], *pose_update_params)  @ f"pose_{i}"
        poses = poses.at[i].set(updated_pose)

    rendered_image = b.RENDERER.render(
        poses, jnp.arange(num_objects))[...,:3]

    sampled_image = b.image_likelihood(rendered_image, variance, outlier_prob) @ "depth"

    return (rendered_image, poses, t+1)

In [9]:
# Setup for inference

# gridding schedule for EACH object
grid_widths = [1, 0.5,0.25]
grid_nums = [(5,5,5),(5,5,5), (5,5,5)]
gridding_schedules = [make_schedule_translation_3d(grid_widths, grid_nums)]
# note to eric: this does not include code for rotation proposals. You can look at demo.py for that

T = gt_images.shape[0]
# assume you know the first time step pose
INIT_STATE = (
        None,
        gt_poses[0],
        0
)
MODEL_ARGS = (
     jnp.array([5e-0, 5e-1]),
     0.1,
     None
)
CONSTANT_CHOICES = {}

key = jax.random.PRNGKey(np.random.randint(0,56567567))

# Note to eric: if you dont need particles, just set it to 1
n_particles = 10

start = time.time()
lw, rendered, inferred_poses, trace = inference_approach(simple_model, gt_images, 
    gridding_schedules, MODEL_ARGS, INIT_STATE, key, CONSTANT_CHOICES, T, "pose", n_particles)
print ("FPS:", rendered.shape[0] / (time.time() - start))

jit compiling
t = 0
t = 1
t = 2
t = 3
t = 4
t = 5
t = 6
t = 7
t = 8
t = 9
t = 10
t = 11
t = 12
t = 13
t = 14
t = 15
t = 16
t = 17
t = 18
t = 19
t = 20
t = 21
t = 22
t = 23
t = 24
t = 25
t = 26
t = 27
t = 28
t = 29
t = 30
t = 31
t = 32
t = 33
t = 34
t = 35
t = 36
t = 37
t = 38
t = 39
t = 40
t = 41
t = 42
t = 43
t = 44
t = 45
t = 46
t = 47
t = 48
t = 49
t = 50
t = 51
t = 52
t = 53
t = 54
t = 55
t = 56
t = 57
t = 58
t = 59
t = 60
t = 61
t = 62
t = 63
t = 64
t = 65
t = 66
t = 67
t = 68
t = 69
t = 70
t = 71
t = 72
t = 73
t = 74
t = 75
t = 76
t = 77
t = 78
t = 79
t = 80
t = 81
t = 82
t = 83
t = 84
t = 85
t = 86
t = 87
t = 88
t = 89
SCAN finished
FPS: 4.046250697881029


In [10]:
# reconstruction for particle # 0
# Note: this is slow and takes 
particle_id = 0
images = []
for t in tqdm(range(T)):
    images.append(b.multi_panel([
                b.scale_image(b.get_depth_image(gt_images[t,...,2]),int(1./SCALE)),
                b.scale_image(b.get_depth_image(rendered[t,particle_id,...,2]),int(1./SCALE))
                ],labels = ['gt/observed', 'inferred (particle 0)']))
display_video(images, framerate=30)

100%|██████████| 90/90 [00:01<00:00, 52.01it/s]


In [11]:
# visualize all particles as dots (blended with the ground truth)
# Note: this is even slower
p_images = get_particle_images(intrinsics, inferred_poses, T = T)
# blended_images = [b.overlay_image(p_images[i],b.get_depth_image(gt_images[i][...,2])) for i in range(len(p_images))]
images = []
for t in tqdm(range(T)):
    images.append(b.multi_panel([
                b.scale_image(b.get_depth_image(gt_images[t,...,2]),int(1./SCALE)),
                b.scale_image(p_images[t],1)
                ],labels = ['gt/observed', 'particles']))
display_video(images, framerate=30)

[Open3D INFO] EGL headless mode enabled.
FEngine (64 bits) created at 0x5587ecda2240 (threading is enabled)
EGL(1.5)
OpenGL(4.1)


100%|██████████| 90/90 [00:02<00:00, 30.78it/s]
100%|██████████| 90/90 [00:01<00:00, 86.91it/s]
