In [1]:
from PIL import Image
import jax3dp3 as j
import os
import numpy as np
import jax.numpy as jnp
from tqdm import tqdm
import jax

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [10]:
j.meshcat.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7002/static/


In [2]:
f = open(os.path.join(j.utils.get_assets_dir(), f"tum/livingRoom1.gt.freiburg"),"r")
data = f.readlines()
data = [d.strip('\n') for d in data]
poses = [jnp.eye(4)]

xyzw_to_rotation_matrix = jax.jit(j.t3d.xyzw_to_rotation_matrix)
transform_from_rot_and_pos = jax.jit(j.t3d.transform_from_rot_and_pos)
for i in tqdm(range(len(data))):
    xyzq = list(map(float,data[i].split(" ")))[1:]
    pos = jnp.array([xyzq[:3]])
    rot = xyzw_to_rotation_matrix(jnp.array(xyzq[3:]))
    pose = transform_from_rot_and_pos(rot, pos)
    poses.append(pose)
poses = jnp.array(poses)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 965/965 [00:00<00:00, 1237.37it/s]


In [3]:
IDX = 0
rgbs = []
original_depths = []
for IDX in tqdm(range(len(poses))):
    depth_filename = os.path.join(j.utils.get_assets_dir(), f"tum/depth/{IDX}.png")
    rgb_filename = os.path.join(j.utils.get_assets_dir(), f"tum/rgb/{IDX}.png")
    rgbs.append(jnp.array(Image.open(rgb_filename)))
    original_depths.append(np.array(Image.open(depth_filename)) / 5000)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 966/966 [00:14<00:00, 65.93it/s]


In [5]:
unproject_depth = jax.jit(j.t3d.unproject_depth)
apply_transform = jax.jit(j.t3d.apply_transform)

In [6]:
original_intrinsics = j.Intrinsics(
    original_depths[0].shape[0], original_depths[1].shape[1],
    481.20, -480.00,319.50,239.50,0.001, 1000.0
)
intrinsics = j.camera.scale_camera_parameters(original_intrinsics, 0.2)
depths = [j.utils.resize(d, intrinsics.height, intrinsics.width) for d in original_depths]

In [523]:
T1 = 250
point_cloud_1 = apply_transform(unproject_depth(depths[T1], intrinsics), poses[T1])

T2 = 255
point_cloud_2 = apply_transform(unproject_depth(depths[T2], intrinsics), poses[T2])

In [530]:
point_cloud_1 = unproject_depth(depths[T1], intrinsics)
point_cloud_2 = unproject_depth(depths[T2], intrinsics)

correction_transform = j.t3d.inverse_pose(poses[T1]) @ poses[T2]
point_cloud_2_corrected = apply_transform(unproject_depth(depths[T2], intrinsics),  correction_transform)

In [531]:
j.meshcat.clear()
j.meshcat.show_cloud("1", point_cloud_1.reshape(-1,3))
j.meshcat.show_cloud("2", point_cloud_2.reshape(-1,3), color=j.RED)

R = 0.001
OUTLIER_PROB = 0.1
OUTLIER_VOLUME = 1.0
j.threedp3_likelihood_jit(point_cloud_1, point_cloud_2, R, OUTLIER_PROB, OUTLIER_VOLUME)

Array(-20916.332, dtype=float32)

In [532]:
j.meshcat.clear()
j.meshcat.show_cloud("1", point_cloud_1.reshape(-1,3))
j.meshcat.show_cloud("2", point_cloud_2_corrected.reshape(-1,3), color=j.RED)

j.meshcat.show_pose("pose", correction_transform)
j.threedp3_likelihood_jit(point_cloud_1, j.t3d.apply_transform(point_cloud_2, 
    correction_transform), R, OUTLIER_PROB, OUTLIER_VOLUME)

Array(-24087.125, dtype=float32)

