# THIS IS FROM COLAB SO3 DIFFUSION EXAMPLE

In [1]:
import torch
torch.set_default_dtype(torch.float64)
import numpy as np
np.random.seed(42)
import matplotlib.pyplot as plt 
from scipy.spatial.transform import Rotation


scipy.spatial.transform._rotation.Rotation

In [2]:
basis = torch.tensor([
    [[0., 0., 0.], [0., 0., -1.], [0., 1., 0.]], 
    [[0., 0., 1.0], [0., 0., 0.], [-1., 0., 0.]], 
    [[0., -1., 0.], [1., 0., 0.], [0., 0., 0.]]])

In [3]:
# hat map from vector space R^3 (axis-angle vector) to Lie algebra so(3), convert to skew-symmetric matrices
def hat(v): 
    """
    Compute the Hat operator of a batc hof 3D vectors
    Args:
        v: Batch of vectors of shape `(N, 3)`

    Returns:
        Batch of skew-symmetric matrices of shape `(N, 3, 3)`
        i.e each matrix looks like:
        [[0, -z, y], 
        [z, 0, -x], 
        [-y, x, 0]]
        
    """
    return torch.einsum('...i, ijk->...jk', v, basis)

In [4]:

# logarithmic map from SO(3) to R^3 (i.e. rotation vector)
def Log(x): 
    """
    Convert a batch of 3x3 rotation matrices R to a batch of 3-dimensional matrix
    logarithms of rotation matrices
    the conversion has a singularity around `R=I`
    Args:
        x: batch of rotation matrices of shape `(N, 3, 3)`

    Returns:
        batch of logarithms of rotation matrices of shape `(N, 3)`
    """
    return torch.tensor(Rotation.from_matrix(x.numpy()).as_rotvec())

In [5]:
# logarithmic map from SO(3) to so(3), this the matrix logarithm
def log(x): return hat(Log(x))

In [17]:
# convert so(3) to R3, skew-symmetric matrices to rotation vector
def vee(h): 
    """
    
    Args:
        h: batch of skew-symmetric tensors `(N, 3, 3)`

    Returns:
        Batch of 3d vectors of shape `(N, 3, 3)`
    """
    # skew-symmetric
    assert torch.allclose(h, -h.transpose(-1, -2)), "Input A must be skew symmetric" 
    x = h[..., 2, 1]
    y = h[..., 0, 2]
    z = h[..., 1, 0]
    print("x:", x)
    v = torch.stack((x, y, z), dim=-1)
    return v 



In [18]:
vee(basis)

x: tensor([1., 0., 0.])


tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

In [13]:
basis.shape

torch.Size([3, 3, 3])

In [None]:
# exponential map from so(3) to SO(3), this is the matrix exponential

In [26]:
def exp(r):
    """
    compute so3 exponential map
    Args:
        r: batch of vectors of shape `(N, 3)` (unit vector)

    Returns:
        batch of rotation matrices of shape `(N, 3, 3)`
    """
    return torch.linalg.matrix_exp(r)


In [33]:
# calc the rotation angle of SO(3)
def omega(r):
    """
    
    Args:
        r: batch of rotation matrices of shape `(N, 3, 3)` (unit)

    Returns:
       angles 

    """
    return torch.arccos((torch.diagonal(r, dim1=-2, dim2=-1).sum(axis = -1) - 1) / 2)

# calc the relative rotation angle of r0, r1
def angle(r0, r1):
    """
    calculate the angle from r0 to r1
    Args:
        r0: batch of rotation matrices
        r1: batch of rotation matrices

    Returns:
        batch of angles

    """
    omega(torch.matmul(torch.transpose(r0, -2, -1), r1))

# Define IGSO3 density, the geodesic random walk and check their agreement

In [None]:
# power series expansion in the IGSO3 density
def f_igso3(omega, t, L=500):
    """
    density without 1 - cosw / pi (required)
    With this reparameterization, IGSO(3) agrees with the Brownian motion on
    SO(3) with t=sigma^2 when defined for the canonical inner product on SO3,
    <u, v>_SO3 = Trace(u v^T)/2
    Args:
        omega: rotation angle
        t: sigma^2
        L: 

    Returns:
        density of rotation angle
    """
    ls = torch.arange(L)[None] # [1, L]
    return ((2 * ls + 1) * torch.exp(-ls * (ls + 1) * t / 2) * torch.sin(omega[:, None] * (ls + 1 / 2)) / torch.sin(omega[:, None] / 2)).sum(dim=-1)
    
 