In [3]:
%load_ext autoreload
%autoreload 2

In [5]:
import jax.numpy as jnp
import jax
from jax import jit, vmap
import numpy as np
import genjax
import trimesh
import matplotlib.pyplot as plt

from jax import grad, jacfwd, jacrev

#|export
from jax.scipy.spatial.transform import Rotation
from scipy.stats import truncnorm as scipy_truncnormal

normal_logpdf    = jax.scipy.stats.norm.logpdf
normal_pdf       = jax.scipy.stats.norm.pdf
truncnorm_logpdf = jax.scipy.stats.truncnorm.logpdf
truncnorm_pdf    = jax.scipy.stats.truncnorm.pdf

inv       = jnp.linalg.inv
logaddexp = jnp.logaddexp
logsumexp = jax.scipy.special.logsumexp

key = jax.random.PRNGKey(0)

In [6]:
def generic_viewpoint(key, cam, n, sig_x, sig_hd):
    """Generates generix camera poses by varying its xy-coordinates and angle (in the xy-plane)."""
    
    # TODO: Make a version that varies rot and pitch and potentially roll.
    
    _, keys = keysplit(key,1,2)

    # Generic position
    xs = sig_x*jax.random.normal(keys[1], (n,3))
    xs = xs.at[0,:].set(0.0)
    xs = xs.at[:,2].set(0.0)

    # Generic rotation
    hds = sig_hd*jax.random.normal(keys[0], (n,))
    hds = hds.at[0].set(0.0)
    rs = vmap(Rotation.from_euler, (None,0))("y", hds)
    rs = Rotation.as_matrix(rs)
    
    # Generic camera poses
    ps = vmap(pack_pose)(xs, rs)
    ps = cam@ps

    # Generic weights
    logps_hd = normal_logpdf(hds, loc=0.0, scale=sig_hd)
    logps_x  = normal_logpdf( xs, loc=0.0, scale=sig_x).sum(-1)
    logps    = logps_hd + logps_x

    return ps, logps


In [7]:
def generic_contact(key, p0, n, sig_x, sig_hd):

    _, keys = keysplit(key,1,2)

    # Generic contact-pose vector
    xs  = sig_x*jax.random.normal(keys[1], (n,3))
    xs  = xs.at[:,2].set(0.0)
    xs  = xs.at[0,:].set(0.0)

    hds = sig_hd*jax.random.normal(keys[0], (n,1))
    hds = hds.at[0,:].set(0.0)
    rs = vmap(Rotation.from_euler, (None,0))("z", hds)
    rs = Rotation.as_matrix(rs)
    
    # Generic camera poses
    ps = vmap(pack_pose)(xs, rs)
    # vs  = jnp.concatenate([xs, hds], axis=1)

    # Generic weights
    logps_hd = normal_logpdf(hds[:,0], loc=0.0, scale=sig_hd)
    logps_x  = normal_logpdf (xs, loc=0.0, scale=sig_x).sum(-1)
    logps    = logps_hd + logps_x

    # Generic object pose
    generic_ps = p0@ps

    return generic_ps, logps

