In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#|default_exp gaussian_particle_system_genjax

In [6]:
#|export
import bayes3d as b3d
import trimesh
import os
from bayes3d._mkl.utils import *
import matplotlib.pyplot as plt
import numpy as np
import jax
from jax import jit, vmap
import jax.numpy as jnp
import jaxlib
from jax.scipy.spatial.transform import Rotation as Rot
from functools import partial
import genjax
from bayes3d.camera import Intrinsics, K_from_intrinsics, camera_rays_from_intrinsics
from bayes3d.transforms_3d import transform_from_pos_target_up, add_homogenous_ones, unproject_depth
import tensorflow_probability as tfp
from tensorflow_probability.substrates.jax.math import lambertw
from typing import Any, NamedTuple

normal_cdf    = jax.scipy.stats.norm.cdf
normal_pdf    = jax.scipy.stats.norm.pdf
normal_logpdf = jax.scipy.stats.norm.logpdf
inv = jnp.linalg.inv
concat = jnp.concatenate

key = jax.random.PRNGKey(0)

In [7]:
#|export
Array = np.ndarray | jax.Array
Shape = int | tuple[int, ...]
Bool = Array
Float = Array
Int = Array

In [188]:
def logit(x):
    return jnp.log(x/(1-x))

    
def pack_homogenous_matrix(x: "Position",q: "Quaternion") -> "HomogenousMatrix":
    r = Rot.from_quat(q).as_matrix()
    return concat([concat([r, x.reshape(-1,1)], axis=-1), jnp.array([[0.,0.,0.,1.]])])

In [8]:
def multiply_quaternions(q1, q2):

    w1, x1, y1, z1 = q1[3], q1[0], q1[1], q1[2]
    w2, x2, y2, z2 = q2[3], q2[0], q2[1], q2[2]

    w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
    x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
    y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
    z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2

    return jnp.array([x, y, z, w])

In [32]:
def make_constant_model(x, address="_ignored"):

    @genjax.Static
    def constant_model(*args):
        # Note that genjax's bernoulli takes logits!
        _ = genjax.bernoulli(jnp.inf) @ address
        return x
        
    return constant_model

## GPS Prior

In [33]:
class Pose(NamedTuple):
    position:   Array
    quaternion: Array


class Cam(NamedTuple):
    pose:       Pose
    intrinsics: Array
    

class ParticleSystem(NamedTuple):
    poses:       tuple[Array, Array]
    covariances: Array
    colors:      Array
    weights:     Array
    mask:        Array


@genjax.Static
def gaussian_embedding_prior(size_bounds):
    """Samples an embedding matrix."""
    emb = genjax.uniform(*jnp.tile(size_bounds, (1, 3, 3))) @ "embedding_matrix"
    return emb


@genjax.Static
def pose_prior(position_bounds):
    """Samples pose with a position within given bounds."""
    x = genjax.uniform(position_bounds[0], position_bounds[1]) @ "x"
    q = genjax.normal(jnp.zeros(4), jnp.ones(4)) @ "q"
    q = q/jnp.linalg.norm(q)
    return Pose(x, q)


def make_gps_prior(max_particles: int):

    @genjax.Static
    def gps_prior(particle_bounds, embedding_bounds):
        """Naive prior over Gaussian particle systems."""

        particle_mask = genjax.Map(genjax.bernoulli, in_axes=(0,))(
                    0.5*jnp.ones(max_particles))@ "mask"
                
        poses = genjax.Map(genjax.masking_combinator(pose_prior), in_axes=(0,(0,)))(
                    particle_mask,
                    (jnp.tile(particle_bounds, (max_particles,1,1)),)) @ "poses"
        
        embs = genjax.Map(genjax.masking_combinator(gaussian_embedding_prior), in_axes=(0,(0,)))(
                    particle_mask,
                    (jnp.tile(embedding_bounds, (max_particles,1)),)) @ "embedding_matrices"
        covs = vmap(lambda emb: emb@emb.T)(embs.value)

        cols = genjax.Map(genjax.masking_combinator(genjax.uniform), in_axes=(0,(0,0)))(
                    particle_mask,
                    (jnp.zeros((max_particles, 3)), jnp.ones((max_particles, 3)),)) @ "colors"


        alphas = genjax.Map(genjax.masking_combinator(genjax.uniform), in_axes=(0,(0,0)))(
                    particle_mask,
                    (jnp.zeros(max_particles), jnp.ones(max_particles),)) @ "transparencies"


        return ParticleSystem(poses.value, covs, cols.value, alphas.value, particle_mask)


    return gps_prior


