In [1]:
import os
import io
import trimesh
import b3d
import b3d.bayes3d as bayes3d
import fire
import genjax
import jax
import jax.numpy as jnp
import numpy as np
import rerun as rr
from PIL import Image
from b3d import Pose
from genjax import Pytree
from b3d import Mesh, Pose
from tqdm import tqdm
import h5py
genjax.pretty()

In [2]:
rr.init("demo")
rr.connect("127.0.0.1:8812")

In [3]:
# def unproject_pixels(pts, depth, cam_matrix, vfov=55, near_plane=0.1, far_plane=100):
#     '''
#     pts: [N, 2] pixel coords
#     depth: [N, ] depth values
#     returns: [N, 3] world coords
#     '''

    
#     camera_matrix = np.linalg.inv(cam_matrix.reshape((4, 4)))

#     # Different from real-world camera coordinate system.
#     # OpenGL uses negative z axis as the camera front direction.
#     # x axes are same, hence y axis is reversed as well.
#     # Source: https://learnopengl.com/Getting-started/Camera
#     rot = np.array([[1, 0, 0, 0],
#                     [0, -1, 0, 0],
#                     [0, 0, -1, 0],
#                     [0, 0, 0, 1]])
#     camera_matrix = np.dot(camera_matrix, rot)


#     height = depth.shape[0]
#     width = depth.shape[1]

#     img_pixs = pts[:, [1, 0]].T
#     img_pix_ones = np.concatenate((img_pixs, np.ones((1, img_pixs.shape[1]))))

#     # Calculate the intrinsic matrix from vertical_fov.
#     # Motice that hfov and vfov are different if height != width
#     # We can also get the intrinsic matrix from opengl's perspective matrix.
#     # http://kgeorge.github.io/2014/03/08/calculating-opengl-perspective-matrix-from-opencv-intrinsic-matrix
#     vfov = vfov / 180.0 * np.pi
#     tan_half_vfov = np.tan(vfov / 2.0)
#     tan_half_hfov = tan_half_vfov * width / float(height)
#     fx = width / 2.0 / tan_half_hfov  # focal length in pixel space
#     fy = height / 2.0 / tan_half_vfov
#     intrinsics = np.array([[fx, 0, width/ 2.0],
#                            [0, fy, height / 2.0],
#                            [0, 0, 1]])
#     img_inv = np.linalg.inv(intrinsics[:3, :3])
#     cam_img_mat = np.dot(img_inv, img_pix_ones)

#     points_in_cam = np.multiply(cam_img_mat, depth.reshape(-1))
#     points_in_cam = np.concatenate((points_in_cam, np.ones((1, points_in_cam.shape[1]))), axis=0)
#     points_in_world = np.dot(camera_matrix, points_in_cam)
#     points_in_world = points_in_world[:3, :].T#.reshape(3, height, width)
#     points_in_cam = points_in_cam[:3, :].T#.reshape(3, height, width)
    
#     return points_in_world


In [4]:
# paths for reading physion metadata
physion_assets_path = os.path.join(
    b3d.get_root_path(),
    "assets/physion/",)

hdf5_file_path = os.path.join(physion_assets_path,
    "hdf5s/pilot_dominoes_0mid_d3chairs_o1plants_tdwroom_0012.hdf5",
)

