In [1]:
import os
import io
import trimesh
import b3d
import genjax
import jax
import b3d.bayes3d as bayes3d
import jax.numpy as jnp
import numpy as np
import rerun as rr
from PIL import Image
import matplotlib.pyplot as plt
import h5py
import json
import random
from copy import deepcopy
import itertools
import pycocotools.mask as mask_util
from scipy.spatial.transform import Rotation
from tqdm import tqdm
from functools import reduce
from genjax import Pytree
from physion_utils import *
import b3d.chisight.dense.dense_model
import b3d.chisight.dense.likelihoods.laplace_likelihood
import collections
from genjax._src.core.serialization.msgpack import msgpack_serialize

In [2]:
rr.init("demo")
rr.connect("127.0.0.1:8812")
rr.log("/", rr.ViewCoordinates.LEFT_HAND_Y_UP, static=True)  # Set an up-axis

In [3]:
# paths for reading physion metadata
physion_assets_path = os.path.join(
    b3d.get_root_path(),
    "assets/physion/",)

resnet_inference_path = os.path.join(
    b3d.get_root_path(),
    "resnet_results/",)

stim_name = 'pilot_dominoes_0mid_d3chairs_o1plants_tdwroom_0001'

hdf5_file_path = os.path.join(physion_assets_path,
    f"{stim_name}.hdf5",
)

mesh_file_path = os.path.join(physion_assets_path,
    f"all_flex_meshes/core",
)

json_file_path = os.path.join(resnet_inference_path,
    f"{stim_name}.json",
)

im_width = 350
im_height = 350

In [4]:
vfov = 54.43222 
near_plane = 0.1
far_plane = 100
depth_arr = []
image_arr = []
base_id, base_type, attachment_id, attachent_type, attachment_fixed, use_attachment, use_base, use_cap = None, None, None, None, None, None, None, None
with h5py.File(hdf5_file_path, "r") as f:
    # extract depth info
    for key in f['frames'].keys():
        depth = jnp.array(Image.fromarray(np.array(f['frames'][key]['images']['_depth_cam0'])).resize((im_width, im_height), Image.BICUBIC))
        depth_arr.append(depth)
        image = jnp.array(Image.open(io.BytesIO(f['frames'][key]['images']['_img_cam0'][:])).resize((im_width, im_height), Image.BICUBIC))
        image_arr.append(image)
    depth_arr = jnp.asarray(depth_arr)
    image_arr = jnp.asarray(image_arr)/255
    FINAL_T, height, width = image_arr.shape[0], image_arr.shape[1], image_arr.shape[2]

    # extract camera info
    camera_azimuth = np.array(f['azimuth']['cam_0'])
    camera_matrix = np.array(f['frames']['0000']['camera_matrices']['camera_matrix_cam0']).reshape((4, 4))
    projection_matrix = np.array(f['frames']['0000']['camera_matrices']['projection_matrix_cam0']).reshape((4, 4))
    im_seg = np.array(Image.open(io.BytesIO(f['frames']['0000']['images']['_id_cam0'][:])).resize((im_width, im_height), Image.BICUBIC))

    # Calculate the intrinsic matrix from vertical_fov.
    # Motice that hfov and vfov are different if height != width
    # We can also get the intrinsic matrix from opengl's perspective matrix.
    # http://kgeorge.github.io/2014/03/08/calculating-opengl-perspective-matrix-from-opencv-intrinsic-matrix
    vfov = vfov / 180.0 * np.pi
    tan_half_vfov = np.tan(vfov / 2.0)
    tan_half_hfov = tan_half_vfov * width / float(height)
    fx = width / 2.0 / tan_half_hfov  # focal length in pixel space
    fy = height / 2.0 / tan_half_vfov

    # extract object info
    object_ids = np.array(f['static']['object_ids'])
    model_names = np.array(f['static']['model_names'])
    assert len(object_ids) == len(model_names)

    distractors = np.array(f['static']['distractors']) if np.array(f['static']['distractors']).size != 0 else None
    occluders = np.array(f['static']['occluders']) if np.array(f['static']['occluders']).size != 0 else None
    distractor_ids = np.concatenate([np.where(model_names==distractor)[0] for distractor in distractors], axis=0).tolist() if distractors else []
    occluder_ids = np.concatenate([np.where(model_names==occluder)[0] for occluder in occluders], axis=0).tolist() if occluders else []
    excluded_model_ids = distractor_ids+occluder_ids
    included_model_ids = [idx for idx in range(len(object_ids)) if idx not in excluded_model_ids]
    object_ids = included_model_ids
    
    object_segmentation_colors = np.array(f['static']['object_segmentation_colors'])
    initial_position = np.array(f['static']['initial_position'])
    initial_rotation = np.array(f['static']['initial_rotation'])
    scales = np.array(f['static']['scale'])
    if "base_id" in np.array(f['static']) and "attachment_id" in np.array(f['static']):
        base_id = np.array(f['static']['base_id'])
        base_type = np.array(f['static']['base_type'])
        attachment_id = np.array(f['static']['attachment_id'])
        attachent_type = np.array(f['static']['attachent_type'])
        attachment_fixed = np.array(f['static']['attachment_fixed'])
        use_attachment = np.array(f['static']['use_attachment'])
        use_base = np.array(f['static']['use_base'])
        use_cap = np.array(f['static']['use_cap'])
        assert attachment_id.size==1
        assert base_id.size==1
        attachment_id = attachment_id.item()
        base_id = base_id.item()
        print(base_id, base_type, attachment_id, attachent_type, attachment_fixed, use_attachment, use_base, use_cap)

