In [None]:
import jax.numpy as jnp
import bayes3d as b
import os
import jax
import functools
from jax.scipy.special import logsumexp
from functools import partial
from tqdm import tqdm
import matplotlib.pyplot as plt
import bayes3d.genjax
import genjax
import pathlib
import numpy as np

In [None]:
intrinsics = b.Intrinsics(
    height=100,
    width=100,
    fx=200.0, fy=200.0,
    cx=50.0, cy=50.0,
    near=0.0001, far=5.0
)

In [None]:
b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
mesh_path = os.path.join(model_dir,"obj_" + "{}".format(14).rjust(6, '0') + ".ply")
b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=1.0/1000.0)

In [None]:
table_pose = b.t3d.inverse_pose(
    b.t3d.transform_from_pos_target_up(
        jnp.array([0.0, 0.5, .05]),
        jnp.array([0.0, 0.0, 0.0]),
        jnp.array([0.0, 0.0, 1.0]),
    )
)
contact_params_to_pose = jax.jit(jax.vmap(lambda cp: table_pose @ b.scene_graph.relative_pose_from_edge(cp, 3, b.RENDERER.model_box_dims[0])))
pose = contact_params_to_pose(jnp.zeros((1,3)))
img = b.RENDERER.render_many(pose[:,None,...], jnp.array([0]))[0]
b.get_depth_image(img[...,2])

In [None]:
poses.shape

In [None]:
cp = jax.random.uniform(jax.random.PRNGKey(10), (100,3)) * 0.1
poses = contact_params_to_pose(cp)
imgs = b.RENDERER.render_many(poses[:,None,...], jnp.array([0]))

In [None]:

plt.imshow()

In [None]:
depth = img[:,:,2]
depth_ = depth.at[depth >= depth.max()].set(jnp.inf)

In [None]:
b

In [None]:
def preprocess(img):
    depth_np = np.array(img)
    depth_np[depth_np >= depth_np.max()] = np.inf
    return depth_np

In [None]:
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.imshow(preprocess(img[:,:,2]))

In [None]:
%%time
import copy
my_cmap = copy.copy(plt.cm.get_cmap('turbo'))
my_cmap.set_bad(alpha=0)
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.imshow(depth_,cmap=my_cmap)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)

fig.savefig("fig.png")
# ,bbox_inches='tight', pad_inches=0)
fig

In [None]:
fig

In [None]:
plt.savefig("fig.png")

In [None]:
plt.plot(jnp.zeros(10))
plt.show()

In [None]:
%matplotlib notebook

import random
import numpy as np

import matplotlib
import matplotlib.pyplot as plt

import matplotlib.animation as animation


fps = 30
nSeconds = 5
snapshots = [ np.random.rand(5,5) for _ in range( nSeconds * fps ) ]

# First set up the figure, the axis, and the plot element we want to animate
fig = plt.figure( figsize=(8,8) )

a = snapshots[0]
im = plt.imshow(a, interpolation='none', aspect='auto', vmin=0, vmax=1)

def animate_func(i):
    if i % fps == 0:
        print( '.', end ='' )

    im.set_array(snapshots[i])
    return [im]

anim = animation.FuncAnimation(
                               fig, 
                               animate_func, 
                               frames = nSeconds * fps,
                               interval = 1000 / fps, # in ms
                               )

anim.save('test_anim.mp4', fps=fps)

print('Done!')

In [None]:
FILTER_SIZE = 3


@functools.partial(
    jnp.vectorize,
    signature='(m)->()',
    excluded=(0,2,),
)
def compute_score_vectorize(observed_xyz, latent_filter_xyz, variance):
    latent_filter_xyz

    
@functools.partial(
    jnp.vectorize,
    signature='(m)->(j,j)',
    excluded=(1,2,3,),
)
def convolutional_filter_vectorize(
    ij,
    observed_xyz: jnp.ndarray,
    rendered_xyz_padded: jnp.ndarray,
    variance,
):
    latent_filter_xyz = jax.lax.dynamic_slice(
        rendered_xyz_padded,
        (ij[0], ij[1], 0),
        (2*FILTER_SIZE + 1, 2*FILTER_SIZE + 1, 3)
    )
    scores = compute_score_vectorize(observed_xyz, latent_filter_xyz, variance)
    return scores