In [5]:
vfov=55 
near_plane=0.1
far_plane=100
depth_arr = []
image_arr = []
with h5py.File(hdf5_file_path, "r") as f:
    # extract depth info
    for key in f['frames'].keys():
        depth = jnp.array(f['frames'][key]['images']['_depth_cam0'])
        depth_arr.append(depth)
        image = jnp.array(Image.open(io.BytesIO(f['frames'][key]['images']['_img_cam0'][:])))
        image_arr.append(image)
    depth_arr = jnp.asarray(depth_arr)
    image_arr = jnp.asarray(image_arr)
    FINAL_T, height, width = image_arr.shape[0], image_arr.shape[1], image_arr.shape[2]

    # extract camera info
    camera_pose = np.array(f['azimuth']['cam_0'])
    camera_matrix = np.array(f['frames']['0000']['camera_matrices']['camera_matrix_cam0'])
    camera_matrix = np.linalg.inv(camera_matrix.reshape((4, 4)))

    # Different from real-world camera coordinate system.
    # OpenGL uses negative z axis as the camera front direction.
    # x axes are same, hence y axis is reversed as well.
    # Source: https://learnopengl.com/Getting-started/Camera
    rot = np.array([[1, 0, 0, 0],
                    [0, -1, 0, 0],
                    [0, 0, -1, 0],
                    [0, 0, 0, 1]])
    camera_matrix = np.dot(camera_matrix, rot)

    # Calculate the intrinsic matrix from vertical_fov.
    # Motice that hfov and vfov are different if height != width
    # We can also get the intrinsic matrix from opengl's perspective matrix.
    # http://kgeorge.github.io/2014/03/08/calculating-opengl-perspective-matrix-from-opencv-intrinsic-matrix
    vfov = vfov / 180.0 * np.pi
    tan_half_vfov = np.tan(vfov / 2.0)
    tan_half_hfov = tan_half_vfov * width / float(height)
    fx = width / 2.0 / tan_half_hfov  # focal length in pixel space
    fy = height / 2.0 / tan_half_vfov

    # extract object info
    object_ids = np.array(f['static']['object_ids'])
    model_names = np.array(f['static']['model_names'])
    distractors = np.array(f['static']['distractors'])
    occluders = np.array(f['static']['occluders'])
    initial_position = np.array(f['static']['initial_position'])
    initial_rotation = np.array(f['static']['initial_rotation'])
    scales = np.array(f['static']['scale'])
    meshes_faces = [np.array(f['static']['mesh'][f'faces_{idx}']) for idx in range(len(object_ids))]
    meshes_vertices = [np.array(f['static']['mesh'][f'vertices_{idx}']) for idx in range(len(object_ids))]

In [6]:
# for i in initial_rotation:
#     print(i)

In [7]:
image_width, image_height = int(height), int(width)
fx, fy, cx, cy, near, far = (
    float(fx),
    float(fy),
    float(width/2.0),
    float(height/2.0),
    float(near_plane),
    float(far_plane),
)

In [8]:
# scales

In [9]:
def scale_mesh(vertices, scale_factor):
    vertices[:, 0] *= scale_factor[0]
    vertices[:, 1] *= scale_factor[1]
    vertices[:, 2] *= scale_factor[2]
    return vertices

In [10]:
excluded_model_ids = np.concatenate((np.where(model_names==distractors), np.where(model_names==occluders)), axis=0)
included_model_names = [model_names[idx] for idx in range(len(object_ids)) if idx not in excluded_model_ids]
included_model_ids = [object_ids[idx]-1 for idx in range(len(object_ids)) if idx not in excluded_model_ids]
object_initial_positions = [pos for idx, pos in enumerate(initial_position) if idx in included_model_ids]
object_initial_rotations = [rot for idx, rot in enumerate(initial_rotation) if idx in included_model_ids]
object_scales = [scale for idx, scale in enumerate(scales) if idx in included_model_ids]
# object_meshes = [(vertex, face) for idx, (face, vertex) in enumerate(zip(meshes_faces, meshes_vertices)) if idx in included_model_ids]
object_meshes = [(scale_mesh(vertex, object_scales[idx]), face) for idx, (face, vertex) in enumerate(zip(meshes_faces, meshes_vertices)) if idx in included_model_ids]


In [11]:
# object_meshes

In [12]:
# path = os.path.join(
#     b3d.get_root_path(),
#     "assets/shared_data_bucket/input_data/shout_on_desk.r3d.video_input.npz",
# )
# video_input = b3d.io.VideoInput.load(path)

# # Get intrinsics
# image_width, image_height, fx, fy, cx, cy, near, far = np.array(
#     video_input.camera_intrinsics_depth
# )
# image_width, image_height = int(image_width), int(image_height)
# fx, fy, cx, cy, near, far = (
#     float(fx),
#     float(fy),
#     float(cx),
#     float(cy),
#     float(near),
#     float(far),
# )

# # Get RGBS and Depth
# rgbs = video_input.rgb[::4] / 255.0
# xyzs = video_input.xyz[::4]

In [13]:
# image_width, image_height

In [14]:
# fx, fy, cx, cy, near, far

In [15]:
# object_initial_positions

In [16]:
# # Make empty library
# object_library = bayes3d.MeshLibrary.make_empty_library()

