In [2]:
import jax.numpy as jnp
import bayes3d as b
import os
import jax
import functools
from jax.scipy.special import logsumexp
from functools import partial
from tqdm import tqdm
import matplotlib.pyplot as plt
import bayes3d.genjax
import genjax
import pathlib


In [3]:
(image, latent) = jnp.load("../mug/hello.npz.npy")

In [5]:
import jax.numpy as jnp
import jax
import numpy as np
import functools
from functools import partial
from jax.scipy.special import logsumexp, erf

logerf  = lambda x: jnp.logaddexp(0.0, jnp.log(2) + jax.scipy.stats.norm.logcdf(x * jnp.sqrt(2)))

def single_patch_likelihood(p: jnp.ndarray,
                            bottom_left: jnp.ndarray,
                            top_right: jnp.ndarray,
                            sigma: float):
    x, y, z = bottom_left
    X, Y, _ = top_right
    C = (2*jnp.pi)**(-3/2)  * sigma**(-3)
    z_term = jnp.exp(- (p[2] - z)**2/sigma**2)
    x_term = jnp.sqrt(jnp.pi) * sigma / 2 * (erf((p[0] - 2*x) / sigma) + erf((- p[0] + x + X) / sigma))
    y_term = jnp.sqrt(jnp.pi) * sigma / 2 * (erf((p[1] - 2*y) / sigma) + erf((- p[1] + y + Y) / sigma))
    return C * x_term * y_term * z_term


@functools.partial(
    jnp.vectorize,
    signature='(m)->()',
    excluded=(1,2,3,4,5,6,7,),
)
def gausssian_mixture_vectorize(
    ij,
    observed_xyz: jnp.ndarray,
    rendered_xyz_padded: jnp.ndarray,
    variance,
    outlier_prob: float,
    outlier_volume: float,
    focal_length,
    filter_size: int,
):
    p = observed_xyz[ij[0], ij[1], :3]
    filter_latent = jax.lax.dynamic_slice(rendered_xyz_padded, (ij[0], ij[1], 0), (2*filter_size + 1, 2*filter_size + 1, 3))

    half_widths = (filter_latent[:, :, 2] / focal_length) / 2.0
    width_observed = (p[2] / focal_length)

    delta = jnp.stack([half_widths, half_widths, jnp.zeros_like(half_widths)],axis=-1)
    bottom_lefts = filter_latent - delta
    top_rights =  filter_latent + delta

    sigma = jnp.sqrt(variance)
    C = (2*jnp.pi)**(-3/2)  * sigma**(-3)

    z_term = - (p[2] - bottom_lefts[:,:,2])**2/sigma**2
    x_term = jnp.log(jnp.sqrt(jnp.pi) * sigma / 2) +  jnp.logaddexp(
                                                        logerf((p[0] - 2*bottom_lefts[:,:,0]) / sigma), 
                                                    logerf((- p[0] + bottom_lefts[:,:,0] + top_rights[:,:,0]) / sigma))
    y_term = jnp.log(jnp.sqrt(jnp.pi) * sigma / 2) + jnp.logaddexp(logerf((p[1] - 2*bottom_lefts[:,:,1]) / sigma),logerf((- p[1] + bottom_lefts[:,:,1] + top_rights[:,:,1]) / sigma))

    log_probabilities_per_latent = x_term + y_term + z_term - jnp.log(((2*half_widths)**2).sum())
    log_probability = logsumexp(log_probabilities_per_latent) + 2*jnp.log(width_observed)
    return log_probability

def threedp3_likelihood_per_pixel(
    observed_xyz: jnp.ndarray,
    rendered_xyz: jnp.ndarray,
    variance,
    outlier_prob,
    outlier_volume,
    focal_length,
    filter_size
):
    rendered_xyz_padded = jax.lax.pad(rendered_xyz,  -100.0, ((filter_size,filter_size,0,),(filter_size,filter_size,0,),(0,0,0,)))
    jj, ii = jnp.meshgrid(jnp.arange(observed_xyz.shape[1]), jnp.arange(observed_xyz.shape[0]))
    indices = jnp.stack([ii,jj],axis=-1)
    log_probabilities = gausssian_mixture_vectorize(
        indices, observed_xyz,
        rendered_xyz_padded,
        variance, outlier_prob, outlier_volume,
        focal_length,
        filter_size
    )
    return log_probabilities

def threedp3_likelihood(
    observed_xyz: jnp.ndarray,
    rendered_xyz: jnp.ndarray,
    variance,
    outlier_prob,
    outlier_volume,
    focal_length,
    filter_size
):
    log_probabilities_per_pixel = threedp3_likelihood_per_pixel(
        observed_xyz, rendered_xyz, variance,
        outlier_prob, outlier_volume,
        focal_length,
        filter_size
    )
    return log_probabilities_per_pixel.sum()

threedp3_likelihood_jit = jax.jit(threedp3_likelihood,static_argnames=('filter_size',))
threedp3_likelihood_per_pixel_jit = jax.jit(threedp3_likelihood_per_pixel, static_argnames=('filter_size',))


In [13]:
args = (0.001,
    0.0,
    1.0,
    200.0,
    3)
scores_exact = threedp3_likelihood_per_pixel_jit(
    image,
    image, 
    *args)
scores_hyp = threedp3_likelihood_per_pixel_jit(
    image,
    latent, 
    *args)
print(scores_exact.sum(), scores_hyp.sum())

-52037.71 -52058.17
