In [None]:
import bayes3d as b
import bayes3d as j
import glob
import os
import jax.numpy as jnp
import jax
from tqdm import tqdm
import machine_common_sense as mcs
import numpy as np

In [None]:
scene_regex = os.path.join(b.utils.get_assets_dir(), "mcs_scene_jsons", "eval_6_validation", "passive_physics_spatio_temporal_continuity*")
# scene_regex = os.path.join(j.utils.get_assets_dir(), "mcs_scene_jsons", "eval_6_validation", "passive_physics_object*")
files = sorted(glob.glob(scene_regex))
files

In [None]:
def load_mcs_scene_data(scene_path):
    cache_dir = os.path.join(b.utils.get_assets_dir(), "mcs_cache")
    scene_name = scene_path.split("/")[-1]
    
    cache_filename = os.path.join(cache_dir, f"{scene_name}.npz")
    if os.path.exists(cache_filename):
        images = np.load(cache_filename,allow_pickle=True)["arr_0"]
    else:
        controller = mcs.create_controller(
            os.path.join(b.utils.get_assets_dir(), "mcs_scene_jsons",  "config_level2.ini")
        )

        scene_data = mcs.load_scene_json_file(scene_path)

        step_metadata = controller.start_scene(scene_data)
        image = b.RGBD.construct_from_step_metadata(step_metadata)

        step_metadatas = [step_metadata]
        while True:
            step_metadata = controller.step("Pass")
            if step_metadata is None:
                break
            step_metadatas.append(step_metadata)

        all_images = []
        for i in tqdm(range(len(step_metadatas))):
            all_images.append(b.RGBD.construct_from_step_metadata(step_metadatas[i]))

        images = all_images
        np.savez(cache_filename, images)

    return images

In [None]:
# images = jax3dp3.physics.load_mcs_scene_data(files[1])[90:]
# images = jax3dp3.physics.load_mcs_scene_data(files[2])[102:]
# images = jax3dp3.physics.load_mcs_scene_data(files[0])[102:150]


filename = files[0]

images = load_mcs_scene_data(filename)[85:]



filename = files[1]
images = load_mcs_scene_data(filename)[85:]


# filename = files[2]
# images = load_mcs_scene_data(filename)[85:]

# filename = files[3]
# images = load_mcs_scene_data(filename)[85:]


# filename = files[3]
# images = jax3dp3.physics.load_mcs_scene_data(filename)[102:152]


# filename = files[3]
# images = jax3dp3.physics.load_mcs_scene_data(filename)[102:152]

# filename = files[4]
# images = jax3dp3.physics.load_mcs_scene_data(filename)[102:152]

# filename = files[6]
# images = jax3dp3.physics.load_mcs_scene_data(filename)[85:152]

# filename = files[7]
# images = jax3dp3.physics.load_mcs_scene_data(filename)[85:152]

# filename = files[8]
# images = jax3dp3.physics.load_mcs_scene_data(filename)[90:152]

# filename = files[9]
# images = jax3dp3.physics.load_mcs_scene_data(filename)[90:152]



In [None]:
scene_name = filename.split("/")[-1]
scene_name

In [None]:
b.get_rgb_image(images[20].rgb)

In [None]:
WALL_Z = 14.5
original_intrinsics = images[0].intrinsics
intrinsics = j.camera.scale_camera_parameters(original_intrinsics, 0.25)
intrinsics = j.Intrinsics(
    intrinsics.height, intrinsics.width,
    intrinsics.fx,
    intrinsics.fy,
    intrinsics.cx,
    intrinsics.cy,
    intrinsics.near,
    WALL_Z + 0.1
)


dx  = 0.7
dy = 0.7
dz = 0.7
gridding = [
    j.make_translation_grid_enumeration(
        -dx, -dy, -dz, dx, dy, dz, 21,15,15
    ),
    j.make_translation_grid_enumeration(
        -dx/2.0, -dy/2, -dz/2, dx/2, dy/2, dz/2, 21,15,15
    ),
    j.make_translation_grid_enumeration(
        -dx/10.0, -dy/10, -dz/10, dx/10, dy/10, dz/10, 21,15,15
    )
]