In [511]:
NUM_SAMPLES_FOR_ESTIMATE = 5000
keys = jax.random.split(jax.random.PRNGKey(2), NUM_SAMPLES_FOR_ESTIMATE)

In [512]:
pose_estimate = jnp.eye(4)

In [513]:
def refine_pose_estimate_inner(pose_estimate, point_cloud_1, point_cloud_2, keys, var, conc):
    keys = jax.random.split(keys[0], NUM_SAMPLES_FOR_ESTIMATE)
    pose_proposals = jax.vmap(lambda key: j.distributions.gaussian_vmf_sample(
        key, pose_estimate, var, conc))(
        keys
    )

    rendered_images = jnp.einsum(
        'aij,...j->a...i',
        pose_proposals,
        jnp.concatenate([point_cloud_2, jnp.ones(point_cloud_2.shape[:-1] + (1,))], axis=-1),
    )[..., :-1]

    best_score = j.threedp3_likelihood_jit(
        point_cloud_1, j.t3d.apply_transform(point_cloud_2, 
        pose_estimate), R, OUTLIER_PROB, OUTLIER_VOLUME)
    
    weights = j.threedp3_likelihood_parallel_jit(point_cloud_1, rendered_images, R, OUTLIER_PROB, OUTLIER_VOLUME)
    weights_max = weights.max()
    print(weights_max)
    better = (weights_max > best_score)
    pose_estimate = pose_proposals[weights.argmax()] * better + pose_estimate* (1.0 - better)
    return pose_estimate, keys

refine_pose_estimate_jit = jax.jit(refine_pose_estimate_inner)

In [516]:
pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.01, 1000.0)
pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.005, 2000.0)
pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.001, 1000.0)
best_score = j.threedp3_likelihood_jit(point_cloud_1, j.t3d.apply_transform(point_cloud_2, 
    pose_estimate), R, OUTLIER_PROB, OUTLIER_VOLUME)
best_score

Array(-39583.477, dtype=float32)

In [521]:
j.meshcat.clear()
j.meshcat.show_cloud("1", point_cloud_1.reshape(-1,3))
j.meshcat.show_cloud("2", point_cloud_2.reshape(-1,3), color=j.RED)

In [522]:
j.meshcat.clear()
j.meshcat.show_cloud("1", point_cloud_1.reshape(-1,3))
j.meshcat.show_cloud("2", j.t3d.apply_transform(point_cloud_2, pose_estimate).reshape(-1,3), color=j.RED)

In [432]:
point_clouds = [
    unproject_depth(depths[t], intrinsics)
    for t in jnp.arange(200, 300, 5)
]
transforms = []
for i in tqdm(range(len(point_clouds) - 1)):
    pose_estimate = jnp.eye(4)
    point_cloud_1, point_cloud_2 = point_clouds[i], point_clouds[i+1]
    pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.01, 1000.0)
    pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.005, 2000.0)
    pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.001, 1000.0)
    pose_estimate,keys = refine_pose_estimate_jit(pose_estimate, point_cloud_1, point_cloud_2, keys, 0.001, 1000.0)
    transforms.append(pose_estimate)


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:22<00:00,  1.19s/it]


In [446]:
i = 15

In [450]:
j.meshcat.clear()
j.meshcat.show_cloud("1", point_clouds[i].reshape(-1,3))
j.meshcat.show_cloud("2", point_clouds[i+1].reshape(-1,3), color=j.RED)


In [451]:
j.meshcat.clear()
j.meshcat.show_cloud("1", point_clouds[i].reshape(-1,3))
j.meshcat.show_cloud("2", j.t3d.apply_transform(point_clouds[i+1], transforms[i]).reshape(-1,3), color=j.RED)


In [445]:
len(point_clouds)

20

In [30]:
j.meshcat.clear()
j.meshcat.show_cloud("1", point_cloud_1.reshape(-1,3))


In [327]:
j.meshcat.clear()
j.meshcat.show_cloud("2", j.t3d.apply_transform(point_cloud_2, pose_estimate).reshape(-1,3), color=j.RED)

