In [1]:
import os
import io
import trimesh
import b3d
import genjax
import jax
import b3d.bayes3d as bayes3d
import jax.numpy as jnp
import numpy as np
import rerun as rr
from PIL import Image
import matplotlib.pyplot as plt
import h5py
from tqdm import tqdm
from genjax import Pytree
import b3d.chisight.dense.dense_model
import b3d.chisight.dense.likelihoods.laplace_likelihood
genjax.pretty()


## define helper functions

In [2]:
def scale_mesh(vertices, scale_factor):
    vertices[:, 0] *= scale_factor[0]
    vertices[:, 1] *= scale_factor[1]
    vertices[:, 2] *= scale_factor[2]
    return vertices

def euler_angles_to_quaternion(euler: np.ndarray) -> np.ndarray:
    """
    Convert Euler angles to a quaternion.

    Source: https://pastebin.com/riRLRvch

    :param euler: The Euler angles vector.

    :return: The quaternion representation of the Euler angles.
    """
    pitch = np.radians(euler[0] * 0.5)
    cp = np.cos(pitch)
    sp = np.sin(pitch)

    yaw = np.radians(euler[1] * 0.5)
    cy = np.cos(yaw)
    sy = np.sin(yaw)

    roll = np.radians(euler[2] * 0.5)
    cr = np.cos(roll)
    sr = np.sin(roll)

    x = sy * cp * sr + cy * sp * cr
    y = sy * cp * cr - cy * sp * sr
    z = cy * cp * sr - sy * sp * cr
    w = cy * cp * cr + sy * sp * sr
    return np.array([x, y, z, w])

def get_mask_area(color, seg_img):
    arr = seg_img == color
    arr = arr.min(-1).astype('float')
    arr = arr.reshape((arr.shape[-1], arr.shape[-1]))
    return arr.astype(bool)

def unproject_pixels(mask, depth_map, cam_matrix, fx, fy, vfov=54.43222, near_plane=0.1, far_plane=100):
    '''
    pts: [N, 2] pixel coords
    depth: [N, ] depth values
    returns: [N, 3] world coords
    '''
    depth = depth_map[mask]
    pts = np.array([[x,y] for x,y in zip(np.nonzero(mask)[0], np.nonzero(mask)[1])])
    camera_matrix = np.linalg.inv(cam_matrix.reshape((4, 4)))

    # Different from real-world camera coordinate system.
    # OpenGL uses negative z axis as the camera front direction.
    # x axes are same, hence y axis is reversed as well.
    # Source: https://learnopengl.com/Getting-started/Camera
    rot = np.array([[1, 0, 0, 0],
                    [0, -1, 0, 0],
                    [0, 0, -1, 0],
                    [0, 0, 0, 1]])
    camera_matrix = np.dot(camera_matrix, rot)


    height = depth_map.shape[0]
    width = depth_map.shape[1]

    img_pixs = pts[:, [1, 0]].T
    img_pix_ones = np.concatenate((img_pixs, np.ones((1, img_pixs.shape[1]))))

    # Calculate the intrinsic matrix from vertical_fov.
    # Motice that hfov and vfov are different if height != width
    # We can also get the intrinsic matrix from opengl's perspective matrix.
    # http://kgeorge.github.io/2014/03/08/calculating-opengl-perspective-matrix-from-opencv-intrinsic-matrix
    intrinsics = np.array([[fx, 0, width/ 2.0],
                           [0, fy, height / 2.0],
                           [0, 0, 1]])
    img_inv = np.linalg.inv(intrinsics[:3, :3])
    cam_img_mat = np.dot(img_inv, img_pix_ones)

    points_in_cam = np.multiply(cam_img_mat, depth.reshape(-1))
    points_in_cam = np.concatenate((points_in_cam, np.ones((1, points_in_cam.shape[1]))), axis=0)
    points_in_world = np.dot(camera_matrix, points_in_cam)
    points_in_world = points_in_world[:3, :].T#.reshape(3, height, width)
    points_in_cam = points_in_cam[:3, :].T#.reshape(3, height, width)
    
    return points_in_world

def set_axes_equal(ax):
    """
    Make axes of 3D plot have equal scale so that spheres appear as spheres,
    cubes as cubes, etc.

    Input
      ax: a matplotlib axis, e.g., as output from plt.gca().
    """

    x_limits = ax.get_xlim3d()
    y_limits = ax.get_ylim3d()
    z_limits = ax.get_zlim3d()

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5*max([x_range, y_range, z_range])

    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])


