In [1]:
import math
import torch
from diffusion_edf.dist import *
from diffusion_edf import transforms
import plotly.graph_objs as go

# SE(3) Score test

In [2]:
dtype = torch.float64
device = 'cpu'

eps = 0.1
std = 0.1
N=3
q = sample_igso3(eps = eps, N=N, dtype=dtype, device=device)
x = torch.randn(N,3, dtype=dtype, device=device)
T = torch.cat([q,x], dim=-1)

log_prob = r3_log_isotropic_gaussian(x, std) + torch.log(igso3(q, eps))
score = se3_isotropic_gaussian_score(T, eps=eps, std=std)
print(f"Anaylitic score: {score}")

dt = 0.0000001

log_prob_rot = []
log_prob_trans = []
for i in range(3):
    qrot = transforms.quaternion_multiply(q,transforms.axis_angle_to_quaternion(dt * torch.eye(3, device=q.device, dtype=q.dtype)[i]))
    log_prob_rot.append(r3_log_isotropic_gaussian(x, std) + torch.log(igso3(qrot, eps)))
    log_prob_trans.append(r3_log_isotropic_gaussian(transforms.quaternion_apply(q,dt*torch.eye(3, device=q.device, dtype=q.dtype)[i]) + x, std) + torch.log(igso3(q, eps)))

log_prob_perturbed = torch.stack(log_prob_rot + log_prob_trans, dim=-1)
print(f"Numerical score: {(log_prob_perturbed - log_prob[...,None]) / dt}")
print(f"Allclose: {torch.allclose((log_prob_perturbed - log_prob[...,None]) / dt, score, atol=0, rtol=1e-2)}")

Anaylitic score: tensor([[  -0.2929,    0.4965,    1.8783,   22.7208, -153.7417,   75.0322],
        [   1.1517,    3.1551,   -3.0236,   15.4715,   50.8923,  -68.7748],
        [   0.7387,   -3.6302,   -2.6324,   11.9674,   -3.8494,   16.2496]],
       dtype=torch.float64)
Numerical score: tensor([[  -0.2929,    0.4965,    1.8783,   22.7208, -153.7417,   75.0322],
        [   1.1517,    3.1551,   -3.0236,   15.4714,   50.8923,  -68.7748],
        [   0.7387,   -3.6302,   -2.6324,   11.9673,   -3.8494,   16.2496]],
       dtype=torch.float64)
Allclose: True


# Adjoint $T_{sg} = T_{se} T_{eg}$ -> $T_{se}$

In [3]:
T_ref = torch.cat([transforms.random_quaternions(N, dtype=dtype, device=device), 10*torch.randn(N, 3, dtype=dtype, device=device)], dim=-1)     # T_ref = T_eg
TT_ref = transforms.multiply_se3(T, T_ref) # T T_ref

log_prob_ref = r3_log_isotropic_gaussian(TT_ref[...,4:], std) + torch.log(igso3(TT_ref[...,:4], eps)) # log P(T T_ref)
score_ref_perturb = adjoint_inv_tr_se3_score(T_ref=T_ref, score = se3_isotropic_gaussian_score(TT_ref, eps=eps, std=std), angular_first=True)          # Ad_{T_ref^-1}^T score(T T_ref)

print(f"Anaylitic score: {score_ref_perturb}")

Anaylitic score: tensor([[ -582.4981,  -329.0069,  -495.7784,  -345.8799,   163.2239,   300.5076],
        [  990.4015,  2141.9007,  1804.8203, -3200.0395,  1188.2145,   385.4164],
        [ -165.8116,   223.2308,   163.5129,  1211.6343,  1034.0490,  -267.9058]],
       dtype=torch.float64)