def likelihood(observed_xyz, rendered_xyz, variance):
    rendered_xyz_padded = jax.lax.pad(rendered_xyz,  intrinsics.far, ((FILTER_SIZE,FILTER_SIZE,0,),(FILTER_SIZE,FILTER_SIZE,0,),(0,0,0,)))
    jj, ii = jnp.meshgrid(jnp.arange(observed_xyz.shape[1]), jnp.arange(observed_xyz.shape[0]))
    indices = jnp.stack([ii,jj],axis=-1)

    log_probabilities = convolutional_filter_vectorize(
        indices, observed_xyz,
        rendered_xyz_padded,
        focal_length,
    )
    return log_probabilities.min(-1).min(-1)

In [None]:
center_1 = jnp.array([1.0, 1.0, 1.0])
center_2 = center_1
print(compute_score(center_1, center_2, 200.0))

center_1 = jnp.array([1.0, 1.0, 1.0])
center_2 = jnp.array([0.1, 0.1, 1.0])
print(compute_score(center_1, center_2, 200.0))

In [None]:
cloud = (jax.random.uniform(jax.random.PRNGKey(10), shape=(100,3)) - 0.5) * 0.1
def render_img(pose):
    img = b.render_point_cloud(b.apply_transform(cloud, pose), intrinsics)
    return b.unproject_depth(img[:,:,2], intrinsics)
pose1 = b.transform_from_pos(jnp.array([0.0, 0.0, 3.0]))
pose2 = b.transform_from_pos(jnp.array([0.0, 10.0, 3.0]))
img1 = render_img(pose1)
img2 = render_img(pose2)
focal_length = intrinsics.fx
b.viz.scale_image(b.get_depth_image(img1[...,2]),5.0)

In [None]:
likelihoods1 = likelihood(img1, img1, 200.0) 
likelihoods2 = likelihood(img1, img2, 200.0)
print(logsumexp(likelihoods1), logsumexp(likelihoods2))
diff = likelihoods1 - likelihoods2
print(jnp.abs(diff).sum())
plt.matshow(jnp.hstack([likelihoods1, likelihoods2]))
plt.colorbar()

plt.matshow(diff)
plt.colorbar()

In [None]:
i,j = jnp.unravel_index(diff.argmax(), diff.shape)
print(img1[i,j])
print(img2[i,j])
print(likelihoods1[i,j])
print(likelihoods2[i,j])

In [None]:
print(convolutional_filter_vectorize(jnp.array([[i,j]]), img1, img1, 200.0).sum())

In [None]:
likelihoods1[i,j]

In [None]:
print(likelihoods1[i,j])
print(likelihoods2[i,j])

In [None]:
filter1 = convolutional_filter(jnp.array([i,j]), img1, img1, 200.0)
filter2 = convolutional_filter(jnp.array([i,j]), img1, img2, 200.0)

print(filter1.sum(), filter2.sum())
plt.matshow(filter1)

In [None]:
print(filter1.sum(), filter2.sum())

In [None]:
filter1[3,3]

In [None]:
filter2[3,3]

In [None]:
scores = likelihood(img1, img1, 200.0)
plt.matshow(scores[10:-10,10:-10])
plt.colorbar()

In [None]:
print(logsumexp(likelihood(img1, img1, 200.0)))
print(logsumexp(likelihood(img1, img2, 200.0)))

In [None]:
likelihoods1 = likelihood(img1, img1, 200.0) 
likelihoods2 = likelihood(img1, img2, 200.0)
diff = likelihoods1 - likelihoods2
print(jnp.abs(diff).sum())

In [None]:
plt.matshow(jnp.hstack([likelihoods1, likelihoods2]))
plt.colorbar()

In [None]:
diff = 0