In [None]:
def get_object_mask(point_cloud_image, segmentation, segmentation_ids):
    object_mask = jnp.zeros(point_cloud_image.shape[:2])
    object_ids = []
    for id in segmentation_ids:
        point_cloud_segment = point_cloud_image[segmentation == id]
        bbox_dims, pose = j.utils.aabb(point_cloud_segment)
        is_occluder = jnp.logical_or(jnp.logical_or(jnp.logical_or(jnp.logical_or(
            (bbox_dims[0] < 0.1),
            (bbox_dims[1] < 0.1)),
            (bbox_dims[1] > 1.1)),
            (bbox_dims[0] > 1.1)),
            (bbox_dims[2] > 2.1)
        )
        if not is_occluder:
            object_mask += (segmentation == id)
            object_ids.append(id)

    object_mask = jnp.array(object_mask) > 0
    return object_ids, object_mask

def prior3(new_pose, prev_pose, prev_prev_pose, bbox_dims):    
    score = 0.0
    new_position = new_pose[:3,3]
    bottom_of_object_y = new_position[1] + bbox_dims[1]/2.0

    prev_position = prev_pose[:3,3]
    prev_prev_position = prev_prev_pose[:3,3]

    velocity_prev = (prev_position - prev_prev_position) * jnp.array([1.0, 1.0, 0.25])
    velocity_with_gravity = velocity_prev + jnp.array([-jnp.sign(velocity_prev[0])*0.01, 0.02, 0.0])

    velocity_with_gravity2 = velocity_with_gravity * jnp.array([1.0 * (jnp.abs(velocity_with_gravity[0]) > 0.1), 1.0, 1.0 ])
    velocity = velocity_with_gravity2

    pred_new_position = prev_position + velocity

    score = score + jax.scipy.stats.multivariate_normal.logpdf(
        new_position, pred_new_position, jnp.diag(jnp.array([0.02, 0.02, 0.02]))
    )
    score += -100.0 * (bottom_of_object_y > 1.5)
    return score

prior_jit = jax.jit(prior3)
prior_parallel_jit = jax.jit(jax.vmap(prior3, in_axes=(0, None,  None, None)))

def update_object_positions(OBJECT_POSES, ALL_OBJECT_POSES):
    for known_id in range(OBJECT_POSES.shape[0]):

        current_pose_estimate = OBJECT_POSES[known_id, :, :]

        for gridding_iter in range(len(gridding)):
            all_pose_proposals = [
                jnp.einsum("aij,jk->aik", 
                    gridding[gridding_iter],
                    current_pose_estimate,
                )
            ]
            if gridding_iter == 0:
                for seg_id in object_ids:
                    _, center_pose = j.utils.aabb(point_cloud_image[segmentation==seg_id])
                    all_pose_proposals.append(
                        jnp.einsum("aij,jk->aik", 
                            gridding[gridding_iter],
                            center_pose,
                        )
                    )
            all_pose_proposals = jnp.vstack(all_pose_proposals)

            all_weights = []
            for batch in jnp.array_split(all_pose_proposals,3):
                
                rendered_images = renderer.render_parallel(batch, known_id)[...,:3]
                rendered_images_spliced = j.splice_image_parallel(rendered_images, point_cloud_image_complement)

                weights = j.threedp3_likelihood_parallel_jit(
                    point_cloud_image, rendered_images_spliced, R, OUTLIER_PROB, OUTLIER_VOLUME, 3
                ).reshape(-1)

                prev_pose = ALL_OBJECT_POSES[-1][known_id]
                if ALL_OBJECT_POSES[-2].shape[0] <= known_id:
                    prev_prev_pose =  ALL_OBJECT_POSES[-1][known_id]
                else:
                    prev_prev_pose =  ALL_OBJECT_POSES[-2][known_id]

                weights += prior_parallel_jit(
                    batch, prev_pose, prev_prev_pose, renderer.model_box_dims[known_id]
                ).reshape(-1)

                all_weights.append(weights)
            all_weights = jnp.hstack(all_weights)

            current_pose_estimate = all_pose_proposals[all_weights.argmax()]

        OBJECT_POSES = OBJECT_POSES.at[known_id].set(current_pose_estimate)
    return OBJECT_POSES

