In [None]:
import math
import torch
from diffusion_edf.dist import sample_igso3

from matplotlib import pyplot as plt

import numpy as np
np.set_printoptions(precision=4, floatmode="fixed")

In [None]:
sample_w_max = sample_igso3(eps=100., N=100000)
sample_w_max = torch.acos(sample_w_max[:,0])*2
sample_w_degree_max = sample_w_max / torch.pi * 180.

$IG_{SO(3)}$ almost converges at $t=6$

In [None]:
t = 6.
eps = t/2

sample = sample_igso3(eps=eps, N=100000)
sample_w = torch.acos(sample[:,0])*2
sample_w_degree = sample_w / torch.pi * 180.

plt.hist([sample_w_degree, sample_w_degree_max], bins=100)
plt.show()

$\textrm{rotation angle}\ \  w\approx \sqrt{t}\times 38.5\degree$

In [None]:
for t in [1e-6, 1e-5, 1e-4, 1e-3, 1e-2]:
    eps = t/2

    sample = sample_igso3(eps=eps, N=100000)
    sample_w = torch.acos(sample[:,0])*2
    sample_w_degree = sample_w / torch.pi * 180.
    print(sample_w.std().item() / math.sqrt(t))
print(f"\n degree: {sample_w_degree.std().item() / math.sqrt(t)}")

# Play

In [None]:
t = 1e-2
eps = t/2

sample = sample_igso3(eps=eps, N=100000)
sample_w = torch.acos(sample[:,0])*2
sample_w_degree = sample_w / torch.pi * 180.
print(sample_w_degree.std())

plt.hist(sample_w_degree, bins=100)
plt.show()

# Calculator

In [None]:
def std_calculator(t_schedule, lin_mult, ang_mult = 1., p = 0.95, N = 100000):
    lin_stds = []
    ang_stds = []
    lin_maxs = []
    ang_maxs = []
    lin_meds = []
    ang_meds = []
    lin_mods = []
    ang_mods = []

    for t in t_schedule:
        eps = t/2 * (ang_mult**2)
        std = torch.tensor([math.sqrt(t*3)]) * lin_mult
        lin_stds.append(std)
        lin_samples = torch.randn(N,3).norm(dim=-1) * std

        sample = sample_igso3(eps=eps, N=N)
        sample_w = torch.acos(sample[:,0])*2
        sample_w_degree = sample_w / torch.pi * 180.
        ang_stds.append(sample_w_degree.std().unsqueeze(0))
        ang_max = sample_w_degree.sort().values[int(p*N)]
        ang_maxs.append(ang_max.unsqueeze(0))
        lin_max = lin_samples.sort().values[int(p*N)]
        lin_maxs.append(lin_max.unsqueeze(0))

        lin_meds.append(lin_samples.median().unsqueeze(0))
        ang_meds.append(sample_w_degree.median().unsqueeze(0))

        boundaries = torch.linspace(0., ang_max*1.2, steps=100)
        bins = torch.bucketize(sample_w_degree, boundaries=boundaries)
        ang_mod = boundaries[bins.mode().values]

        boundaries = torch.linspace(0., lin_max*1.2, steps=100)
        bins = torch.bucketize(lin_samples, boundaries=boundaries)
        lin_mod = boundaries[bins.mode().values]

        lin_mods.append(lin_mod.unsqueeze(0))
        ang_mods.append(ang_mod.unsqueeze(0))

    return torch.cat(lin_stds), torch.cat(ang_stds), torch.cat(lin_maxs), torch.cat(ang_maxs), torch.cat(lin_meds), torch.cat(ang_meds), torch.cat(lin_mods), torch.cat(ang_mods)

In [None]:
lin_stds, ang_stds, lin_maxs, ang_maxs, lin_meds, ang_meds, lin_mods, ang_mods = std_calculator([1e-4, 1e-3, 1e-2, 1e-1, 1.], ang_mult = 2.5, lin_mult=20., p=0.95)

print(f"angular std: {ang_stds.numpy()}")
print(f"angular median: {ang_meds.numpy()}")
print(f"angular mode: {ang_mods.numpy()}")
print(f"angular max: {ang_maxs.numpy()}")
print("========================================")
print(f"linear std: {lin_stds.numpy()}")
print(f"linear median: {lin_meds.numpy()}")
print(f"linear mode: {lin_mods.numpy()}")
print(f"linear max: {lin_maxs.numpy()}")

In [None]:
lin_stds, ang_stds, lin_maxs, ang_maxs, lin_meds, ang_meds, lin_mods, ang_mods = std_calculator([1e-4, 5e-4, 5e-3, 3e-2, 1e-1, 1.], ang_mult = 2.5, lin_mult=15., p=0.95)

print(f"angular std: {ang_stds.numpy()}")
print(f"angular median: {ang_meds.numpy()}")
print(f"angular mode: {ang_mods.numpy()}")
print(f"angular max: {ang_maxs.numpy()}")
print("========================================")
print(f"linear std: {lin_stds.numpy()}")
print(f"linear median: {lin_meds.numpy()}")
print(f"linear mode: {lin_mods.numpy()}")
print(f"linear max: {lin_maxs.numpy()}")

In [None]:
lin_stds, ang_stds, lin_maxs, ang_maxs, lin_meds, ang_meds, lin_mods, ang_mods = std_calculator([0.003, 0.03, 0.1], ang_mult = 5., lin_mult=15., p=0.95)

print(f"angular std: {ang_stds.numpy()}")
print(f"angular median: {ang_meds.numpy()}")
print(f"angular mode: {ang_mods.numpy()}")
print(f"angular max: {ang_maxs.numpy()}")
print("========================================")
print(f"linear std: {lin_stds.numpy()}")
print(f"linear median: {lin_meds.numpy()}")
print(f"linear mode: {lin_mods.numpy()}")
print(f"linear max: {lin_maxs.numpy()}")