In [9]:
import numpy as np
import torch
from IPython import embed
import torch.nn.functional as F

# compute the relative pose
def normalize_vector( v):
    batch=v.shape[0]
    v_mag = torch.sqrt(v.pow(2).sum(1))# batch
    v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8])))
    v_mag = v_mag.view(batch,1).expand(batch,v.shape[1])
    v = v/v_mag
    return v

def compute_quaternions_from_rotation_matrices(matrices):
    batch=matrices.shape[0]
    
    w=torch.sqrt(torch.max(1.0 + matrices[:,0,0] + matrices[:,1,1] + matrices[:,2,2], torch.zeros(1))) / 2.0
    w = torch.max (w , torch.autograd.Variable(torch.zeros(batch))+1e-8) #batch
    w4 = 4.0 * w
    x= (matrices[:,2,1] - matrices[:,1,2]) / w4
    y= (matrices[:,0,2] - matrices[:,2,0]) / w4
    z= (matrices[:,1,0] - matrices[:,0,1]) / w4
    quats = torch.cat((w.view(batch,1), x.view(batch, 1),y.view(batch, 1), z.view(batch, 1) ), 1   )
    quats = normalize_vector(quats)
    return quats

def compute_rotation_matrix_from_quaternion( quaternion, n_flag=True):
    batch=quaternion.shape[0]
    if n_flag:
        quat = normalize_vector(quaternion)
    else:
        quat = quaternion
    qw = quat[...,0].view(batch, 1)
    qx = quat[...,1].view(batch, 1)
    qy = quat[...,2].view(batch, 1)
    qz = quat[...,3].view(batch, 1)

    # Unit quaternion rotation matrices computatation  
    xx = qx*qx
    yy = qy*qy
    zz = qz*qz
    xy = qx*qy
    xz = qx*qz
    yz = qy*qz
    xw = qx*qw
    yw = qy*qw
    zw = qz*qw

    row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
    row1 = torch.cat((2*xy+ 2*zw,  1-2*xx-2*zz, 2*yz-2*xw  ), 1) #batch*3
    row2 = torch.cat((2*xz-2*yw,   2*yz+2*xw,   1-2*xx-2*yy), 1) #batch*3
    
    matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
    
    return matrix

def rot_err_q(est_pose, gt_pose):
                 
    est_pose_q = F.normalize(est_pose, p=2, dim=1)
    gt_pose_q = F.normalize(gt_pose, p=2, dim=1)
    inner_prod = torch.bmm(est_pose_q.view(est_pose_q.shape[0], 1, est_pose_q.shape[1]),
                           gt_pose_q.view(gt_pose_q.shape[0], gt_pose_q.shape[1], 1)) 
    # if torch.abs(inner_prod) <= 1:
    orient_err = 2 * torch.acos(torch.abs(inner_prod)) * 180 / torch.pi
    # else:
    #     origin = torch.abs(torch.abs(inner_prod) - int(torch.abs(inner_prod)) - 1)
    #     orient_err = 2 * torch.acos(origin) * 180 / torch.pi
    return orient_err

def rot_err_R(est_pose, gt_pose):
    est_R = compute_rotation_matrix_from_quaternion(est_pose)
    gt_R = compute_rotation_matrix_from_quaternion(gt_pose)
    rot = torch.matmul(est_R.transpose(1, 2), gt_R)
    U, S, Vh = torch.linalg.svd(rot)
    V = Vh.mH
    log_rot = U @ torch.diag(torch.log(S)) @ V
    rot_err = torch.mean(torch.abs(log_rot)) / torch.pi
    return rot_err

def translation_err(est_pose, gt_pose):
    """
    Calculate the position error given the estimated and ground truth pose(s).
    :param est_pose: (torch.Tensor) a batch of estimated poses (Nx7, N is the batch size)
    :param gt_pose: (torch.Tensor) a batch of ground-truth poses (Nx7, N is the batch size)
    :return: position error(s)
    """
    posit_err = torch.norm(est_pose[:, 0:3] - gt_pose[:, 0:3], dim=1)
    return posit_err

In [10]:
pose_gt = torch.tensor([[-0.0086188981315046,0.9658827547330426,8.32611692498036,0.9971564052053462,0.0562843980563516,-0.0499919500231132,-0.0034604950906209]])
pose_est = torch.tensor([[0.04115517437458038,0.9318070411682129,8.151130676269531,0.9971417784690857,0.05427778884768486,-0.052424173802137375,-0.0037253866903483868]])

In [11]:
pose_gt_rot = pose_gt[:, 3:]
pose_est_rot = pose_est[:, 3:]

In [13]:
rot_err_q(pose_est_rot, pose_gt_rot)

tensor([[[0.3648]]])

In [14]:
est_R = compute_rotation_matrix_from_quaternion(pose_est_rot)
gt_R = compute_rotation_matrix_from_quaternion(pose_gt_rot)
rot = torch.matmul(est_R.transpose(1, 2), gt_R)
U, S, Vh = torch.linalg.svd(rot)

In [16]:
U.shape

torch.Size([1, 3, 3])

In [17]:
S.shape

torch.Size([1, 3])

In [18]:
Vh.shape

torch.Size([1, 3, 3])

In [19]:
V = Vh.mH

In [37]:
S_1 = torch.eye(3, dtype=torch.float32)

In [38]:
S_1[0,0] = torch.log(S[0,0])
S_1[1,1] = torch.log(S[0,1])
S_1[2,2] = torch.log(S[0,2])

In [39]:
S_1

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -5.9605e-08,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00, -5.9605e-08]])

In [40]:
log_rot = U @ S_1 @ V

In [41]:
log_rot

tensor([[[-5.6584e-08,  1.8733e-08,  0.0000e+00],
         [-1.7736e-08, -5.3842e-08,  0.0000e+00],
         [ 6.0380e-09,  1.7402e-08,  0.0000e+00]]])

In [42]:
rot_err = torch.mean(torch.abs(log_rot)) / torch.pi

In [43]:
rot_err

tensor(6.0243e-09)