In [5]:
all_meshes = {}
for path, dirs, files in os.walk(mesh_file_path):
    for name in (files + dirs):
        if name.endswith('.obj'):
            mesh = trimesh.load(os.path.join(path, name))
            all_meshes[name[:-4]] = mesh
ordered_all_meshes = collections.OrderedDict(sorted(all_meshes.items()))

In [6]:
# Defines the enumeration schedule.
scaling_factor = 1.0
renderer = b3d.renderer.renderer_original.RendererOriginal(
    width * scaling_factor,
    height * scaling_factor,
    fx * scaling_factor,
    fy * scaling_factor,
    (width/2) * scaling_factor,
    (height/2) * scaling_factor,
    near_plane,
    far_plane,
)

b3d.reload(b3d.chisight.dense.dense_model)
b3d.reload(b3d.chisight.dense.likelihoods.laplace_likelihood)
likelihood_func = b3d.chisight.dense.likelihoods.laplace_likelihood.likelihood_func
model, viz_trace, info_from_trace = (
    b3d.chisight.dense.dense_model.make_dense_multiobject_model(
        renderer, likelihood_func
    )
)
importance_jit = jax.jit(model.importance)

likelihood_args = {
    "fx": renderer.fx,
    "fy": renderer.fy,
    "cx": renderer.cx,
    "cy": renderer.cy,
    "image_width": Pytree.const(renderer.width),
    "image_height": Pytree.const(renderer.height),
}


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [7]:
num_pose_grid = 11
position_search_thr = 0.05

# Gridding on translation only.
translation_deltas = b3d.Pose.concatenate_poses(
    [
        jax.vmap(lambda p: b3d.Pose.from_translation(p))(
            jnp.stack(
                jnp.meshgrid(
                    jnp.linspace(-position_search_thr, position_search_thr, num_pose_grid),
                    jnp.linspace(-position_search_thr, position_search_thr, num_pose_grid),
                    jnp.linspace(-position_search_thr, position_search_thr, num_pose_grid),
                ),
                axis=-1,
            ).reshape(-1, 3)
        ),
        b3d.Pose.identity()[None, ...],
    ]
)
# Sample orientations from a VMF to define a "grid" over orientations.
rotation_deltas = b3d.Pose.concatenate_poses(
    [
        jax.vmap(b3d.Pose.sample_gaussian_vmf_pose, in_axes=(0, None, None, None))(
            jax.random.split(jax.random.PRNGKey(0), num_pose_grid * num_pose_grid * num_pose_grid),
            b3d.Pose.identity(),
            0.0001,
            10.0,
        ),
        b3d.Pose.identity()[None, ...],
    ]
)
all_deltas = b3d.Pose.stack_poses([translation_deltas, rotation_deltas])

# needs to be odd
num_scale_grid = 11
scale_search_thr = 0.1

scale_deltas = jnp.stack(
                jnp.meshgrid(
                    jnp.linspace(-scale_search_thr, scale_search_thr, num_scale_grid),
                    jnp.linspace(-scale_search_thr, scale_search_thr, num_scale_grid),
                    jnp.linspace(-scale_search_thr, scale_search_thr, num_scale_grid),
                ),
                axis=-1,
            ).reshape(-1, 3)

In [8]:
rgbds = jnp.concatenate([image_arr, jnp.reshape(depth_arr, depth_arr.shape+(1,))], axis=-1)

In [9]:
START_T = 0

with open(json_file_path) as f:
    json_file = json.load(f)
pred = json_file['scene'][0]['objects']
sample = [0 for _ in range(len(object_ids))]
pose_scale_mesh = {}
for i, (o_id, idx) in enumerate(zip(object_ids, sample)):
    area = mask_util.decode(pred[i]['mask']).astype(bool)
    object_colors = jnp.asarray(rgbds[START_T][..., 0:3][area])
    mean_object_colors = np.mean(object_colors, axis=0)
    pose_scale_mesh[o_id] = b3d.Mesh(ordered_all_meshes[pred[i]['type'][idx]].vertices, ordered_all_meshes[pred[i]['type'][idx]].faces,  jnp.ones(ordered_all_meshes[pred[i]['type'][idx]].vertices.shape)*mean_object_colors)


In [10]:
with open("/home/haoliangwang/b3d/saved_traces/test.pickle", "rb") as output_file:
    retrieved_tr = msgpack_serialize.load(output_file, model, (Pytree.const([o_id for o_id in object_ids]), [pose_scale_mesh[o_id] for o_id in object_ids], likelihood_args, Pytree.const(True)))
viz_trace(retrieved_tr, 0)

In [11]:
retrieved_tr.get_choices()["object_scale_0"]

Array([0.29165232, 0.15955979, 1.7935468 ], dtype=float32)

In [12]:
retrieved_tr.get_score()

Array(-6275.1025, dtype=float32)

In [13]:
o_id = 0
trace = retrieved_tr
key = jax.random.PRNGKey(0)
for iter in range(5):    
    trace, key, _, _ = bayes3d.enumerate_and_select_best_move_pose(
        trace, Pytree.const((f"object_pose_{o_id}",)), key, all_deltas
    )
    # trace, key, posterior_scales, scores = bayes3d.enumerate_and_select_best_move_scale(
    #     trace, Pytree.const((f"object_scale_{o_id}",)), key, scale_deltas
    # )
    print(f"score: {trace.get_score()}")
    viz_trace(trace, iter+1, cloud=True)

score: -5862.845703125
score: -5862.845703125
score: -5862.845703125
score: -5862.845703125
score: -5862.845703125


In [14]:
trace.get_choices()["object_pose_0"]

Pose(position=Array([ 0.91359544, -0.14378065,  0.06828638], dtype=float32), quaternion=Array([ 0.0094165 ,  0.00671392, -0.01218714,  0.99985886], dtype=float32))