In [4]:
log_prob_rot_ref = []
log_prob_trans_ref = []
for i in range(3):
    q_perturb = transforms.quaternion_multiply(q,transforms.axis_angle_to_quaternion(dt * torch.eye(3, device=q.device, dtype=q.dtype)[i]))        # R dR
    x_perturb = transforms.quaternion_apply(q,dt*torch.eye(3, device=q.device, dtype=q.dtype)[i]) + x                                              # R dx + x

    log_prob_rot_ref.append(
        r3_log_isotropic_gaussian(transforms.quaternion_apply(q_perturb, T_ref[..., 4:]) + x, std=std) + \
                                                                                                              # R dR x' + x
        torch.log(igso3(transforms.quaternion_multiply(q_perturb, T_ref[...,:4]), eps))                       # R dR R'
    )

    log_prob_trans_ref.append(
        r3_log_isotropic_gaussian(transforms.quaternion_apply(q, T_ref[..., 4:]) + x_perturb, std=std) + \
                                                                                                              # R x' + R dx + x
        torch.log(igso3(transforms.quaternion_multiply(q, T_ref[...,:4]), eps))                       # R R'
    )


log_prob_perturbed_ref = torch.stack(log_prob_rot_ref + log_prob_trans_ref, dim=-1)

numerical_score_ref = (log_prob_perturbed_ref - log_prob_ref[..., None]) / dt
print(f"Numerical score: {numerical_score_ref}")
print(f"Allclose: {torch.allclose(numerical_score_ref, score_ref_perturb, atol=0, rtol=1e-2)}")

Numerical score: tensor([[ -582.4976,  -329.0064,  -495.7776,  -345.8799,   163.2239,   300.5076],
        [  990.4009,  2141.9000,  1804.8196, -3200.0395,  1188.2144,   385.4164],
        [ -165.8112,   223.2309,   163.5130,  1211.6343,  1034.0490,  -267.9059]],
       dtype=torch.float64)
Allclose: True


# Adjoint T_se -> T_sg

In [5]:
analytic_score = se3_isotropic_gaussian_score(TT_ref, eps=eps, std=std)
print(f"Anaylitic score: {analytic_score}")

Anaylitic score: tensor([[ 1.5007e+00, -1.8175e+00, -1.0158e+00, -2.1189e+02,  4.1604e+02,
          1.3635e+02],
        [-1.0099e+01, -1.3386e+01, -1.7408e+01,  3.1217e+02, -2.4232e+03,
         -2.4148e+03],
        [-6.1486e+00, -1.0377e+01,  9.7393e-02,  1.0636e+03,  7.1638e+02,
          9.8219e+02]], dtype=torch.float64)


In [6]:
numerical_score = adjoint_inv_tr_se3_score(T_ref = transforms.se3_invert(T_ref), score = numerical_score_ref, angular_first=True)
print(f"Numerical score: {numerical_score}")
print(f"Allclose: {torch.allclose(numerical_score, analytic_score, atol=0, rtol=1e-2)}")

Numerical score: tensor([[ 1.5004e+00, -1.8169e+00, -1.0167e+00, -2.1189e+02,  4.1604e+02,
          1.3635e+02],
        [-1.0101e+01, -1.3385e+01, -1.7409e+01,  3.1217e+02, -2.4232e+03,
         -2.4148e+03],
        [-6.1482e+00, -1.0376e+01,  9.7226e-02,  1.0636e+03,  7.1638e+02,
          9.8219e+02]], dtype=torch.float64)
Allclose: True


# Isotropic Adjoint Test

In [7]:
N=5
x_ref = torch.randn(N,3)
T_ref = transforms.se3_from_r3(x_ref)
score = torch.randn(N,6)

print(torch.allclose(adjoint_inv_tr_isotropic_se3_score(x_ref=x_ref, score=score), adjoint_inv_tr_se3_score(T_ref=T_ref, score=score), atol=0, rtol = 1e-7),
torch.allclose(adjoint_inv_tr_isotropic_se3_score(x_ref=-x_ref, score=score), adjoint_inv_tr_se3_score(T_ref=transforms.se3_invert(T_ref), score=score), atol=0, rtol = 1e-7),
torch.allclose(adjoint_isotropic_se3_score(x_ref=x_ref, score=score), adjoint_se3_score(T_ref=T_ref, score=score), atol=0, rtol = 1e-7),
torch.allclose(adjoint_isotropic_se3_score(x_ref=-x_ref, score=score), adjoint_se3_score(T_ref=transforms.se3_invert(T_ref), score=score), atol=0, rtol = 1e-7),
)

True True True True
