In [1]:
import jax

import jax.numpy as jnp
import numpy as np
import transforms3d as t3d

from tqdm import tqdm
from utils import read_canonical_model, load_pc, visualize_icp_result

In [18]:
def Kabsch(z, m):
    # Kabsch algorithm
    # z: source point cloud
    # m: target point cloud
    z_bar = np.mean(z, axis=0)
    m_bar = np.mean(m, axis=0)

    delta_z = (z - z_bar).reshape(len(z), 3, 1)
    delta_m = (m - m_bar).reshape(len(m), 3, 1)

    Q = np.zeros((3, 3))
    for i in range(len(z)):
        mult = delta_m[i] @ delta_z[i].T
        Q += mult
    # SVD
    U, s, V = np.linalg.svd(Q)
    det_UV = np.linalg.det(U @ V.T)
    mid_matrix = np.eye(3)
    mid_matrix[2][2] = det_UV
    R = U @ mid_matrix @ V
    p = m_bar - R @ z_bar

    return R, p

In [None]:
def ICP(z, m):
    # ICP algorithm
    # z: source point cloud
    # m: target point cloud
    # Transforming z into m
    
    # Initial guess
    
    # Random rotation
    # angles = np.random.rand(3) * 2 * np.pi - np.pi
    # R_0 = t3d.euler.euler2mat(angles[0], angles[1], angles[2])
    
    # Simple guess
    R_0 = np.eye(3)
    p_0 = np.zeros(3)

    m_hat = (R_0 @ z.T).T + p_0
    z_assoc = np.zeros_like(m)

    for _ in tqdm(range(20)):
        
        # Parallelize the closest point
        z_assoc = np.array([z[np.argmin(np.linalg.norm(m[i] - m_hat, axis=1))] for i in range(len(m))])

        # for i in range(len(m)):
        #     # Find the closest point
        #     dist = np.linalg.norm(m[i] - m_hat, axis=1)
        #     ind = np.argmin(dist)
        #     z_assoc[i] = z[ind]

        R, p = Kabsch(z_assoc, m)
        m_hat = (R @ z.T).T + p
        
        m_temp = (R @ z_assoc.T).T + p
        err = np.sum(np.linalg.norm(m - m_temp, axis=1))
        print(f"Avg. error: {err / len(m)}")
        
    return R, p

In [30]:
obj_name = 'liq_container' # drill or liq_container
num_pc = 4 # number of point clouds

source_pc = read_canonical_model(obj_name)
targets = []

for i in range(num_pc):
    target_pc = load_pc(obj_name, i)
    targets.append(target_pc)

print(source_pc.shape, target_pc.shape)
target_R_source, target_p_source = ICP(source_pc, targets[2])

# Estimated_pose, you need to estimate the pose with ICP
pose = np.eye(4)
pose[:3, :3] = target_R_source
pose[:3, 3] = target_p_source
pose[3][3] = 1

# Visualize the estimated result
visualize_icp_result(source_pc, target_pc, pose)
print(target_R_source, target_p_source)

(46563, 3) (19544, 3)


  5%|▌         | 1/20 [00:03<01:06,  3.51s/it]

Avg. error: 0.05262744785269032


 10%|█         | 2/20 [00:06<00:59,  3.33s/it]

Avg. error: 0.020087199931730564


 15%|█▌        | 3/20 [00:10<00:59,  3.50s/it]

Avg. error: 0.017016891463245854


 20%|██        | 4/20 [00:14<01:01,  3.85s/it]

Avg. error: 0.016121753656216802


 25%|██▌       | 5/20 [00:18<00:58,  3.87s/it]

Avg. error: 0.015739882603833407


 30%|███       | 6/20 [00:22<00:55,  3.96s/it]

Avg. error: 0.015494175283264737


 35%|███▌      | 7/20 [00:26<00:49,  3.78s/it]

Avg. error: 0.01529330206658057


 40%|████      | 8/20 [00:29<00:42,  3.53s/it]

Avg. error: 0.015111259225429066


 45%|████▌     | 9/20 [00:33<00:41,  3.75s/it]

Avg. error: 0.014944869494093191


 50%|█████     | 10/20 [00:37<00:38,  3.90s/it]

Avg. error: 0.014796672802768454


 55%|█████▌    | 11/20 [00:41<00:34,  3.83s/it]

Avg. error: 0.01466188869959768


 60%|██████    | 12/20 [00:45<00:32,  4.01s/it]

Avg. error: 0.014541144037242966


 65%|██████▌   | 13/20 [00:49<00:26,  3.85s/it]

Avg. error: 0.01443234108994424


 70%|███████   | 14/20 [00:53<00:23,  3.88s/it]

Avg. error: 0.014329892766660692


 75%|███████▌  | 15/20 [00:56<00:18,  3.77s/it]

Avg. error: 0.014238738007229445


 80%|████████  | 16/20 [01:00<00:15,  3.82s/it]

Avg. error: 0.014154409753779951


 85%|████████▌ | 17/20 [01:04<00:11,  3.96s/it]

Avg. error: 0.014077003268170369


 90%|█████████ | 18/20 [01:08<00:07,  3.69s/it]

Avg. error: 0.014004409404169676


 95%|█████████▌| 19/20 [01:11<00:03,  3.73s/it]

Avg. error: 0.013939580590321401


100%|██████████| 20/20 [01:16<00:00,  3.81s/it]

Avg. error: 0.013878787136280572
[[ 0.89379808 -0.34736047 -0.28366477]
 [ 0.34716301  0.93632311 -0.05269602]
 [ 0.2839064  -0.05137832  0.9574745 ]] [ 0.54841938 -0.21126736 -0.00369447]