## load basic information from hdf5

In [3]:
# paths for reading physion metadata
physion_assets_path = os.path.join(
    b3d.get_root_path(),
    "assets/physion/",)

# stim_name = 'lf_0/dominoes_all_movies/pilot_dominoes_0mid_d3chairs_o1plants_tdwroom_0012'
stim_name = '0012_dominoes'

hdf5_file_path = os.path.join(physion_assets_path,
    f"{stim_name}.hdf5",
)

mesh_file_path = os.path.join(physion_assets_path,
    f"all_flex_meshes/all",
)

In [4]:
vfov = 54.43222 
near_plane = 0.1
far_plane = 100
depth_arr = []
image_arr = []
with h5py.File(hdf5_file_path, "r") as f:
    # extract depth info
    for key in f['frames'].keys():
        depth = jnp.flip(jnp.array(f['frames'][key]['images']['_depth_cam0']), 0)
        depth_arr.append(depth)
        image = jnp.array(Image.open(io.BytesIO(f['frames'][key]['images']['_img_cam0'][:])))
        image_arr.append(image)
    depth_arr = jnp.asarray(depth_arr)
    image_arr = jnp.asarray(image_arr)/255
    FINAL_T, height, width = image_arr.shape[0], image_arr.shape[1], image_arr.shape[2]

    # extract camera info
    camera_azimuth = np.array(f['azimuth']['cam_0'])
    camera_matrix = np.array(f['frames']['0000']['camera_matrices']['camera_matrix_cam0']).reshape((4, 4))
    projection_matrix = np.array(f['frames']['0000']['camera_matrices']['projection_matrix_cam0']).reshape((4, 4))
    im_seg = np.array(Image.open(io.BytesIO(f['frames']['0000']['images']['_id_cam0'][:])))

    vfov = vfov / 180.0 * np.pi
    tan_half_vfov = np.tan(vfov / 2.0)
    tan_half_hfov = tan_half_vfov * width / float(height)
    fx = width / 2.0 / tan_half_hfov  # focal length in pixel space
    fy = height / 2.0 / tan_half_vfov

    # extract object info
    object_ids = np.array(f['static']['object_ids'])
    model_names = np.array(f['static']['model_names'])
    assert len(object_ids) == len(model_names)
    object_segmentation_colors = np.array(f['static']['object_segmentation_colors'])
    initial_position = np.array(f['static']['initial_position'])
    initial_rotation = np.array(f['static']['initial_rotation'])
    scales = jnp.array(f['static']['scale'])

In [5]:
# print(scales)

## b3d modeling

In [6]:
counter = np.unique(im_seg.reshape(-1, im_seg.shape[2]), axis=0)
num_obj = counter.shape[0]-1

In [7]:
# skipping the object id inference step, assuming that we already know which segment mask corresponds to which mesh
START_T = 0
pose_mesh_color_scale_from_point_cloud = []
for (model_name, color, scale) in zip(model_names, object_segmentation_colors, scales):
    area = get_mask_area(im_seg, color)
    point_cloud = jnp.asarray(unproject_pixels(area, depth_arr[START_T], camera_matrix, fx, fy))
    point_cloud_centroid = point_cloud.mean(0)
    point_cloud_bottom = min(point_cloud[:,1])
    object_pose = b3d.Pose.from_translation(jnp.array([point_cloud_centroid[0], point_cloud_bottom, point_cloud_centroid[2]]))
    # print("gt pose: ", object_pose)
    object_colors = jnp.asarray(image_arr[START_T][area])
    mean_object_colors = np.mean(object_colors, axis=0)
    trim = trimesh.load(os.path.join(mesh_file_path, f"{model_name.decode('UTF-8')}.obj"))
    bounding_box = trim.bounding_box
    bbox_corners = bounding_box.vertices
    original_size = jnp.array((max(bbox_corners[:,0])-min(bbox_corners[:,0]), max(bbox_corners[:,1])-min(bbox_corners[:,1]), max(bbox_corners[:,2])-min(bbox_corners[:,2])))
    point_cloud_size = jnp.array((max(point_cloud[:,0])-min(point_cloud[:,0]), max(point_cloud[:,1])-min(point_cloud[:,1]), max(point_cloud[:,2])-min(point_cloud[:,2])))
    object_scale = jnp.abs(point_cloud_size/original_size)
    mesh_info = (scale_mesh(trim.vertices, object_scale), trim.faces, jnp.ones(trim.vertices.shape)*mean_object_colors)
    # mesh_info = (scale_mesh(trim.vertices, scale), trim.faces, jnp.ones(trim.vertices.shape)*mean_object_colors)
    # print(object_scale)
    pose_mesh_color_scale_from_point_cloud.append((object_pose, mesh_info))