# for (vertices, faces) in object_meshes:
#     object_library.add_object(vertices, faces)
# print(f"{object_library.get_num_objects()} objects in library")

In [17]:
# load original meshes without scaling
object_library = bayes3d.MeshLibrary.make_empty_library()

for model_name in included_model_names:
    object_library.add_trimesh(trimesh.load(os.path.join(physion_assets_path, f"all_flex_meshes/{model_name.decode('UTF-8')}.obj",)))

print(f"{object_library.get_num_objects()} objects in library")

3 objects in library


In [18]:
# Gridding on translation only.
translation_deltas = Pose.concatenate_poses(
    [
        jax.vmap(lambda p: Pose.from_translation(p))(
            jnp.stack(
                jnp.meshgrid(
                    jnp.linspace(-0.01, 0.01, 11),
                    jnp.linspace(-0.01, 0.01, 11),
                    jnp.linspace(-0.01, 0.01, 11),
                ),
                axis=-1,
            ).reshape(-1, 3)
        ),
        Pose.identity()[None, ...],
    ]
)
# Sample orientations from a VMF to define a "grid" over orientations.
rotation_deltas = Pose.concatenate_poses(
    [
        jax.vmap(Pose.sample_gaussian_vmf_pose, in_axes=(0, None, None, None))(
            jax.random.split(jax.random.PRNGKey(0), 11 * 11 * 11),
            Pose.identity(),
            0.00001,
            1000.0,
        ),
        Pose.identity()[None, ...],
    ]
)
all_deltas = Pose.stack_poses([translation_deltas, rotation_deltas])


In [19]:
# all_deltas

In [20]:
# all_deltas[1]

In [21]:
# Defines the enumeration schedule.
key = jax.random.PRNGKey(0)
renderer = b3d.Renderer(image_width, image_height, fx, fy, cx, cy, near, far)
model = bayes3d.model_multiobject_gl_factory(renderer)
importance_jit = jax.jit(model.importance)

# Arguments of the generative model.
# These control the inlier / outlier decision boundary for color error and depth error.
color_error, depth_error = (1e100, 0.01)
inlier_score, outlier_prob = (5.0, 0.00001)
color_multiplier, depth_multiplier = (10000.0, 500.0)
model_args = bayes3d.ModelArgs(
    color_error,
    depth_error,
    inlier_score,
    outlier_prob,
    color_multiplier,
    depth_multiplier,
)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [22]:
def euler_angles_to_quaternion(euler: np.ndarray) -> np.ndarray:
    """
    Convert Euler angles to a quaternion.

    Source: https://pastebin.com/riRLRvch

    :param euler: The Euler angles vector.

    :return: The quaternion representation of the Euler angles.
    """

    pitch = np.radians(euler[0] * 0.5)
    cp = np.cos(pitch)
    sp = np.sin(pitch)

    yaw = np.radians(euler[1] * 0.5)
    cy = np.cos(yaw)
    sy = np.sin(yaw)

    roll = np.radians(euler[2] * 0.5)
    cr = np.cos(roll)
    sr = np.sin(roll)

    x = sy * cp * sr + cy * sp * cr
    y = sy * cp * cr - cy * sp * sr
    z = cy * cp * sr - sy * sp * cr
    w = cy * cp * cr + sy * sp * sr
    return np.abs(np.array([x, y, z, w]))

In [23]:
# Pose.identity()

In [24]:
# image_arr[0]

In [29]:
# Initial trace for timestep 0
START_T = 0
trace, _ = importance_jit(
    jax.random.PRNGKey(0),
    genjax.ChoiceMap.d(
        dict(
            [
                ("camera_pose", Pose(jnp.array([0.0, 0.0, 0.0]), jnp.array([0.0, 0.0, 0.0, 1.0]))),
                ("object_pose_0", Pose(jnp.asarray(object_initial_positions[0]),jnp.asarray(euler_angles_to_quaternion(object_initial_rotations[0])))),
                ("object_pose_1", Pose(jnp.asarray(object_initial_positions[1]),jnp.asarray(euler_angles_to_quaternion(object_initial_rotations[1])))),
                ("object_pose_2", Pose(jnp.asarray(object_initial_positions[2]),jnp.asarray(euler_angles_to_quaternion(object_initial_rotations[2])))),
                ("object_0", 0),
                ("object_1", 1),
                ("object_2", 2),
                (
                    "observed_rgb_depth",
                    (image_arr[START_T], depth_arr[START_T]),
                ),
            ]
        )
    ),
    (jnp.arange(3), model_args, object_library),
)

