In [None]:
import jax.numpy as jnp
import bayes3d as b
import trimesh
import os
import numpy as np
import trimesh
import jax


# --- creating the model dir from the working directory
model_dir = os.path.join(j.utils.get_assets_dir(), "ycb_video_models/models")
print(f"{model_dir} exists: {os.path.exists(model_dir)}")
model_names = j.ycb_loader.MODEL_NAMES
IDX = 13
name = model_names[IDX]
print(name)


bop_ycb_dir = os.path.join(j.utils.get_assets_dir(), "bop/ycbv")
rgbd, gt_ids, gt_poses, masks = j.ycb_loader.get_test_img('52', '1', bop_ycb_dir)
intrinsics = j.Intrinsics(
    height=rgbd.intrinsics.height,
    width=rgbd.intrinsics.width,
    fx=rgbd.intrinsics.fx, fy=rgbd.intrinsics.fx,
    cx=rgbd.intrinsics.width/2.0, cy=rgbd.intrinsics.height/2.0,
    near=0.001, far=2.0
)

In [None]:
import bayes3d.posecnn_densefusion
densefusion = jax3dp3.posecnn_densefusion.DenseFusion()

In [None]:
mesh_path = os.path.join(model_dir,name,"textured.obj")
print(mesh_path)
mesh = j.mesh.load_mesh(mesh_path)

In [None]:
NUM_IMAGES_PER_ITER = 5
FIXED_TRANSLATION = jnp.array([0.0, 0.08324493, 1.0084537])
_seed = 1222
key = jax.random.PRNGKey(_seed) 
object_poses = jax.vmap(lambda key: j.distributions.gaussian_vmf(key, 0.00001, 0.001))(
    jax.random.split(key, NUM_IMAGES_PER_ITER)
)
object_poses = object_poses.at[:,:3,3].set(FIXED_TRANSLATION)

In [None]:
## fetch dataset
DATASET_FILENAME = f"dataset_{_seed}_{NUM_IMAGES_PER_ITER}.npz"  # npz file
DATASET_FILE = os.path.join(j.utils.get_assets_dir(), f"datasets/{DATASET_FILENAME}")

load_from_existing = False

# generate and save the dataset
if not load_from_existing:
    rgbds = j.kubric_interface.render_multiobject_parallel([mesh_path], object_poses[None,...],
                                                          intrinsics, scaling_factor=1.0, lighting=1.0) # multi img singleobj
    np.savez(DATASET_FILE, rgbds=rgbds, poses=object_poses, id=IDX, name=model_names[IDX], intrinsics=intrinsics, mesh_path=mesh_path)

# or load preexisting dataset
else:
    data = np.load(DATASET_FILE,allow_pickle=True)
    rgbds = data["rgbds"]
    object_poses = data["poses"]
    id = data["id"].item()

rgb_images = j.hvstack_images([j.get_rgb_image(r.rgb) for r in rgbds], 1, 5)
rgb_images.show()
rgb_images.save(f"dataset_{NUM_IMAGES_PER_ITER}_seed_{_seed}.png")


In [None]:
# run densefusion on dataset
all_results = []
for scene_idx, rgbd in enumerate(rgbds):
    results = densefusion.get_densefusion_results(rgbd.rgb, rgbd.depth, rgbd.intrinsics, scene_name=str(scene_idx))
    all_results.extend(results)

In [None]:
# process densefusion results
import pickle
with open(f"dataset_{NUM_IMAGES_PER_ITER}_seed_{_seed}.pkl", 'wb') as f:
    pickle.dump(all_results, f)

translation_err = jnp.zeros((1,3))
for results in all_results:
    pred_rot = results[name]['rot_q']
    pred_transl = results[name]['tr']
    translation_err += pred_transl

avg_translation_err = translation_err / len(all_results)
avg_translation_err

In [None]:
## Visualize densefusion outputs

from PIL import Image

VIZ_DIR = os.path.join(os.getcwd(), "Densefusion_iterative_result/")

densefusion_vizs = []

for scene_idx, _ in enumerate(rgbds):
    im = Image.open(os.path.join(VIZ_DIR, f"{scene_idx}.png"))
    densefusion_vizs.append(im)
densefusion_result_viz = j.hvstack_images(densefusion_vizs, 1, 5)
densefusion_result_viz.show()

In [None]:
rgb_images.show()  # original dataset for comparison