In [8]:
# %matplotlib widget

# for idx, (_, mesh) in enumerate(pose_mesh_color_scale_from_point_cloud):
#       fig = plt.figure()
#       ax = fig.add_subplot(projection='3d')
#       ax.set_box_aspect([1,1,1])
#       ax.plot_trisurf(mesh[0][:, 0], mesh[0][:,2], mesh[0][:,1], triangles=mesh[1], color=np.mean(mesh[2], axis=0).tolist())
#       ax.set_title(idx)
#       set_axes_equal(ax)

In [9]:
# all_object_poses_gt = []
# for idx in range(len(object_ids)):
#     object_pose = b3d.Pose(jnp.asarray(initial_position[idx]), jnp.asarray(euler_angles_to_quaternion(initial_rotation[idx])))
#     all_object_poses_gt.append(object_pose)
# all_object_poses_gt

In [10]:
object_library = []
for (_, obj) in pose_mesh_color_scale_from_point_cloud:
    object_library.append(b3d.Mesh(obj[0], obj[1], obj[2]))

print(f"{len(object_library)} objects in library")

5 objects in library


In [11]:
R = camera_matrix[:3,:3]
T = camera_matrix[0:3, 3]
a = np.array([-R[0,:], -R[1,:], -R[2,:]])
b = np.array(T)
camera_position_from_matrix = np.linalg.solve(a, b)
camera_rotation_from_matrix = -np.transpose(R)
camera_pose = b3d.Pose(
    camera_position_from_matrix,
    b3d.Rot.from_matrix(camera_rotation_from_matrix).as_quat()
)

In [12]:
# Defines the enumeration schedule.
scaling_factor = 1.0
renderer = b3d.renderer.renderer_original.RendererOriginal(
    width * scaling_factor,
    height * scaling_factor,
    fx * scaling_factor,
    fy * scaling_factor,
    (width/2) * scaling_factor,
    (height/2) * scaling_factor,
    near_plane,
    far_plane,
)

b3d.reload(b3d.chisight.dense.dense_model)
b3d.reload(b3d.chisight.dense.likelihoods.laplace_likelihood)
likelihood_func = b3d.chisight.dense.likelihoods.laplace_likelihood.likelihood_func
model, viz_trace, info_from_trace = (
    b3d.chisight.dense.dense_model.make_dense_multiobject_model(
        renderer, likelihood_func
    )
)
importance_jit = jax.jit(model.importance)

likelihood_args = {
    "fx": renderer.fx,
    "fy": renderer.fy,
    "cx": renderer.cx,
    "cy": renderer.cy,
    "image_width": Pytree.const(renderer.width),
    "image_height": Pytree.const(renderer.height),
}


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [13]:
num_pose_grid = 9

# Gridding on translation only.
translation_deltas = b3d.Pose.concatenate_poses(
    [
        jax.vmap(lambda p: b3d.Pose.from_translation(p))(
            jnp.stack(
                jnp.meshgrid(
                    jnp.linspace(-0.1, 0.1, num_pose_grid),
                    jnp.linspace(-0.1, 0.1, num_pose_grid),
                    jnp.linspace(-0.1, 0.1, num_pose_grid),
                ),
                axis=-1,
            ).reshape(-1, 3)
        ),
        b3d.Pose.identity()[None, ...],
    ]
)
# Sample orientations from a VMF to define a "grid" over orientations.
rotation_deltas = b3d.Pose.concatenate_poses(
    [
        jax.vmap(b3d.Pose.sample_gaussian_vmf_pose, in_axes=(0, None, None, None))(
            jax.random.split(jax.random.PRNGKey(0), num_pose_grid * num_pose_grid * num_pose_grid),
            b3d.Pose.identity(),
            0.0001,
            100.0,
        ),
        b3d.Pose.identity()[None, ...],
    ]
)
all_deltas = b3d.Pose.stack_poses([translation_deltas, rotation_deltas])

