In [223]:
import torch
import numpy as np
from liegroups.torch import SE3, SO3
import sys
sys.path.insert(0,'..')
from utils import *

In [249]:
#Compute DxD covariance from NxD samples
def sample_covariance(samples, assume_zero_mean=False):
    sample_mat = samples.transpose(0,1)
    if assume_zero_mean:
        sample_mean = samples.new_zeros((3,1))
    else:
        sample_mean = samples.mean(dim=0).view(-1, 1)

    return (sample_mat - sample_mean).mm((sample_mat-sample_mean).transpose(0,1))/(samples.shape[0] - 1)

def frob_norm(A, B):
    return (A-B).norm()

In [250]:
#Mean phi 
n_samples = 25
n_repeat = 1000
degree_std = 10

error_deg = 0.
covar_err = 0.

for i in range(n_repeat):
    covar = (degree_std*(np.pi/180.))**2 * torch.diag(torch.tensor([1., 1.5, 2.]))
    m = torch.distributions.MultivariateNormal(torch.zeros(3), covar)
    R_mean = SO3.exp(np.pi*torch.rand(3) - np.pi/2.)
    R_samples = SO3.exp(m.rsample([n_samples])).dot(R_mean)
    phi_mean_est = R_samples.log()
    R_mean_est = SO3.exp(phi_mean_est.mean(dim=0))
    error_deg += R_mean.dot(R_mean_est.inv()).log().norm()*(180./np.pi)
    covar_err += frob_norm(sample_covariance(phi_new_samples), covar)

print('Average deg error: {:.3f} | Covar err: {:.3f}'.format(error_deg/n_repeat, covar_err/n_repeat))


Average deg error: 3.905 | Covar err: 0.035


In [291]:
def compute_quat_stats(R_samples, R_mean):
    q_mean = R_mean.to_quaternion()
    q_samples = set_quat_sign(R_samples.to_quaternion().unsqueeze(0)).squeeze(0)
    q_mean_est = q_samples.mean(dim=0)
    R_mean_est = SO3.from_quaternion(q_mean_est/q_mean_est.norm())
    phi_diff = quat_log_diff(q_samples, q_mean.repeat([n_samples, 1]))
    return R_mean_est, sample_covariance(phi_diff, assume_zero_mean=True)

def compute_geo_stats(R_samples, R_mean):
    R_s = SO3.from_matrix(R_samples.as_matrix()[0])
    for j in range(10):
        #Compute logs about T_s
        phi_new_samples = R_samples.dot(R_s.inv()).log().double()
        phi_delta = phi_new_samples.mean(dim=0)
        if phi_delta.norm() < 1e-3:
            break
        R_s = SO3.exp(phi_delta).dot(R_s)
        Sigma = sample_covariance(phi_new_samples, assume_zero_mean=True).double()
    return R_s, Sigma

def compute_errors(R_est, R_true, Sigma_est, Sigma_true):
    ang_err = R_est.dot(R_true.inv()).log().norm()*(180./np.pi)
    covar_err = frob_norm(Sigma_est, Sigma_true)
    return ang_err, covar_err
    

In [292]:
#Compute both
n_samples = 25
n_repeat = 1000
degree_std = 25

geo_error_deg = 0.
geo_covar_err = 0.
quat_error_deg = 0.
quat_covar_err = 0.

for i in range(n_repeat):
    covar = (degree_std*(np.pi/180.))**2 * torch.diag(torch.tensor([1., 0.9, 1.1]))
    m = torch.distributions.MultivariateNormal(torch.zeros(3), covar)
    R_mean = SO3.exp(np.pi*torch.rand(3).double() - np.pi/2.)
    R_samples = SO3.exp(m.rsample([n_samples]).double()).dot(R_mean)
    
    R_geo, Sigma_geo = compute_geo_stats(R_samples, R_mean)
    R_quat, Sigma_quat = compute_quat_stats(R_samples, R_mean)
    
    #print('R_geo: {} | Sigma_geo: {}'.format(R_geo.as_matrix().dtype, Sigma_geo.dtype))
    #print('R_quat: {} | Sigma_quat: {}'.format(R_quat.as_matrix().dtype, Sigma_quat.dtype))

    geo_ang_err_i, geo_covar_err_i = compute_errors(R_geo, R_mean, Sigma_geo, covar.double())
    quat_ang_err_i, quat_covar_err_i = compute_errors(R_quat, R_mean, Sigma_quat, covar.double())

    quat_error_deg += quat_ang_err_i
    quat_covar_err += quat_covar_err_i

    geo_error_deg += geo_ang_err_i
    geo_covar_err += geo_covar_err_i

print('QUAT deg error: {:.3f} | Covar err: {:.3f}'.format(quat_error_deg/n_repeat, quat_covar_err/n_repeat))
print('GEO deg error: {:.3f} | Covar err: {:.3f}'.format(geo_error_deg/n_repeat, geo_covar_err/n_repeat))

QUAT deg error: 8.377 | Covar err: 0.134
GEO deg error: 8.259 | Covar err: 0.131
