In [None]:
import bayes3d as b
import bayes3d.genjax
import genjax
import jax.numpy as jnp
import jax
import os
from tqdm import tqdm
console = genjax.pretty(show_locals=False)
from genjax._src.core.transforms.incremental import NoChange
from genjax._src.core.transforms.incremental import UnknownChange
from genjax._src.core.transforms.incremental import Diff

In [None]:
b.setup_visualizer()

In [None]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=500.0, fy=500.0,
    cx=50.0, cy=50.0,
    near=0.01, far=20.0
)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
meshes = []
for idx in range(1,22):
    mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
    b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)

# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/1000000000.0)

In [None]:
# del importance_jit
importance_jit = jax.jit(b.genjax.model.importance)
key = jax.random.PRNGKey(5)

In [None]:
@genjax.gen
def single_object_model():
    pose = b.genjax.uniform_pose(jnp.array([-0.01,-0.01,1.5]), jnp.array([0.01,0.01,3.5])) @ "pose"
    obj_id = 0
    rendered = b.RENDERER.render(
        pose[None,...] , jnp.array([obj_id])
    )[...,:3]
    image = b.genjax.image_likelihood(rendered, 0.01, 0.01, 1.0) @ "image"
    return rendered

importance_jit = jax.jit(single_object_model.importance)
key = jax.random.PRNGKey(5)

In [None]:
key, (_,gt_trace) = importance_jit(key, genjax.choice_map({}), ())
print(gt_trace.get_score())
b.get_depth_image(gt_trace["image"][...,2])

In [None]:
importance_parallel = jax.jit(jax.vmap(single_object_model.importance, in_axes=(0, None, None)))

In [None]:
keys = jax.random.split(key, 1000)
keys, (weights, traces) = importance_parallel(keys, genjax.choice_map({"image": gt_trace["image"]}), ());

In [None]:
traces.get_retval().shape

In [None]:
sampled_indices = jax.random.categorical(key, weights, shape=(10,))
print(sampled_indices)
print(weights[sampled_indices])
images = [b.get_depth_image(img[:,:,2]) for img in traces.get_retval()[sampled_indices]]

In [None]:
b.multi_panel(images,title="10 Posterior Samples", title_fontsize=20).convert("RGB")

In [None]:
variance, concentration = 0.001, 0.001

In [None]:
def importance_sampling_with_proposal(key, trace, variance, concentration):
    pose_mean = b.transform_from_pos(jnp.array([0.0, 0.0, 1.0]))
    pose = b.distributions.gaussian_vmf_jit(key, pose_mean, variance, concentration)
    proposal_weight = b.distributions.gaussian_vmf_logpdf_jit(pose, pose_mean, variance, concentration)
    new_trace = trace.update(key, genjax.choice_map({"root_pose_0": pose}), 
                             b.genjax.make_unknown_change_argdiffs(trace))[1][2]
    return new_trace,new_trace.get_score() - proposal_weight
importance_sampling_with_proposal_vmap = jax.vmap(importance_sampling_with_proposal, in_axes=(0, None, None, None))

In [None]:
traces, weights = importance_sampling_with_proposal_vmap(jax.random.split(key, 100), gt_trace, 0.001, 0.001)

In [None]:
sampled_indices = jax.random.categorical(key, weights, shape=(10,))
print(sampled_indices)
print(weights[sampled_indices])
images = [b.get_depth_image(img[:,:,2]) for img in b.genjax.get_rendered_image(traces)[sampled_indices]]
b.multi_panel(images,title="10 Posterior Samples", title_fontsize=20).convert("RGB")