In [185]:
mesh = j.mesh.make_voxel_mesh_from_point_cloud(unproject_depth(depths[T1], intrinsics).reshape(-1,3), 0.05)

In [190]:
renderer = j.Renderer(intrinsics)
renderer.add_mesh(mesh)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (128, 96, 1024)


In [206]:
recontruction = renderer.render_single_object(j.t3d.inverse_pose(poses[T2]) @ poses[T1], 0)

In [207]:
j.meshcat.clear()
j.meshcat.show_trimesh("!", mesh)

In [143]:
j.meshcat.clear()
j.meshcat.show_cloud("1", point_cloud_1.reshape(-1,3))
j.meshcat.show_cloud("2", point_cloud_2.reshape(-1,3), color=j.RED)

In [None]:
T_WIDTH = 0.01
translation_grid = j.make_translation_grid_enumeration(
    -T_WIDTH,-T_WIDTH,-T_WIDTH,
    T_WIDTH,T_WIDTH,T_WIDTH,
    11,11,11
)
rotation_grid = j.make_rotation_grid_enumeration(
    50, 40, -jnp.pi/40, jnp.pi/40, jnp.pi/40
)


pose_estimate = jnp.eye(4)
best_score = -1000000.0

In [None]:
pose_proposals = jnp.einsum(
    'aij,jk->aik',
    translation_grid,
    pose_estimate,
    
)
rendered_images = jnp.einsum(
    'aij,...j->a...i',
    pose_proposals,
    jnp.concatenate([point_cloud_2, jnp.ones(point_cloud_2.shape[:-1] + (1,))], axis=-1),
)[..., :-1]

weights = j.threedp3_likelihood_parallel_jit(point_cloud_1, rendered_images, R, OUTLIER_PROB, OUTLIER_VOLUME)
weights_max = weights.max()
better = (weights_max > best_score)
pose_estimate = pose_proposals[weights.argmax()] * better + pose_estimate* (1.0 - better)
best_score = weights_max * better + best_score * (1.0 - better)
print(best_score)

pose_proposals = jnp.einsum(
    'ij,ajk->aik',
    pose_estimate,
    rotation_grid
)
rendered_images = jnp.einsum(
    'aij,...j->a...i',
    pose_proposals,
    jnp.concatenate([point_cloud_2, jnp.ones(point_cloud_2.shape[:-1] + (1,))], axis=-1),
)[..., :-1]


weights = j.threedp3_likelihood_parallel_jit(point_cloud_1, rendered_images, R, OUTLIER_PROB, OUTLIER_VOLUME)
weights_max = weights.max()
better = (weights_max > best_score)
pose_estimate = pose_proposals[weights.argmax()] * better + pose_estimate *(1.0 - better)
best_score = weights_max * better + best_score * (1.0 - better)
print(best_score)


In [93]:
j.meshcat.clear()
for i in range(1000):
    j.meshcat.show_pose(f"{i}", poses[i], size=0.01)

In [97]:
j.meshcat.clear()
j.meshcat.show_cloud("1", point_cloud_1.reshape(-1,3))


In [125]:
j.meshcat.clear()
j.meshcat.show_cloud("1", point_cloud_1.reshape(-1,3))
j.meshcat.show_cloud("2", point_cloud_2.reshape(-1,3), color=j.RED)

In [79]:
poses[1]

Array([[ 1.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  1.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  1.  , -2.25],
       [ 0.  ,  0.  ,  0.  ,  1.  ]], dtype=float32)

In [90]:
poses[100]

Array([[ 0.9232352 ,  0.25634083, -0.2862283 ,  0.0436907 ],
       [-0.25016412,  0.9664283 ,  0.05860607,  0.0111501 ],
       [ 0.2916422 ,  0.01749687,  0.95636773, -2.1024    ],
       [ 0.        ,  0.        ,  0.        ,  1.        ]],      dtype=float32)