In [None]:
import sys
sys.path.append("..")
import numpy as np 
import torch
from relie.utils.so3_tools import so3_hat, so3_vee, so3_exp, so3_log
from relie.utils.se3_tools import se3_hat, se3_vee, se3_exp, se3_log, se3_inv

In [None]:
def det_jac_so3(x):
    theta = x.norm(2,-1)
    return 2*(1 - torch.cos(theta))/(theta**2)

def so3_inv(el):
    return el.transpose(-2,-1)

def relative_error(x_hat,x):
    return (x-x_hat)/x
    

In [None]:
def compute_approx_jacobian(points, vee, exp, log, inv, eps = 0.01):
    points = points.unsqueeze(-2)
    dim = points.shape[-1]
    basis = torch.eye(dim, dtype=torch.float64).unsqueeze(0)
    group_delta = exp(basis*eps + points)
    points_inv = inv(exp(points))
    normal_coord = vee(log(points_inv@group_delta))
    estimated_det_jac = np.linalg.det(normal_coord.numpy())/((eps)**dim)
    return torch.tensor(estimated_det_jac)

### SO(3)

In [None]:
def approximate_so3_jacobian(points, eps):
    return compute_approx_jacobian(points, so3_vee, so3_exp, so3_log, so3_inv, eps)

In [None]:
dim = 3
eps = 1
n_points = 10


center = torch.tensor(np.random.uniform(-1,1,(n_points,dim)),dtype = torch.float64)

l_eps = [1, 1e-1, 1e-2, 1e-3]
for eps in l_eps:
    estimated_det_jac = approximate_so3_jacobian(center, eps=eps)
    err = relative_error(estimated_det_jac, det_jac_so3(center))
    print(torch.log10(err.mean()), torch.log10(err.std()))

### SE(3)

In [None]:
def approximate_se3_jacobian(points, eps = 0.01):
    return compute_approx_jacobian(points, se3_vee, se3_exp, se3_log, se3_inv, eps)

def det_jac_se3(z):
    x, _ = z.split([3,3], -1)
    theta = x.norm(2,-1)
    return (2*(1 - torch.cos(theta))/(theta**2))**2

In [None]:
dim = 6
n_points = 10
center = torch.tensor(np.random.normal(0,1,(n_points,dim)),dtype = torch.float64)
l_eps = [1, 1e-1, 1e-2, 1e-3]
for eps in l_eps:
    estimated_det_jac = approximate_se3_jacobian(center, eps=eps)  
    err = relative_error(estimated_det_jac, det_jac_se3(center))
    print(torch.log10(err.mean()), torch.log10(err.std()))

In [None]:
print(estimated_det_jac - det_jac_se3(center))

In [None]:
print(approximate_se3_jacobian(center, eps=0.1) - det_jac_se3(center))