In [1]:
import math
import torch
from diffusion_edf.dist import IgSO3Dist
from diffusion_edf.transforms import random_quaternions, quaternion_to_axis_angle, quaternion_apply, axis_angle_to_quaternion, quaternion_multiply
import plotly.graph_objs as go

In [2]:
dist = IgSO3Dist()

In [None]:
eps = 0.051
time = eps * 2

N = 10000
q = dist.sample(eps=eps, N=N)
z = torch.tensor([0., 0., 1.]).expand(N,3)

zrot = quaternion_apply(quaternion=q, point=z)

In [None]:
z_vis = zrot

layout = go.Layout(
    width=600,
    height=600,
    scene=dict(camera=dict(eye=dict(x=1.25, y=1.25, z=1.25)), #the default values are 1.25, 1.25, 1.25
            aspectmode='manual', #this string can be 'data', 'cube', 'auto', 'manual'
            #a custom aspectratio is defined as follows:
            aspectratio=dict(x=1, y=1, z=1),
            xaxis = dict(title='x', range = [-1,1]),
            yaxis = dict(title='y', range = [-1,1]),
            zaxis = dict(title='z', range = [-1,1]),
            )
)
fig = go.Figure(data=go.Scatter3d(x=z_vis[...,0], y=z_vis[...,1], z=z_vis[...,2], mode='markers', marker=dict(size=3)), layout=layout)
fig.show()

In [None]:
n_iter = 100
q_diffuse = torch.tensor([1.,0.,0.,0.]).expand(N,4)
for n in range(n_iter):
    #t = time * (n+1) / n_iter
    dt = time / n_iter
    dw = torch.randn(N,3) * math.sqrt(dt)
    q_diffuse = quaternion_multiply(q_diffuse, axis_angle_to_quaternion(dw))

z_diffuse = quaternion_apply(quaternion=q_diffuse, point=z)

In [None]:
z_vis = z_diffuse

layout = go.Layout(
    width=600,
    height=600,
    scene=dict(camera=dict(eye=dict(x=1.25, y=1.25, z=1.25)), #the default values are 1.25, 1.25, 1.25
            aspectmode='manual', #this string can be 'data', 'cube', 'auto', 'manual'
            #a custom aspectratio is defined as follows:
            aspectratio=dict(x=1, y=1, z=1),
            xaxis = dict(title='x', range = [-1,1]),
            yaxis = dict(title='y', range = [-1,1]),
            zaxis = dict(title='z', range = [-1,1]),
            )
)
fig = go.Figure(data=go.Scatter3d(x=z_vis[...,0], y=z_vis[...,1], z=z_vis[...,2], mode='markers', marker=dict(size=3)), layout=layout)
fig.show()

In [None]:
angle = torch.tensor([0.1])
eps = 1.

approx = -0.25 * angle.square() / eps # gaussian
exact = torch.log(dist.isotropic_gaussian_so3(angle, eps=eps))

print(f"Approx: {approx}")
print(f"Exact: {exact}")