In [1]:
import sys
import jax
import genjax
import bayes3d as b
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("../")
from viz import *
from mcs_utils import *
from PIL import Image
import bayes3d.transforms_3d as t3d
from jax.debug import print as jprint
from tqdm import tqdm
console = genjax.pretty()

In [2]:
# Loading and preprocessing all data and renderer
SCALE = 0.1
observations = load_observations_npz('passive_physics_validation_shape_constancy_0001_01')
gt_images, gt_images_bg, gt_images_obj, intrinsics, registered_objects = preprocess_mcs_physics_scene(observations, MIN_DIST_THRESH=0.6, scale=SCALE)
b.setup_renderer(intrinsics)
for registered_obj in registered_objects:
    b.RENDERER.add_mesh(registered_obj['mesh'])
video_from_rendered(gt_images, scale = int(1/SCALE), framerate=30)

Extracting Meshes


 57%|█████▊    | 92/160 [00:04<00:01, 41.53it/s]

Adding review
Review passed, added to init queue
Adding new mesh at t = {} 95


100%|██████████| 160/160 [00:29<00:00,  5.35it/s]


Extracting downsampled data


100%|██████████| 160/160 [00:05<00:00, 26.88it/s]
[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)
Centering mesh with translation [0.01000004 0.01249996 0.01249991]


In [3]:
# Setup for inference
T = gt_images.shape[0]
num_registered_objects = len(registered_objects)
MODEL_ARGS = (
    (
        gt_images[0],
        jnp.tile(jnp.eye(4).at[3,3].set(1e+5)[None,...],(num_registered_objects,1,1)),
        jnp.zeros(num_registered_objects, dtype=bool),
        0
    ),
     jnp.array([registered_obj['t'] for r in registered_objects]),
     jnp.array([registered_obj['pose'] for r in registered_objects]),
     gt_images_bg,
     jnp.array([0,1e+20]),
     0.01,
     None
)

# full_chm = genjax.index_choice_map(
#     jnp.arange(T),
#     genjax.choice_map({
#         'depth' : gt_images
#     })
# )

full_chm = genjax.choice_map({
        'depth' : gt_images[0]
    })

In [20]:
# model time
# This model has to be recompiled for different # objects for now this is okay
@genjax.gen
def mcs_single_object(prev_state, t_inits, init_poses, gt_images_bg, pose_update_params, variance, outlier_prob):
    """
    Single Object Model HMM
    """

    (_, poses, active_states, t) = prev_state

    num_objects = poses.shape[0]

    # for each object
    for i in range(num_objects):
        # activate object when t == t_init for that object and initialize the correct pose
        active_states.at[i].set(jax.lax.cond(
            jnp.equal(t_inits[i]+1,t), # doing t_init + 1 so in first time step, the pose is fixed 
            lambda:True, 
            lambda:active_states[i]))
        
        poses.at[i].set(jax.lax.cond(
            jnp.equal(t_inits[i],t), # init pose at the corerct time step
            lambda:init_poses[i], 
            lambda:poses[i]))
        # pose_update_params = jax.lax.cond(active_states[i],lambda:pose_update_params,lambda:jnp.array([1e+20, 0]))
        # put the pose so far away that any pose update cant possibel come intot he scene
        # We will use the active_states for genjax branch switching
        updated_pose = b.gaussian_vmf_pose(poses[i], *pose_update_params)  @ "pose"
        poses.at[i].set(updated_pose)

    rendered_image_obj = b.RENDERER.render(
        poses, jnp.arange(num_objects))[...,:3]

    rendered_image = splice_image(rendered_image_obj, gt_images_bg[t])

    sampled_image = b.image_likelihood(rendered_image, variance, outlier_prob) @ "depth"

    return (rendered_image, poses, active_states, t+1)

In [21]:
key = jax.random.PRNGKey(236786782323467)
mcs_single_object_importance_jit = jax.jit(mcs_single_object.importance)
_, tr = mcs_single_object_importance_jit(key, full_chm, MODEL_ARGS)

In [22]:
_, tr = mcs_single_object_importance_jit(key, full_chm, MODEL_ARGS)

In [23]:
tr




└── [1mBuiltinTrace[0m
    ├── gen_fn
    │   └── [1mBuiltinGenerativeFunction[0m
    │       └── source
    │           └── <function mcs_single_object>
    ├── args
    │   └── [1mtuple[0m
    │       ├── [1mtuple[0m
    │       │   ├──  f32[40,60,3]
    │       │   ├──  f32[1,4,4]
    │       │   ├──  bool[1]
    │       │   └──  i32[]
    │       ├──  i32[1]
    │       ├──  f32[1,4,4]
    │       ├──  f32[160,40,60,3]
    │       ├──  f32[2]
    │       ├──  f32[]
    │       └── (const) None
    ├── retval
    │   └── [1mtuple[0m
    │       ├──  f32[40,60,3]
    │       ├──  f32[1,4,4]
    │       ├──  bool[1]
    │       └──  i32[]
    ├── choices
    │   └── [1mTrie[0m
    │       ├── [1m:pose[0m
    │       │   └── [1mDistributionTrace[0m
    │       │       ├── gen_fn
    │       │       │   └── [1mGaussianVMFPose[0m
    │       │       ├── args
    │       │       │   └── [1mtuple[0m
    │       │       │       ├──  f32[4,4]
    │       │       │     