In [34]:
gps_prior = make_gps_prior(10)
tr = gps_prior.simulate(key, (
    jnp.array([jnp.zeros(3),jnp.ones(3)]), 
    jnp.array([0.,1.]),
))
gps = tr.get_retval() 

## Hierarchical GPS model

In [35]:
class Clustering(NamedTuple):
    poses:       tuple[Array, Array]
    assignments: Array


@genjax.Static
def motion_model(p: Pose, std_position, std_quaternion):
    """Hacked motion model for elements in SE(3)."""
    x = genjax.normal(p.position,   std_position  ) @ "x"
    q = genjax.normal(p.quaternion, std_quaternion) @ "q"
    q = q/jnp.linalg.norm(q)
    return Pose(x,q)


def make_hgps_model(max_clusters, max_particles, max_time_steps=10, camera_intrinsics=jnp.array([0])):
    
        
    gps_prior = make_gps_prior(max_particles)


    @genjax.Static
    def kernel(state):
        t, gps, clustering, cam = state

        new_cam_pose       = motion_model(cam.pose, jnp.ones(3), jnp.ones(4)) @ "camera_pose"
        # TODO: should empty clusters be masked out?
        new_cluster_poses  = genjax.Map(motion_model, in_axes=(0,None,None))(
                                    clustering.poses, jnp.ones(3), jnp.ones(4)) @ "cluster_poses"

        new_particle_poses = genjax.Map(genjax.masking_combinator(motion_model), in_axes=(0,(0,None,None)))(
                                    gps.mask,
                                    (gps.poses, jnp.ones(3), jnp.ones(4))) @ "relative_particle_poses"
        new_particle_poses = new_particle_poses.value
        # TODO: Put "energy" constraint on particle system, eg., it should be more expensive
        #       to move the relative pose of particles than moving the cluster. This could be done by
        #       computing the relative pose updates and sampling a value from say a normal distribution.
        #       Constraining this value to be zero would put more weight on the updates to be zero. Should 
        #       we put a prior over the contribution of that? 

        cam        = cam._replace(pose=new_cam_pose)
        gps        = gps._replace(poses=new_particle_poses)
        clustering = clustering._replace(poses=new_cluster_poses)

        # obs = observation_model(cam, gps, clustering) @ "observation"
        obs = jnp.zeros(1)

        return (t+1, gps, clustering, cam)

    
    unfolded_kernel = genjax.Unfold(kernel, max_time_steps)


    @genjax.Static
    def hgps_model(T, particle_bounds, embedding_bounds):

        gps = gps_prior(particle_bounds, embedding_bounds) @ "initial_particle_system"


        zs = genjax.Map(genjax.masking_combinator(genjax.categorical), in_axes=(0,(0,)))(
                        gps.mask,
                        (jnp.tile(jnp.ones(max_clusters), (max_particles, 1)),)) @ "initial_assignments"
        zs = zs.value

        # TODO: should empty clusters be masked out?
        qs = genjax.Map(pose_prior, in_axes=(0,))(
                        jnp.tile(particle_bounds, (max_clusters,1,1))) @ "initial_coordinate_frames"
        clustering = Clustering(qs, zs)
        
        cam_pose = pose_prior(particle_bounds) @ "initial_camera_pose"
        cam = Cam(cam_pose, camera_intrinsics)

        state0 = (0, gps, clustering, cam)
        states = unfolded_kernel(T, state0) @ "chain"

        return states
        

    return hgps_model

In [36]:
max_clusters  = 5
max_particles = 7
max_T = 10
hgps_model = make_hgps_model(max_clusters, max_particles, max_T)

T = 1
particle_bounds = jnp.array([-jnp.ones(3), jnp.ones(3)])
embedding_bounds = jnp.array([0.,1.])

tr = hgps_model.simulate(key, (T, particle_bounds, embedding_bounds))

In [191]:
t, gps, clustering, cam = tr.get_retval()