num_scale_grid = 5
scale_deltas = dict(
    [("x", jnp.array([[x,1,1] for x in jnp.linspace(0.6, 1.1, num_scale_grid)]+[[1,1,1]]).reshape(-1, 3)), 
     ("y", jnp.array([[1,y,1] for y in jnp.linspace(0.6, 1.1, num_scale_grid)]+[[1,1,1]]).reshape(-1, 3)), 
     ("z", jnp.array([[1,1,z] for z in jnp.linspace(0.6, 1.1, num_scale_grid)]+[[1,1,1]]).reshape(-1, 3))])

In [14]:
rgbds = jnp.concatenate([jnp.flip(image_arr,2), jnp.reshape(jnp.flip(depth_arr,2), depth_arr.shape+(1,))], axis=-1)
# rgbds.shape

In [15]:
# fig = plt.figure(figsize=[10, 5])
# ax = fig.add_subplot(121)
# ax.imshow(rgbds[START_T][..., 0:3])

# ax = fig.add_subplot(122)
# ax.imshow(rgbds[START_T][..., -1])

In [21]:
rr.init("demo")
rr.connect("127.0.0.1:8812")
# b3d.rr_init("demo")
rr.log("/", rr.ViewCoordinates.LEFT_HAND_Y_UP, static=True)  # Set an up-axis

In [22]:
# Initial trace for timestep 0
choice_map = dict([("camera_pose", camera_pose), 
                   ("rgbd", rgbds[START_T]), 
                   ("depth_noise_variance", 0.01),
                   ("color_noise_variance", 0.08),
                   ("outlier_probability", 0.1)] + 
                   [(f"object_pose_{idx}", pose_mesh_color_scale_from_point_cloud[idx][0]) for idx in range(num_obj)]
                   + [(f"object_scale_{idx}", jnp.array([1.0, 1.0, 1.0])) for idx in range(num_obj)]
                   )

trace, _ = importance_jit(
    jax.random.PRNGKey(0),
    genjax.ChoiceMap.d(
        choice_map
    ),
    (
        {
            "num_objects": Pytree.const(num_obj),
            "meshes": object_library,
            "likelihood_args": likelihood_args,
            # "check_interp": Pytree.const(False),
        },
    ),
)
original_trace = trace
viz_trace(trace, 0, cloud=True)

In [18]:
print(trace.get_score())

58654.45


In [19]:
trace = original_trace
num_inference_step = 5
for seed in range(num_inference_step):
    print(seed)
    key = jax.random.PRNGKey(seed)
    for idx in range(num_obj):
        print(f"obj {idx}")
        # trace, key = bayes3d.enumerate_and_select_best_move_new(
        #     trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['x']]
        #     # trace, Pytree.const((f"object_pose_{idx}",)), key, [all_deltas]
        # )
        trace, key = bayes3d.enumerate_and_select_best_move(
            trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['x']]
            # trace, Pytree.const((f"object_pose_{idx}",)), key, [all_deltas]
        )
        trace, key = bayes3d.enumerate_and_select_best_move(
            trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['y']]
        )
        trace, key = bayes3d.enumerate_and_select_best_move(
            trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['z']]
        )
        print(trace.get_score())
    viz_trace(trace, seed+1, cloud=True)

0
obj 0


2024-09-18 20:02:58.566745: W external/xla/xla/service/hlo_rematerialization.cc:2948] Can't reduce memory use below 16.67GiB (17895516261 bytes) by rematerialization; only reduced to 17.54GiB (18838890356 bytes), down from 17.54GiB (18838913116 bytes) originally