def add_new_objects(OBJECT_POSES):
    for seg_id in object_ids:
        average_probability = jnp.mean(pixelwise_probs[segmentation == seg_id])
        print(seg_id, average_probability)

        if average_probability > -10.0:
            print("avererage_probibilty ", average_probability)
            continue

        num_pixels = jnp.sum(segmentation == seg_id)
        if num_pixels < 14:
            print("num_pixels", num_pixels)
            continue

        rows, cols = jnp.where(segmentation == seg_id)
        distance_to_edge_1 = min(jnp.abs(rows - 0).min(), jnp.abs(rows - intrinsics.height).min())
        distance_to_edge_2 = min(jnp.abs(cols - 0).min(), jnp.abs(cols - intrinsics.width).min())

        point_cloud_segment = point_cloud_image[segmentation == seg_id]
        dims, pose = j.utils.aabb(point_cloud_segment)

        BUFFER = 1

        if distance_to_edge_1 < BUFFER or distance_to_edge_2 < BUFFER:
            print("distance to edge ", distance_to_edge_1, " ", distance_to_edge_2)
            continue

        resolution = 0.01
        voxelized = jnp.rint(point_cloud_segment / resolution).astype(jnp.int32)
        min_z = voxelized[:,2].min()
        depth = voxelized[:,2].max() - voxelized[:,2].min()

        front_face = voxelized[voxelized[:,2] <= min_z+20, :]
        slices = [front_face]
        for i in range(depth):
            slices.append(front_face + jnp.array([0.0, 0.0, i]))
        full_shape = jnp.vstack(slices) * resolution

        print("Seg ID: ", seg_id, "Prob: ", average_probability, " Pixels: ",num_pixels, " dists: ", distance_to_edge_1, " ", distance_to_edge_2, " Pose: ", pose[:3, 3])

        dims, pose = j.utils.aabb(full_shape)
        mesh = j.mesh.make_marching_cubes_mesh_from_point_cloud(
            j.t3d.apply_transform(full_shape, j.t3d.inverse_pose(pose)),
            0.075
        )

        renderer.add_mesh(mesh)
        print("Adding new mesh")

        OBJECT_POSES = jnp.concatenate([OBJECT_POSES, pose[None, ...]], axis=0)
    return OBJECT_POSES

In [None]:
R = 0.01
OUTLIER_PROB=0.01
OUTLIER_VOLUME=100.0
ALL_OBJECT_POSES = [jnp.zeros((0, 4, 4))]
t = 0

renderer = j.Renderer(intrinsics)


In [None]:
for t in tqdm(range(1, len(images))):
    print(f"{t}/{len(images)}")
    image = images[t]
    depth = j.utils.resize(image.depth, intrinsics.height, intrinsics.width)
    point_cloud_image = j.t3d.unproject_depth(depth, intrinsics)
    segmentation = j.utils.resize(image.segmentation, intrinsics.height, intrinsics.width)
    segmentation_ids = jnp.unique(segmentation)
    object_ids, object_mask = get_object_mask(point_cloud_image, segmentation, segmentation_ids)
    j.get_depth_image(1.0 * object_mask)
    depth_complement = depth * (1.0 - object_mask) + intrinsics.far * (object_mask)
    point_cloud_image_complement = j.t3d.unproject_depth(depth_complement, intrinsics)

    OBJECT_POSES = jnp.array(ALL_OBJECT_POSES[-1])
    OBJECT_POSES = update_object_positions(OBJECT_POSES, ALL_OBJECT_POSES)
#     OBJECT_POSES = update_object_positions(OBJECT_POSES, ALL_OBJECT_POSES)
#     OBJECT_POSES = update_object_positions(OBJECT_POSES, ALL_OBJECT_POSES)

    rerendered = renderer.render_multiobject(OBJECT_POSES, jnp.arange(OBJECT_POSES.shape[0]))[...,:3]
    rerendered_spliced = j.splice_image_parallel(jnp.array([rerendered]), point_cloud_image_complement)[0]
    pixelwise_probs = j.threedp3_likelihood_per_pixel_jit(point_cloud_image, rerendered_spliced,  R, 0.0, 1.0, 5)

    OBJECT_POSES = add_new_objects(OBJECT_POSES)
    
    ALL_OBJECT_POSES.append(OBJECT_POSES)

In [None]:
data = []
for t in tqdm(range(len(images))):
    image = images[t]
    depth = j.utils.resize(image.depth, intrinsics.height, intrinsics.width)
    point_cloud_image = j.t3d.unproject_depth(depth, intrinsics)
    segmentation = j.utils.resize(image.segmentation, intrinsics.height, intrinsics.width)
    segmentation_ids = jnp.unique(segmentation)
