In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from condorgmm.condor.model.distributions.variance_priors import get_my_inverse_gamma_widget
get_my_inverse_gamma_widget(plot_sqrt_x=True)

## Test out the prior on RGB (mean, var)

In [None]:
from genjax import gen
import genjax
import jax.numpy as jnp
from condorgmm.condor.model.distributions.variance_priors import my_inverse_gamma
import jax

In [None]:
# variance__n_pseudo_obs = 1
# variance__pseudo_obs_var = 10**2 # pseudo std = 10
# mean__n_pseudo_obs = 0.01
variance__n_pseudo_obs = 2 * 1e-1
variance__pseudo_obs_var = 0.3 ** 2
mean__n_pseudo_obs = .5 * 1e-5

@gen
def generate_mean_var():
    var = my_inverse_gamma(variance__n_pseudo_obs, variance__pseudo_obs_var) @ "var"
    std = jnp.sqrt(var / mean__n_pseudo_obs)
    val = genjax.normal(255 / 2, std) @ "val"
    return (var, val)
    
var_samples, val_samples = jax.vmap(generate_mean_var())(jax.random.split(jax.random.key(0), 36))

In [None]:
import numpy as np

import matplotlib.pyplot as plt

fig, axes = plt.subplots(6, 6, figsize=(15, 15))

for i, ax in enumerate(axes.flatten()):
    var = var_samples[i]
    val = val_samples[i]
    std = np.sqrt(var)
    x = np.linspace(val - 3*std, val + 3*std, 100)
    y = jnp.exp(jax.vmap(lambda x: genjax.normal.logpdf(x, val, std))(x))
    ax.plot(x, y)
    ax.set_title(f'Var: {var:.2f}, Val: {val:.2f}')
    ax.axvline(x=0, color='r', linestyle='--')
    ax.axvline(x=255, color='r', linestyle='-')
    ax.axvline(x=val, color='g', linestyle='-')
    ax.grid(True)

ax.set_xlim([-10, 265])
plt.tight_layout()
plt.show()

In [None]:
from condorgmm.condor.model.distributions.variance_priors import get_my_inverse_gamma_widget
get_my_inverse_gamma_widget(plot_sqrt_x=True)

## Test out the prior on XYZ (mean, var)

In [None]:
from condorgmm.condor.model.distributions.variance_priors import my_inverse_wishart

In [None]:
import condorgmm
condorgmm.rr_init("conjugate_priors_00")

In [None]:
xyz_cov_n_pseudo_obs = 1.
xyz_pseudo_sample_cov = jnp.array([
    [5, 0, 0],
    [0, 5, 0],
    [0, 0, 5]
]) * 1e-3 # mm -> m
xyz_mean_n_pseudo_obs = .01
xyz_mean_center = jnp.array([0, 0, 0], dtype=float)

def generate_mean_cov(key):
    k1, k2 = jax.random.split(key)
    cov = my_inverse_wishart.sample(k1, xyz_cov_n_pseudo_obs, xyz_pseudo_sample_cov)
    xyz = genjax.mv_normal.sample(k2, xyz_mean_center, cov / xyz_mean_n_pseudo_obs)
    return xyz, cov

In [None]:
from jax.scipy.spatial.transform import Rotation as Rot

def cov_to_isotropic_and_quaternion(cov): # (3, 3)
    # Eigen-decomposition
    eigvals, eigvecs = jnp.linalg.eigh(cov)
    
    # Ensure positive eigenvalues
    vars = jnp.maximum(eigvals, 0)

    # Convert rotation matrix to quaternion
    quat = Rot.from_matrix(eigvecs).as_quat()

    return vars, quat

def isotropic_and_quaternion_to_cov(vars, quat): # (3,); (4,)
    rot = Rot.from_quat(quat).as_matrix()
    cov = rot @ jnp.diag(vars) @ rot.T
    return cov

In [None]:
import rerun as rr

condorgmm.rr_init("conjugate_priors_01")
means, covs = jax.vmap(generate_mean_cov)(jax.random.split(jax.random.PRNGKey(0), 40))
vars, quats = jax.vmap(cov_to_isotropic_and_quaternion)(covs)
ellipses = rr.Ellipsoids3D(
    half_sizes = jnp.sqrt(vars),
    centers = means,
    quaternions = quats,
)
rr.log("samples", ellipses)
rr.log("reference", rr.Ellipsoids3D(
    half_sizes=jnp.array([[1, 1, 1]]),
    centers = jnp.array([[0, 0, 0]]),
    quaternions = jnp.array([[0, 0, 0, 1]]),
    colors=jnp.array([[0, 0, 0]])
))