RuntimeError: OpenGL error: GL_INVALID_VALUE 1281[glTexImage3D(GL_TEXTURE_2D_ARRAY, 0, GL_RGBA32F, s.width, s.height, s.depth, 0, GL_RGBA, GL_UNSIGNED_BYTE, 0);]
Exception raised from rasterizeResizeBuffers at /home/hlwang_ipe_genjax/b3d/src/b3d/renderer/nvdiffrast_jax/nvdiffrast/common/rasterize_gl.cpp:413 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xb2 (0x7f69bf54ef12 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/torch/lib/../../../../libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xd2 (0x7f69bf4f9958 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/torch/lib/../../../../libc10.so)
frame #2: rasterizeResizeBuffers(int, RasterizeGLState&, bool&, int, int, int, int, int) + 0xbdd (0x7f6dbdc6a468 in /home/hlwang_ipe_genjax/.cache/torch_extensions/py312_cu120/nvdiffrast_plugin_original_gl/nvdiffrast_plugin_original_gl.so)
frame #3: _rasterize_fwd_gl(CUstream_st*, RasterizeGLStateWrapper&, float const*, int const*, std::vector<int, std::allocator<int> >, std::vector<int, std::allocator<int> >, float*, float*) + 0x1b0 (0x7f6dbdc911f4 in /home/hlwang_ipe_genjax/.cache/torch_extensions/py312_cu120/nvdiffrast_plugin_original_gl/nvdiffrast_plugin_original_gl.so)
frame #4: jax_rasterize_fwd_gl(CUstream_st*, void**, char const*, unsigned long) + 0x279 (0x7f6dbdc91551 in /home/hlwang_ipe_genjax/.cache/torch_extensions/py312_cu120/nvdiffrast_plugin_original_gl/nvdiffrast_plugin_original_gl.so)
frame #5: <unknown function> + 0x3cee0fe (0x7f6c9ac190fe in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #6: <unknown function> + 0x3cf0c15 (0x7f6c9ac1bc15 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #7: <unknown function> + 0x4a9a3dd (0x7f6c9b9c53dd in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #8: <unknown function> + 0x4a9c0c0 (0x7f6c9b9c70c0 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #9: <unknown function> + 0x4a9c81f (0x7f6c9b9c781f in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #10: <unknown function> + 0x70d7d16 (0x7f6c9e002d16 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #11: <unknown function> + 0x1555791 (0x7f6c98480791 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #12: <unknown function> + 0x15561ea (0x7f6c984811ea in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #13: <unknown function> + 0x14d2add (0x7f6c983fdadd in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #14: <unknown function> + 0x14d3f35 (0x7f6c983fef35 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #15: <unknown function> + 0x14d62e4 (0x7f6c984012e4 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #16: <unknown function> + 0x15b6330 (0x7f6c984e1330 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #17: <unknown function> + 0x14569b9 (0x7f6c983819b9 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #18: <unknown function> + 0x1458512 (0x7f6c98383512 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #19: <unknown function> + 0xa0aa79 (0x7f6c97935a79 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #20: <unknown function> + 0x1601bec (0x7f6c9852cbec in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #21: PyObject_Vectorcall + 0x4f (0x556a0de48fcf in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #22: <unknown function> + 0x113f65 (0x556a0dd3af65 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #23: _PyObject_FastCallDictTstate + 0x1fa (0x556a0de36c1a in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #24: _PyObject_Call_Prepend + 0xe9 (0x556a0de62899 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #25: <unknown function> + 0x311693 (0x556a0df38693 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #26: _PyObject_Call + 0xb5 (0x556a0de652e5 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #27: <unknown function> + 0x114ca6 (0x556a0dd3bca6 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #28: PyObject_Vectorcall + 0x4f (0x556a0de48fcf in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #29: <unknown function> + 0xa7d848 (0x7f6c979a8848 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #30: <unknown function> + 0xa7e3cc (0x7f6c979a93cc in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #31: <unknown function> + 0x114ca6 (0x556a0dd3bca6 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #32: <unknown function> + 0x25aeec (0x556a0de81eec in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #33: <unknown function> + 0x25a9fd (0x556a0de819fd in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #34: _PyObject_Call + 0x123 (0x556a0de65353 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #35: <unknown function> + 0x114ca6 (0x556a0dd3bca6 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #36: PyObject_Vectorcall + 0x4f (0x556a0de48fcf in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #37: <unknown function> + 0xa7d848 (0x7f6c979a8848 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #38: <unknown function> + 0xa7e3cc (0x7f6c979a93cc in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/jaxlib/xla_extension.so)
frame #39: PyObject_Vectorcall + 0x4f (0x556a0de48fcf in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #40: <unknown function> + 0x113f65 (0x556a0dd3af65 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #41: PyEval_EvalCode + 0xae (0x556a0deeeece in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #42: <unknown function> + 0x2e3edd (0x556a0df0aedd in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #43: <unknown function> + 0x115857 (0x556a0dd3c857 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #44: <unknown function> + 0x2de852 (0x556a0df05852 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #45: <unknown function> + 0x2dfb87 (0x556a0df06b87 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #46: <unknown function> + 0x116206 (0x556a0dd3d206 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #47: <unknown function> + 0x25aeec (0x556a0de81eec in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #48: <unknown function> + 0x25a9fd (0x556a0de819fd in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #49: _PyObject_Call + 0x123 (0x556a0de65353 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #50: <unknown function> + 0x114ca6 (0x556a0dd3bca6 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #51: <unknown function> + 0x2de852 (0x556a0df05852 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #52: <unknown function> + 0x8200 (0x7f6dd93f1200 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/lib-dynload/_asyncio.cpython-312-x86_64-linux-gnu.so)
frame #53: <unknown function> + 0x89b4 (0x7f6dd93f19b4 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/lib/python3.12/lib-dynload/_asyncio.cpython-312-x86_64-linux-gnu.so)
frame #54: <unknown function> + 0x22beaa (0x556a0de52eaa in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #55: <unknown function> + 0x35de9c (0x556a0df84e9c in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #56: <unknown function> + 0x1ca7b3 (0x556a0ddf17b3 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #57: <unknown function> + 0x222246 (0x556a0de49246 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #58: <unknown function> + 0x114ca6 (0x556a0dd3bca6 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #59: PyEval_EvalCode + 0xae (0x556a0deeeece in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #60: <unknown function> + 0x2e3edd (0x556a0df0aedd in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #61: <unknown function> + 0x222246 (0x556a0de49246 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #62: PyObject_Vectorcall + 0x4f (0x556a0de48fcf in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)
frame #63: <unknown function> + 0x113f65 (0x556a0dd3af65 in /home/hlwang_ipe_genjax/b3d/.pixi/envs/gpu/bin/python)


In [23]:
trace = original_trace
num_inference_step = 5
for seed in range(num_inference_step):
    print(seed)
    key = jax.random.PRNGKey(seed)
    for idx in range(num_obj):
        print(f"obj {idx}")
        trace, key = bayes3d.enumerate_and_select_best_move_new(
            trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['x']]
        )
        trace, key = bayes3d.enumerate_and_select_best_move_new(
            trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['y']]
        )
        trace, key = bayes3d.enumerate_and_select_best_move_new(
            trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['z']]
        )
        # trace, key = bayes3d.enumerate_and_select_best_move(
        #     trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['x']]
        #     # trace, Pytree.const((f"object_pose_{idx}",)), key, [all_deltas]
        # )
        # trace, key = bayes3d.enumerate_and_select_best_move(
        #     trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['y']]
        # )
        # trace, key = bayes3d.enumerate_and_select_best_move(
        #     trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['z']]
        # )
        print(trace.get_score())
    viz_trace(trace, seed+1, cloud=True)

