In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
#|default_exp gaussian_sensor_model

# Gaussian Renderer

In [2]:
#|export
import bayes3d as b3d
import trimesh
import os
from bayes3d._mkl.utils import *
import matplotlib.pyplot as plt
import numpy as np
import jax
from jax import jit, vmap
import jax.numpy as jnp
from functools import partial
from bayes3d.camera import Intrinsics, K_from_intrinsics, camera_rays_from_intrinsics
from bayes3d.transforms_3d import transform_from_pos_target_up, add_homogenous_ones, unproject_depth
import tensorflow_probability as tfp
from tensorflow_probability.substrates.jax.math import lambertw


In [5]:
#|export
from bayes3d._mkl.types import *
from bayes3d._mkl.gaussian_renderer import (
    weighted_arrival_intersection, 
    weighted_argmax_intersection,
    discrete_arrival_probabilities
)


In [6]:
#|export
normal_cdf    = jax.scipy.stats.norm.cdf
normal_pdf    = jax.scipy.stats.norm.pdf
normal_logpdf = jax.scipy.stats.norm.logpdf
inv = jnp.linalg.inv

key = jax.random.PRNGKey(0)

From 
> `_mkl.gaussian_renderer.py`
```python
    def _cast_ray(v, mus, precisions, colors, weights, zmax=2.0, bg_color=jnp.array([1.,1.,1.,1.])):
        # TODO: Deal with negative intersections behind the camera
        # TODO: Maybe switch to log probs?

        # Compute fuzzy intersections `xs` with Gaussians and 
        # their function values `sigmas`
        ts, sigmas = vmap(weighted_argmax_intersection, (0,0,0,None,None))(
                            mus, precisions, weights, jnp.zeros(3), v)
        order  = jnp.argsort(ts)
        ts     = ts[order]
        sigmas = sigmas[order]
        xs     = ts[:,None]*v[None,:]

        # TODO: Ensure that alphas are in [0,1]
        # TODO: Should we reset the color opacity to `op`?
        # Alternatively we can set `alphas = (1 - jnp.exp(-sigmas*1.0))` -- cf. Fuzzy Metaballs paper
        alphas = sigmas * (ts > 0)
        arrival_probs = discrete_arrival_probabilities(alphas)
        op = 1 - arrival_probs[-1] # Opacity
        mean_depth = jnp.sum(arrival_probs[:-1]*xs[:,2]) \
                        + arrival_probs[-1]*zmax
        mean_color = jnp.sum(arrival_probs[:-1,None]*colors[order], axis=0) \
                        + arrival_probs[-1]*bg_color    

        return mean_depth, mean_color, op


    cast_rays = jit(vmap(_cast_ray, (0,None,None,None,None,None,None)))
```