In [51]:
import sys
import jax
import genjax
import bayes3d as b
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("../")
from viz import *
from PIL import Image

In [48]:
class MCS_Observation:
    def __init__(self, rgb, depth, intrinsics, segmentation):
        """RGBD Image
        
        Args:
            rgb (np.array): RGB image
            depth (np.array): Depth image
            camera_pose (np.array): Camera pose. 4x4 matrix
            intrinsics (b.camera.Intrinsics): Camera intrinsics
            segmentation (np.array): Segmentation image
        """
        self.rgb = rgb
        self.depth = depth
        self.intrinsics = intrinsics
        self.segmentation  = segmentation

load_observations = lambda x : np.load('val7_physics_npzs' + "/{}.npz".format(x),allow_pickle=True)["arr_0"]

def npz_to_data(observations, scale = 0.5):
    intrinsics_data = observations[0].intrinsics
    intrinsics = b.scale_camera_parameters(b.Intrinsics(intrinsics_data["height"],intrinsics_data["width"],
                            intrinsics_data["fx"], intrinsics_data["fy"],
                            intrinsics_data["cx"],intrinsics_data["cy"],
                            intrinsics_data["near"],intrinsics_data["far"]),scale)
    
    depths = [jax.image.resize(obs.depth, (int(obs.depth.shape[0] * scale), 
                        int(obs.depth.shape[1] * scale)), 'nearest') for obs in observations]
    segs = [jax.image.resize(obs.segmentation, (int(obs.segmentation.shape[0] * scale), 
                        int(obs.segmentation.shape[1] * scale)), 'nearest') for obs in observations]
    rgbs = [jax.image.resize(obs.rgb, (int(obs.rgb.shape[0] * scale), 
                        int(obs.rgb.shape[1] * scale), 3), 'nearest') for obs in observations]
    
    gt_images = b.unproject_depth_vmap_jit(jnp.stack(depths), intrinsics)

    return gt_images, depths, segs, rgbs


# def video_from_segmentation(segs):
#     np.random.seed(0)
#     segs_stacked = np.stack(segs)
#     palette = np.random.randint(0, 256, size=(len(np.unique(segs_stacked)), 3))
#     label_to_index = {label: index for index, label in enumerate(np.unique(segs_stacked))}

#     # Iterate over each frame
#     frames = []
#     for i, seg in enumerate(segs_stacked):

#         color_image = np.zeros((*seg.shape, 3), dtype=np.uint8)

#         # Map each label to its corresponding color
#         for label, index in label_to_index.items():
#             color_image[seg == label] = palette[index]

#         # Convert to PIL image
#         frames.append(Image.fromarray(color_image))

#     return display_video(frames)


In [56]:
observations = load_observations('passive_physics_validation_object_permanence_0001_24')
gt_images, depths, segs, rgbs = npz_to_data(observations,scale = 0.25)
video_from_rendered(gt_images, scale = 4)

In [58]:
gt_images.shape

(240, 100, 150, 3)