 procrustes_3d进行对齐

In [6]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def procrustes_3d(X, Y):
    """
    Perform Procrustes alignment of 3D points.
    
    Args:
    - X (np.array): Source points, shape (N, 3)
    - Y (np.array): Target points, shape (N, 3)
    
    Returns:
    - Z (np.array): Aligned source points, shape (N, 3)
    """
    # Translate X and Y to their centroids
    X_centroid = X.mean(axis=0)
    Y_centroid = Y.mean(axis=0)
    X = X - X_centroid
    Y = Y - Y_centroid
    
    # Compute the covariance matrix
    covariance_matrix = np.dot(Y.T, X)
    
    # Singular Value Decomposition
    U, S, Vt = np.linalg.svd(covariance_matrix)
    
    # Compute the rotation matrix
    R = np.dot(U, Vt)
    
    # Apply the rotation to X
    Z = np.dot(X, R)
    
    # Scale the points
    scale = np.trace(np.dot(Z.T, Y)) / np.trace(np.dot(Z.T, Z))
    Z *= scale
    
    # Translate back to the target's centroid
    Z += Y_centroid
    
    return Z
import open3d as o3d

def visualize_3d_joints_open3d(pred_joints, gt_joints, aligned_joints):
    """
    Visualize 3D joints before and after Procrustes alignment using Open3D.
    
    Args:
    - pred_joints (np.array): Predicted joint positions, shape (K, 3)
    - gt_joints (np.array): Ground truth joint positions, shape (K, 3)
    - aligned_joints (np.array): Aligned predicted joint positions, shape (K, 3)
    """
    # Create Open3D point clouds
    pred_pcd = o3d.geometry.PointCloud()
    gt_pcd = o3d.geometry.PointCloud()
    aligned_pcd = o3d.geometry.PointCloud()
    
    pred_pcd.points = o3d.utility.Vector3dVector(pred_joints)
    gt_pcd.points = o3d.utility.Vector3dVector(gt_joints)
    aligned_pcd.points = o3d.utility.Vector3dVector(aligned_joints)
    
    # Set colors for visualization
    pred_pcd.paint_uniform_color([1, 0, 0])  # Red
    gt_pcd.paint_uniform_color([0, 1, 0])    # Green
    aligned_pcd.paint_uniform_color([0, 0, 1])  # Blue
    
    # Visualize the point clouds
    o3d.visualization.draw_geometries([pred_pcd, gt_pcd, aligned_pcd])



data_gt = np.load(r'E:\WorkSpace\inbed_pose_repos\CLIFF\slp_sample\p102.npz')
data_pred = np.load(r'E:\WorkSpace\inbed_pose_repos\CLIFF\slp_sample\cliff_2djoint.npz')

# Example usage for a single sample:
pred_joints_sample = data_pred['pred_joints'][0, :24, :]
gt_joints_sample = data_gt['gt_3D_joints'][0]

# Perform Procrustes alignment
aligned_pred_joints_sample = procrustes_3d(pred_joints_sample, gt_joints_sample)

# Visualize the joints
visualize_3d_joints_open3d(pred_joints_sample, gt_joints_sample, aligned_pred_joints_sample)


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [17]:
import torch

def batch_compute_similarity_transform_torch(S1, S2):
    '''
    Computes a similarity transform (sR, t) that takes
    a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
    where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
    i.e. solves the orthogonal Procrutes problem.
    '''
    transposed = False
    if S1.shape[0] != 3 and S1.shape[0] != 2:
        S1 = S1.permute(0,2,1)
        S2 = S2.permute(0,2,1)
        transposed = True
    assert(S2.shape[1] == S1.shape[1])

    # 1. Remove mean.
    mu1 = S1.mean(axis=-1, keepdims=True)
    mu2 = S2.mean(axis=-1, keepdims=True)

    X1 = S1 - mu1
    X2 = S2 - mu2

    # 2. Compute variance of X1 used for scale.
    var1 = torch.sum(X1**2, dim=1).sum(dim=1)

    # 3. The outer product of X1 and X2.
    K = X1.bmm(X2.permute(0,2,1))

    # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
    # singular vectors of K.
    U, s, V = torch.svd(K)

    # Construct Z that fixes the orientation of R to get det(R)=1.
    Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
    Z = Z.repeat(U.shape[0],1,1)
    Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1))))

    # Construct R.
    R = V.bmm(Z.bmm(U.permute(0,2,1)))

    # 5. Recover scale.
    scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1

    # 6. Recover translation.
    t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1)))

    # 7. Error:
    S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t

    if transposed:
        S1_hat = S1_hat.permute(0,2,1)

    return S1_hat, (scale, R, t)


def visualize_3d_joints_open3d(pred_joints, gt_joints, aligned_joints):
    """
    Visualize 3D joints before and after Procrustes alignment using Open3D.
    
    Args:
    - pred_joints (np.array): Predicted joint positions, shape (K, 3)
    - gt_joints (np.array): Ground truth joint positions, shape (K, 3)
    - aligned_joints (np.array): Aligned predicted joint positions, shape (K, 3)
    """
    # Ensure inputs are in the right shape and type
    pred_joints = np.asarray(pred_joints)
    gt_joints = np.asarray(gt_joints)
    aligned_joints = np.asarray(aligned_joints)
    
    # Create Open3D point clouds
    pred_pcd = o3d.geometry.PointCloud()
    gt_pcd = o3d.geometry.PointCloud()
    aligned_pcd = o3d.geometry.PointCloud()
    
    pred_pcd.points = o3d.utility.Vector3dVector(pred_joints)
    gt_pcd.points = o3d.utility.Vector3dVector(gt_joints)
    aligned_pcd.points = o3d.utility.Vector3dVector(aligned_joints)
    
    # Set colors for visualization
    pred_pcd.paint_uniform_color([1, 0, 0])  # Red
    gt_pcd.paint_uniform_color([0, 1, 0])    # Green
    aligned_pcd.paint_uniform_color([0, 0, 1])  # Blue
    
    # Visualize the point clouds
    o3d.visualization.draw_geometries([pred_pcd, gt_pcd, aligned_pcd])


data_gt = np.load(r'E:\WorkSpace\inbed_pose_repos\CLIFF\slp_sample\p102.npz')
data_pred = np.load(r'E:\WorkSpace\inbed_pose_repos\CLIFF\slp_sample\cliff_2djoint.npz')

# Example usage for a single sample:
pred_joints_sample = torch.from_numpy(data_pred['pred_joints'][:, :24, :])
gt_joints_sample = torch.from_numpy(data_gt['gt_3D_joints'])

# Perform Procrustes alignment
aligned_pred_joints,_ = batch_compute_similarity_transform_torch(pred_joints_sample, gt_joints_sample)

i = 1


# Example usage for a single sample:
pred_joints_sample = data_pred['pred_joints'][i, :24, :]
gt_joints_sample = data_gt['gt_3D_joints'][i]

# Perform Procrustes alignment
aligned_pred_joints_sample = procrustes_3d(pred_joints_sample, gt_joints_sample)

# Visualize the joints
visualize_3d_joints_open3d(pred_joints_sample, aligned_pred_joints_sample, aligned_pred_joints[i])
