## Cognitive Battery Introduction: Jax-3DP3

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from jax3dp3.viz.img import save_depth_image
from jax3dp3.utils import depth_to_coords_in_camera
from jax3dp3.transforms_3d import transform_from_pos
from jax3dp3.shape import (
    get_rectangular_prism_shape,
)
from jax3dp3.likelihood import threedp3_likelihood
import jax.numpy as jnp

import jax
from scipy.spatial.transform import Rotation as R
from jax3dp3.rendering import render_planes_multiobject
from jax3dp3.enumerations import make_translation_grid_enumeration
from jax3dp3.enumerations_procedure import enumerative_inference_single_frame

Initialize camera metadata and path to data:

In [None]:
num_frames = 103
data_path = "data/videos"

width = 300
height = 300
fx = 150
fy = 150
cx = 150
cy = 150

fx_fy = jnp.array([fx, fy])
cx_cy = jnp.array([cx, cy])

K = jnp.array(
    [
        [fx_fy[0], 0.0, cx_cy[0]],
        [0.0, fx_fy[1], cx_cy[1]],
        [0.0, 0.0, 1.0],
    ]
)

Load ground-truth RGB images, depth, and segmentation data.

In [None]:
rgb_images, depth_images, seg_maps = [], [], []
rgb_images_pil = []
for i in range(num_frames):
    rgb_path = os.path.join(data_path, f"frames/frame_{i}.jpeg")
    rgb_img = Image.open(rgb_path)
    rgb_images_pil.append(rgb_img)
    rgb_images.append(np.array(rgb_img))

    depth_path = os.path.join(data_path, f"depths/frame_{i}.npy")
    depth_npy = np.load(depth_path)
    depth_images.append(depth_npy)

    seg_map = np.load(os.path.join(data_path, f"segmented/frame_{i}.npy"))
    seg_maps.append(seg_map)

Mask the depth and segmentation images to only include the relevant part of the scene (i.e. crop to the box above table).

In [None]:
masked_coord_images = []   # depth data in 2d view as images
masked_seg_images = []     # segmentation data as images

for frame_idx in range(num_frames):
    k = 5 if 5 <= frame_idx < 19 else 4  # 4 objects in frames [5:19]

    coord_image, _ = depth_to_coords_in_camera(depth_images[frame_idx], K)
    segmentation_image = seg_maps[frame_idx]
    mask = np.invert(
        (coord_image[:, :, 0] < 1.0)
        * (coord_image[:, :, 0] > -0.5)
        * (coord_image[:, :, 1] < 0.28)
        * (coord_image[:, :, 1] > -0.5)
        * (coord_image[:, :, 2] < 4.0)
        * (coord_image[:, :, 2] > 1.2)
    )
    coord_image[mask, :] = 0.0 
    segmentation_image[mask, :] = 0.0
    masked_coord_images.append(coord_image)
    masked_seg_images.append(segmentation_image)

masked_coord_images = np.stack(masked_coord_images)
masked_seg_images = np.stack(masked_seg_images)

In [None]:
start_t = 0

coord_image = masked_coord_images[start_t]
seg_image = masked_seg_images[start_t]
obj_ids = jnp.unique(seg_image[..., 0])

shape_planes, shape_dims, init_poses = [], [], []
for obj_id in obj_ids:
    if obj_id == 0: # Masked background
        continue
    obj_mask = seg_image[..., 0] == obj_id

    masked_coord_image = coord_image * obj_mask[:, :, None]
    masked_seg_image = seg_image * obj_mask[:, :, None]

    object_points = masked_coord_image[obj_mask]
    maxs = np.max(object_points, axis=0)
    mins = np.min(object_points, axis=0)
    dims = maxs - mins
    center_of_box = (maxs + mins) / 2

    init_pose = transform_from_pos(center_of_box)
    init_poses.append(init_pose)

    shape, dim = get_rectangular_prism_shape(dims)
    shape_planes.append(shape)
    shape_dims.append(dim)

In [None]:
start_t = 0

coord_image = masked_coord_images[start_t]
seg_image = masked_seg_images[start_t]

obj_mask = seg_image[..., 0] == 132


obj_id = 172
obj_mask = seg_image[..., 0] == obj_id

masked_coord_image = coord_image * obj_mask[:, :, None]
masked_seg_image = seg_image * obj_mask[:, :, None]

object_points = masked_coord_image[obj_mask]


In [None]:
object_points.shape