0
obj 0
60722.027
obj 1
72385.97
obj 2
90382.625
obj 3
92462.89
obj 4
111269.84
1
obj 0
111335.3
obj 1
111335.3
obj 2
111557.016
obj 3
111552.56
obj 4
112445.72
2
obj 0
112445.72
obj 1
112445.72
obj 2
112445.72
obj 3
112445.72
obj 4
113298.53
3
obj 0
113298.53
obj 1
113298.53
obj 2
113298.53
obj 3
113298.53
obj 4
113298.53
4
obj 0
113298.53
obj 1
113298.53
obj 2
113298.53
obj 3
113298.53
obj 4
113298.53


In [20]:
# for idx in range(num_obj):
#     addr = f"object_scale_{idx}"
#     current_scale = trace.get_choices()[addr]
#     print(current_scale) 

In [21]:
# new_scale = jnp.array([0.1,0.1,0.1])
# new_trace = trace.update(
#         jax.random.PRNGKey(0),
#         genjax.ChoiceMap.d({'object_scale_3': new_scale,}),
#     )[0]
# viz_trace(new_trace, START_T, cloud=True)

In [22]:
# print(new_trace.get_score())

In [23]:
# viz_trace(new_trace, START_T, cloud=True)

In [24]:
# viz_trace(trace, START_T)

In [1]:
FINAL_T = len(image_arr)
for T_observed_image in tqdm(range(FINAL_T)):
    # Constrain on new RGB and Depth data.
    trace = b3d.update_choices(
        trace,
        Pytree.const(("rgbd",)),
        rgbds[T_observed_image],
    )
 
    for idx in range(num_obj):
        trace, key = bayes3d.enumerate_and_select_best_move_new(
            trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['x']]
        )
        trace, key = bayes3d.enumerate_and_select_best_move_new(
            trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['y']]
        )
        trace, key = bayes3d.enumerate_and_select_best_move_new(
            trace, Pytree.const((f"object_pose_{idx}", f"object_scale_{idx}",)), key, [all_deltas, scale_deltas['z']]
        )
    viz_trace(trace, T_observed_image+num_inference_step, cloud=True)

NameError: name 'image_arr' is not defined