In [1]:
import jax
import jax.numpy as jnp
import bayes3d as b
import bayes3d.icp
import cv2
import numpy as np

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7007/static/


In [28]:
def pixel_reconstruction(structure, posevec, intrinsics):
    pose = b.t3d.transform_from_posevec(posevec)
    reconstruction = b.project_cloud_to_pixels(b.t3d.apply_transform(structure, b.t3d.inverse_pose(pose)), intrinsics)
    return reconstruction

def loss(structure, posevecs, intrinsics, obs_pixels):
    return ((jax.vmap(pixel_reconstruction, in_axes=(None, 0, None))(structure, posevecs, intrinsics) - obs_pixels)**2).mean()

loss_jit = jax.jit(loss)

gradient_structure_func_jit = jax.jit(jax.grad(loss, argnums=0))
grad_poses_func_jit = jax.jit(jax.grad(loss, argnums=1))

In [19]:
intrinsics = b.Intrinsics(
    300,
    300,
    200.0,200.0,
    150.0,150.0,
    0.001, 50.0
)

key = jax.random.PRNGKey(10)

pose_random = b.distributions.gaussian_vmf_sample(key, b.t3d.transform_from_pos(jnp.array([0.0, 0.0, 1.0])), 0.1, 100.0)
# Create a random 3D cloud representing an object
structure = b.t3d.apply_transform(jax.random.uniform(key, shape=(30,3)) * 0.2, pose_random)

# Sample two camera poses to view the structure from
key = jax.random.split(key, 1)[0]
pose_1 = jnp.eye(4)
pose_2 = b.t3d.inverse_pose(b.distributions.gaussian_vmf_sample(key, b.t3d.transform_from_pos(jnp.array([0.0, 0.0, 0.0])), 0.1, 100.0))

b.clear()
b.show_cloud("structure", structure)
b.show_pose("1", pose_1)
b.show_pose("2", pose_2)

pixels_1 = pixel_reconstruction(structure, b.t3d.transform_to_posevec(pose_1),intrinsics)
pixels_2 = pixel_reconstruction(structure, b.t3d.transform_to_posevec(pose_2),intrinsics)

In [22]:
K = jnp.array([
    [intrinsics.fx, 0.0, intrinsics.cx],
    [0.0, intrinsics.fy,  intrinsics.cy],
    [0.0, 0.0, 1.0],
])

E, _ = cv2.findEssentialMat(np.array(pixels_1), np.array(pixels_2), np.array(K))
points, R, t, mask = cv2.recoverPose(
    E, np.array(pixels_1), np.array(pixels_2), np.array(K), 10000.0
)
baseline_transform = b.t3d.inverse_pose(b.t3d.transform_from_rot_and_pos(R,t))

P_0 = K @ jnp.hstack([jnp.eye(3), jnp.zeros((3,1))])
P_1 = K @ jnp.hstack([R, t.reshape(-1,1)])

in_1 = np.array(pixels_1).transpose()
in_2 = np.array(pixels_2).transpose()
points_homogenous = cv2.triangulatePoints(
    np.array(P_0), np.array(P_1), np.array(pixels_1).transpose(), np.array(pixels_2).transpose()
).transpose()
points_3d = points_homogenous[:,:3] / points_homogenous[:,3].reshape(-1,1)

print(jnp.abs(pixel_reconstruction(points_3d, b.t3d.transform_to_posevec(jnp.eye(4)),intrinsics) - pixels_1).sum())
print(jnp.abs(pixel_reconstruction(points_3d, b.t3d.transform_to_posevec(baseline_transform),intrinsics) - pixels_2).sum())

b.clear()
b.show_cloud("structure", points_3d)
b.show_pose("1", jnp.eye(4))
b.show_pose("2", baseline_transform)

0.0006713867
0.0011138916


In [37]:
# This must be confusing why i multiply by baseline_transform. But its because in pose_estimates,
# i can't have the identity pose. It is problematic for gradients for some reason.
structure_estimate = b.t3d.apply_transform(jnp.array(points_3d), baseline_transform)
pose_estimates = jnp.array(
    [
        b.t3d.transform_to_posevec(jnp.eye(4) @ baseline_transform),
        b.t3d.transform_to_posevec(baseline_transform @ baseline_transform)
    ]
)
obs_pixels = jnp.stack([pixels_1, pixels_2])

print(loss_jit(structure_estimate, pose_estimates, intrinsics, obs_pixels))
print( gradient_structure_func_jit(structure_estimate, pose_estimates, intrinsics, obs_pixels))
print(grad_poses_func_jit(structure_estimate, pose_estimates, intrinsics, obs_pixels))