In [30]:
bayes3d.rerun_visualize_trace_t(trace, 0)

In [31]:
object_initial_positions

In [32]:
depth_arr[0]

In [28]:
trace

In [35]:
trace.get_retval()

In [30]:
# Visualize trace
bayes3d.rerun_visualize_trace_t(trace, 0)


In [None]:
ACQUISITION_T = 90
for T_observed_image in tqdm(range(ACQUISITION_T)):
    # Constrain on new RGB and Depth data.
    trace = b3d.update_choices(
        trace,
        Pytree.const(["observed_rgb_depth"]),
        (rgbs_resized[T_observed_image], xyzs[T_observed_image, ..., 2]),
    )
    trace, key = bayes3d.enumerate_and_select_best_move(
        trace, Pytree.const(("camera_pose",)), key, all_deltas
    )
    bayes3d.rerun_visualize_trace_t(trace, T_observed_image)


In [None]:
# Outliers are AND of the RGB and Depth outlier masks
_inliers, _color_inliers, _depth_inliers, outliers, _undecided, _valid_data_mask = (
    bayes3d.get_rgb_depth_inliers_from_trace(trace)
)
outlier_mask = outliers
rr.log("outliers", rr.Image(jnp.tile((outlier_mask * 1.0)[..., None], (1, 1, 3))))

# Get the point cloud corresponding to the outliers
rgb, depth = trace.get_choices()["observed_rgb_depth"]
point_cloud = b3d.xyz_from_depth(depth, fx, fy, cx, cy)[outlier_mask]
point_cloud_colors = rgb[outlier_mask]

# Segment the outlier cloud.
assignment = b3d.segment_point_cloud(point_cloud)

# Only keep the largers cluster in the outlier cloud.
point_cloud = point_cloud.reshape(-1, 3)[assignment == 0]
point_cloud_colors = point_cloud_colors.reshape(-1, 3)[assignment == 0]

# Create new mesh.
vertices, faces, vertex_colors, _face_colors = (
    b3d.make_mesh_from_point_cloud_and_resolution(
        point_cloud, point_cloud_colors, point_cloud[:, 2] / fx * 2.0
    )
)
object_pose = Pose.from_translation(vertices.mean(0))
vertices = object_pose.inverse().apply(vertices)
object_library.add_object(vertices, faces, vertex_colors)

In [None]:
object_library

In [None]:
object_library.ranges

In [None]:
len(object_library.ranges)

In [None]:
single_object_trace = trace
trace = single_object_trace

trace, _ = importance_jit(
    jax.random.PRNGKey(0),
    genjax.ChoiceMap.d(
        dict(
            [
                ("camera_pose", trace.get_choices()["camera_pose"]),
                ("object_pose_0", trace.get_choices()["object_pose_0"]),
                ("object_pose_1", trace.get_choices()["camera_pose"] @ object_pose),
                ("object_0", 0),
                ("object_1", 1),
                (
                    "observed_rgb_depth",
                    (rgbs_resized[ACQUISITION_T], xyzs[ACQUISITION_T, ..., 2]),
                ),
            ]
        )
    ),
    (jnp.arange(2), model_args, object_library),
)

In [None]:
# Visualize trace
bayes3d.rerun_visualize_trace_t(trace, ACQUISITION_T)

In [None]:
FINAL_T = len(xyzs)
for T_observed_image in tqdm(range(ACQUISITION_T, FINAL_T)):
    # Constrain on new RGB and Depth data.
    trace = b3d.update_choices(
        trace,
        Pytree.const(("observed_rgb_depth",)),
        (rgbs_resized[T_observed_image], xyzs[T_observed_image, ..., 2]),
    )
    trace, key = bayes3d.enumerate_and_select_best_move(
        trace, Pytree.const(("camera_pose",)), key, all_deltas
    )
    trace, key = bayes3d.enumerate_and_select_best_move(
        trace, Pytree.const(("object_pose_1",)), key, all_deltas
    )
    bayes3d.rerun_visualize_trace_t(trace, T_observed_image)
