In [93]:
import b3d.chisight.gen3d.image_kernel as image_kernel
import b3d.chisight.gen3d.inference as inference
import b3d.chisight.gen3d.inference_moves as inference_moves
import b3d.chisight.gen3d.transition_kernels as transition_kernels
import jax
import b3d
import jax.numpy as jnp
import pytest
import matplotlib.pyplot as plt

In [90]:
near, far, image_height, image_width = 0.001, 5.0, 480, 640
img_model = image_kernel.NoOcclusionPerVertexImageKernel(
    near, far, image_height, image_width
)
color_transiton_scale = 0.05
p_resample_color = 0.005

# This parameter is needed for the inference hyperparameters.
# See the `InferenceHyperparams` docstring in `inference.py` for details.
effective_color_transition_scale = color_transiton_scale + p_resample_color * 1 / 2
inference_hyperparams = inference.InferenceHyperparams(
    n_poses=6000,
    pose_proposal_std=0.04,
    pose_proposal_conc=1000.0,
    effective_color_transition_scale=effective_color_transition_scale,
)

hyperparams = {
    "pose_kernel": transition_kernels.UniformPoseDriftKernel(max_shift=0.1),
    "color_kernel": transition_kernels.MixtureDriftKernel(
        [
            transition_kernels.LaplaceNotTruncatedColorDriftKernel(
                scale=color_transiton_scale
            ),
            transition_kernels.UniformDriftKernel(
                max_shift=0.15, min_val=jnp.zeros(3), max_val=jnp.ones(3)
            ),
        ],
        jnp.array([1 - p_resample_color, p_resample_color]),
    ),
    "visibility_prob_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.1, support=jnp.array([0.01, 0.99])
    ),
    "depth_nonreturn_prob_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.1, support=jnp.array([0.01, 0.99])
    ),
    "depth_scale_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.1,
        support=jnp.array([0.0025, 0.01, 0.02, 0.1, 0.4, 1.0]),
    ),
    "color_scale_kernel": transition_kernels.DiscreteFlipKernel(
        resample_probability=0.1, support=jnp.array([0.05, 0.1, 0.15, 0.3, 0.8])
    ),
    "image_kernel": img_model,
}

In [98]:
from b3d.chisight.gen3d.visualization import create_interactive_visualization
b3d.reload(b3d.chisight.gen3d.visualization)
observed_rgbd_for_point = jnp.array([0.1, 0.2, 0.3, 0.4])
create_interactive_visualization(
    observed_rgbd_for_point,
    hyperparams,
    inference_hyperparams,
)

interactive(children=(ToggleButtons(description='Prev Vis Prob:', options=('0.01', '0.99'), value='0.01'), Tog…