7.600951e-10
[[-2.84359849e-05  4.60372439e-06  2.05823217e-05]
 [-3.02022854e-05  9.60086163e-06  1.84139190e-05]
 [-2.66648003e-05  3.16381943e-06  1.70179337e-05]
 [-2.35313910e-05  1.15076309e-05  1.97183854e-05]
 [-2.06757031e-05  1.13164906e-05  1.52605990e-05]
 [-2.65833387e-05  2.89150626e-06  1.76866433e-05]
 [-2.63409183e-05  8.12571500e-07  1.31392471e-05]
 [-3.24703942e-05  4.04622142e-06  1.59098727e-05]
 [-2.94261063e-05  4.85079090e-06  2.23064781e-05]
 [-1.10832116e-05  4.52422410e-06  8.70693293e-06]
 [-2.23968964e-05  1.16567471e-05  2.06933182e-05]
 [-1.86983380e-05  1.01043734e-05  1.54906411e-05]
 [-3.69652917e-05  5.68599489e-06  2.02190131e-05]
 [-3.40309234e-05  6.81734537e-06  1.72832497e-05]
 [-3.05951326e-05  1.46924804e-05  2.73677124e-05]
 [-2.03685340e-05  1.12851258e-05  1.44735150e-05]
 [-2.65355975e-05  6.39584459e-06  1.63240675e-05]
 [-2.36113647e-05  1.24313756e-05  2.01727344e-05]
 [-2.53183243e-05  3.32914374e-06  1.48325389e-05]
 [-3.41088271e-05 

In [39]:
learning_rate= 0.0001
for i in range(5):
    loss_value = loss_jit(structure_estimate, pose_estimates, intrinsics, obs_pixels)
    grad_structure = gradient_structure_func_jit(structure_estimate, pose_estimates, intrinsics, obs_pixels)
    structure_estimate -= learning_rate * grad_structure
    grad_poses = grad_poses_func_jit(structure_estimate, pose_estimates, intrinsics, obs_pixels)
    pose_estimates -= learning_rate * grad_poses
    loss_value = loss_jit(structure_estimate, pose_estimates, intrinsics, obs_pixels)
    print(loss_value)
    if loss_value < 0.0001:
        print("Loss is low enough!")
        break

1.7350734e-09
Loss is low enough!


In [None]:
jnp.array([b.t3d.transform_to_posevec(p) for p in [pose_1, pose_2]]).shape

(2, 6)

In [9]:
valid, rvec, tvec= cv2.solvePnP(np.array(b.t3d.apply_transform(structure, pose_1)), np.array(pixels_2), K, None)
transform = b.t3d.transform_from_rot_and_pos(b.t3d.rotation_from_rodrigues(rvec), tvec)
transform

Array([[ 0.8831218 , -0.3343354 , -0.32911363,  0.07618109],
       [ 0.3666367 ,  0.92952377,  0.03953705, -0.04631384],
       [ 0.29270032, -0.15558115,  0.94346225,  0.8632487 ],
       [ 0.        ,  0.        ,  0.        ,  1.        ]],      dtype=float32)

In [10]:
K = jnp.array([
    [intrinsics.fx, 0.0, intrinsics.cx],   
    [0.0, intrinsics.fy, intrinsics.cy],   
    [0.0, 0.0, 1.0],
])

0.2368308
0.23676616
0.23670565
0.23664394
0.23658368
0.23652315
0.23646268
0.23639977
0.23633814
0.236277
0.23621714
0.23615582
0.23609531
0.23603669
0.23597154
0.23591217
0.23585293
0.23579033
0.23572978
0.23566617
0.23560831
0.2355469
0.2354851
0.23542352
0.23536436
0.23530367
0.23524423
0.23518217
0.23512211
0.23506273
0.23500253
0.23494026
0.23488063
0.23481756
0.23475869
0.23469895
0.23463744
0.23457834
0.23451659
0.23445871
0.23439682
0.23433977
0.23427752
0.23421623
0.23415685
0.23409802
0.23403868
0.23397823
0.23391709
0.23385736
0.23379818
0.2337386
0.23368081
0.23361662
0.23355705
0.23350048
0.2334383
0.23337997
0.23332039
0.23325983
0.2332005
0.23314108
0.23308301
0.23302251
0.23296088
0.23290253
0.23284242
0.23278454
0.2327254
0.23266676
0.23260581
0.23254672
0.23248675
0.23242691
0.2323702
0.23231067
0.23225161
0.23219183
0.23213148
0.23207521
0.2320143
0.23195523
0.23189719
0.23183706
0.23177946
0.23172131
0.23166072
0.23160145
0.23154256
0.2314837
0.23142365
0.23136587


In [42]:
jax.grad(loss, argnums=1)(structure, poses, obs_pixels)

Array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]], dtype=float32)

In [None]:
b.t3init_poses[0]

In [51]:
transform_to_posvec(pose_1)

ValueError: Zero-dimensional arrays cannot be concatenated.

In [40]:
pose_2

Array([[ 0.8844032 , -0.33645874, -0.32346025,  0.07613744],
       [ 0.3701166 ,  0.9278015 ,  0.04688475, -0.04827813],
       [ 0.28433213, -0.16118303,  0.94507957,  0.8631969 ],
       [ 0.        ,  0.        ,  0.        ,  1.        ]],      dtype=float32)

In [46]:
rodriguez_to_rotation(jnp.array(rvec))

Array([[ 0.8747357 ,  0.47217974,  0.10901201],
       [-0.4841344 ,  0.8613571 ,  0.15387577],
       [-0.02124125, -0.1873771 ,  0.98205835]], dtype=float32)

In [44]:
cv2.Rodrigues(rvec)

(array([[ 0.87473574,  0.47217979,  0.10901202],
        [-0.48413447,  0.86135711,  0.15387576],
        [-0.02124125, -0.18737711,  0.98205836]]),
 array([[-0.00375773,  0.06230547, -0.23972006,  0.00407614,  0.17069488,
         -0.94268156, -0.2476511 ,  0.94167657,  0.17431572],
        [-0.06530594, -0.09847135,  0.9505523 , -0.07624571,  0.00158748,
         -0.24877579, -0.95155729, -0.24084476, -0.0665348 ],
        [ 0.47947362, -0.87148103, -0.07262487,  0.87047605,  0.47834893,
          0.06107662, -0.0948505 ,  0.00284728, -0.00150829]]))