In [71]:
import b3d.chisight.gen3d.image_kernel as image_kernel
import b3d.chisight.gen3d.inference as inference
import b3d.chisight.gen3d.inference_moves as inference_moves
import b3d.chisight.gen3d.transition_kernels as transition_kernels
import jax
import jax.numpy as jnp
import pytest
import matplotlib.pyplot as plt

In [2]:
near, far, image_height, image_width = 0.001, 5.0, 480, 640
img_model = image_kernel.NoOcclusionPerVertexImageKernel(
    near, far, image_height, image_width
)
color_transiton_scale = 0.05
p_resample_color = 0.005

# This parameter is needed for the inference hyperparameters.
# See the `InferenceHyperparams` docstring in `inference.py` for details.
effective_color_transition_scale = color_transiton_scale + p_resample_color * 1 / 2
inference_hyperparams = inference.InferenceHyperparams(
    n_poses=6000,
    pose_proposal_std=0.04,
    pose_proposal_conc=1000.0,
    effective_color_transition_scale=effective_color_transition_scale,
)

hyperparams = {
    "pose_kernel": transition_kernels.UniformPoseDriftKernel(max_shift=0.1),
    "color_kernel": transition_kernels.MixtureDriftKernel(
        [
            transition_kernels.LaplaceNotTruncatedColorDriftKernel(
                scale=color_transiton_scale
            ),
            transition_kernels.UniformDriftKernel(
                max_shift=0.15, min_val=jnp.zeros(3), max_val=jnp.ones(3)
            ),
        ],
        jnp.array([1 - p_resample_color, p_resample_color]),
    ),
    "visibility_prob_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.1, support=jnp.array([0.01, 0.99])
    ),
    "depth_nonreturn_prob_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.1, support=jnp.array([0.01, 0.99])
    ),
    "depth_scale_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.1,
        support=jnp.array([0.0025, 0.01, 0.02, 0.1, 0.4, 1.0]),
    ),
    "color_scale_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.1, support=jnp.array([0.05, 0.1, 0.15, 0.3, 0.8])
    ),
    "image_kernel": img_model,
}

In [66]:
color_scale = 0.01
depth_scale = 0.00

depth_nonreturn_prob_kernel = hyperparams["depth_nonreturn_prob_kernel"]
visibility_prob_kernel = hyperparams["visibility_prob_kernel"]
color_kernel = hyperparams["color_kernel"]
obs_rgbd_kernel = hyperparams["image_kernel"].get_rgbd_vertex_kernel()

previous_color = jnp.array([0.1, 0.2, 0.3])
previous_dnrp = depth_nonreturn_prob_kernel.support[0]
latent_depth = 1.0

@jax.jit
def get_sample(
    key, observed_rgbd_for_point, previous_visibility_prob, previous_color, latent_depth, previous_dnrp, color_scale, depth_scale,
):
    rgb, visibility_prob, dnr_prob = inference_moves._propose_a_points_attributes(
        key,
        observed_rgbd_for_point,
        latent_depth,
        previous_color,
        previous_visibility_prob,
        previous_dnrp,



        depth_nonreturn_prob_kernel,
        visibility_prob_kernel,
        color_kernel,
        obs_rgbd_kernel,

        color_scale,
        depth_scale,

        inference_hyperparams,
        return_metadata=True,
    )[:3]
    return rgb, visibility_prob, dnr_prob

get_samples = jax.vmap(
    get_sample, in_axes=(0, None, None, None, None, None, None, None)
)

In [67]:
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 1000)

observed_rgbd_for_point = jnp.array([0.1, 0.2, 0.3, 0.4])
previous_visibility_prob = 0.1
previous_color = jnp.array([0.1, 0.2, 0.3])
previous_dnrp = 0.1
color_scale, depth_scale = 0.1, 0.005
samples = get_samples(
    keys, observed_rgbd_for_point, previous_visibility_prob, previous_color, latent_depth, previous_dnrp, color_scale, depth_scale
)

In [58]:
from matplotlib.gridspec import GridSpec