#     object_ids, object_mask = j.physics.get_object_mask(point_cloud_image, segmentation, segmentation_ids)
    j.get_depth_image(1.0 * object_mask)
    depth_complement = depth * (1.0 - object_mask) + intrinsics.far * (object_mask)
    point_cloud_image_complement = j.t3d.unproject_depth(depth_complement, intrinsics)

    OBJECT_POSES = ALL_OBJECT_POSES[t]
    rerendered = renderer.render_multiobject(OBJECT_POSES, jnp.arange(OBJECT_POSES.shape[0]))
    rerendered_spliced = j.splice_image_parallel(jnp.array([rerendered[...,:3]]), point_cloud_image_complement)[0]
    pixelwise_probs = j.threedp3_likelihood_per_pixel_jit(point_cloud_image, rerendered_spliced, R, 0.0, 1.0, 5)


    weights = []
    if t >= 2:
        for known_id in range(len(ALL_OBJECT_POSES[t])):
            if ALL_OBJECT_POSES[t-1].shape[0] <= known_id:
                continue

            if ALL_OBJECT_POSES[t-2].shape[0] <= known_id:
                continue

            prev_pose = ALL_OBJECT_POSES[t-1][known_id]
            if ALL_OBJECT_POSES[t-2].shape[0] <= known_id:
                prev_prev_pose =  ALL_OBJECT_POSES[t-1][known_id]
            else:
                prev_prev_pose =  ALL_OBJECT_POSES[t-2][known_id]

            weight = prior_jit(
                ALL_OBJECT_POSES[t][known_id],
                ALL_OBJECT_POSES[t-1][known_id],
                ALL_OBJECT_POSES[t-2][known_id],
                renderer.model_box_dims[known_id]
            ).reshape(-1)
            weights.append(weight)

    data.append(
        (
            image.rgb,
            point_cloud_image,
            rerendered,
            rerendered_spliced,
            pixelwise_probs,
            weights
        )
    )
weights_over_time = [jnp.array(d[-1]).sum() for d in data]

In [None]:
import matplotlib.pyplot as plt
import io
import numpy as np
from PIL import Image
t = 50

def make_plot(x,y, xlabel):
    plt.clf()
    color = np.array([229, 107, 111])/255.0
    plt.plot(x,y, color=color)
    plt.xlim(0, len(images))
    plt.ylim(-800.0, 100.0)
    plt.xlabel("Time",fontsize=20)
    plt.ylabel("Log Probability",fontsize=20)
    plt.tight_layout()
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png')
    im = Image.open(img_buf)
    return im


In [None]:
weights_over_time

In [None]:
viz_panels = []
for t in tqdm(range(len(images))):
    rgb, point_cloud_image, rerendered, rerendered_spliced, pixelwise_probs, weights = data[t]
    plots = make_plot(np.arange(t), weights_over_time[:t], "Time")
    factor = rgb.shape[0] / plots.height


    v = j.multi_panel([
        j.get_rgb_image(rgb),
        j.scale_image(j.get_depth_image(rerendered[:,:,2], min=4.0,max=15.0),4),
        j.overlay_image(j.scale_image(j.get_depth_image(rerendered[:,:,2], min=4.0,max=15.0),4), j.get_rgb_image(rgb)),
        j.scale_image(plots, factor)
    ],
    ["Observed RGB", "Inferred Objects", "Overlay", "Probability"],
        label_fontsize=50,
    )
    viz_panels.append(v)
j.make_gif(viz_panels, f"{scene_name}.gif")

In [None]:
len(ALL_OBJECT_POSES)

In [None]:
j.meshcat.setup_visualizer()

In [None]:
j.meshcat.show_trimesh("1",renderer.meshes[1])

In [None]:
ALL_OBJECT_POSES.append(OBJECT_POSES)


In [None]:
r = jnp.ones_like(point_cloud_image[:,:,2]) * 0.005
key = jax.random.PRNGKey(10)
noisy_point_cloud_image = jax.random.multivariate_normal(
    key, point_cloud_image[:,:,:3], (jnp.eye(3)[None, None, :, :] * r[:,:,None,None]), shape=r.shape
)
img = j.render_point_cloud(noisy_point_cloud_image.reshape(-1,3), intrinsics)
j.scale_image(j.get_depth_image(img[:,:,2]),10)

In [None]:
plt.plot(weights_over_time)