def plot_samples(samples, observed_rgbd_for_point):
    fig = plt.figure(layout="constrained", figsize=(10, 10))
    gs = GridSpec(3, 3, figure=fig)
    
    fig.suptitle(f"Observed RGBD: {observed_rgbd_for_point}", fontsize=16)
    rgb, visibility_prob, dnr_prob = samples

    ax = fig.add_subplot(gs[0, 0])
    values, counts = jnp.unique(visibility_prob, return_counts=True)
    ax.bar(values, counts)
    ax.set_xticks(values)
    ax.set_title("Visibility Probability Samples")

    ax = fig.add_subplot(gs[0, 1])
    values, counts = jnp.unique(dnr_prob, return_counts=True)
    ax.bar(values, counts)
    ax.set_xticks(values)
    ax.set_title("Depth Nonreturn Probability Samples")

    ax = fig.add_subplot(gs[1, 0])
    ax.hist(rgb[:,0], jnp.linspace(0, 1, 100), color = "r")
    ax.set_title("R Samples")

    ax = fig.add_subplot(gs[1, 1])
    ax.hist(rgb[:,1], jnp.linspace(0, 1, 100), color = "g")
    ax.set_title("G Samples")

    ax = fig.add_subplot(gs[1, 2])
    ax.hist(rgb[:,2], jnp.linspace(0, 1, 100), color = "b")
    ax.set_title("B Samples")


In [73]:
from ipywidgets import interact
import ipywidgets as widgets
def f(previous_visibility_prob, previous_dnrp, latent_depth, previous_r, previous_g, previous_b,
        color_scale,
        depth_scale,
      ):
    previous_rgb = jnp.array([previous_r, previous_g, previous_b])
    previous_visibility_prob = float(previous_visibility_prob) 
    previous_dnrp = float(previous_dnrp) 
    samples = get_samples(
        keys, observed_rgbd_for_point,
        previous_visibility_prob, previous_rgb, latent_depth, previous_dnrp,
        color_scale,
        depth_scale,
    )
    plot_samples(samples, observed_rgbd_for_point)

interact(f,
    previous_visibility_prob = widgets.ToggleButtons(
        options=[f'{x:.2f}' for x in visibility_prob_kernel.support],
        description='Prev Vis Prob:',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
    ),
    previous_dnrp = widgets.ToggleButtons(
        options=[f'{x:.2f}' for x in depth_nonreturn_prob_kernel.support],
        description='Prev DNR Prob:',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
    ),
    latent_depth = widgets.FloatSlider(
        value=observed_rgbd_for_point[3],
        min=-1.0,
        max=1.0,
        step=0.01,
        description='Latent Depth:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.2f',
    ),
    previous_r = widgets.FloatSlider(
        value=observed_rgbd_for_point[0],
        min=0.0,
        max=1.0,
        step=0.01,
        description='Previous R:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.2f',
    ),
    previous_g = widgets.FloatSlider(
        value=observed_rgbd_for_point[1],
        min=0.0,
        max=1.0,
        step=0.01,
        description='Previous G:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.2f',
    ),
    previous_b = widgets.FloatSlider(
        value=observed_rgbd_for_point[2],
        min=0.0,
        max=1.0,
        step=0.01,
        description='Previous B:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.2f',
    ),
    color_scale = widgets.FloatSlider(
        value=color_scale,
        min=0.01,
        max=0.1,
        step=0.001,
        description='Color Scale:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.2f',
    ),
    depth_scale = widgets.FloatSlider(
        value=depth_scale,
        min=0.001,
        max=0.01,
        step=0.0005,
        description='Depth Scale:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.2f',
    ),
)

interactive(children=(ToggleButtons(description='Prev Vis Prob:', options=('0.01', '0.99'), value='0.01'), Tog…

<function __main__.f(previous_visibility_prob, previous_dnrp, latent_depth, previous_r, previous_g, previous_b, color_scale, depth_scale)>

In [9]:
keys

Array([[2615604937, 1821629751],
       [ 331609114, 2540088542],
       [  62694249, 3219724671],
       ...,
       [2744197942, 3392345896],
       [3019838301,  111166463],
       [3609926923, 3750703181]], dtype=uint32)