In [1]:
import bayes3d as b
import jax.numpy as jnp
import jax
import os
import matplotlib.pyplot as plt
import matplotlib

In [2]:
b.setup_renderer

<function bayes3d.renderer.setup_renderer(intrinsics, num_layers=1024)>

In [3]:

intrinsics = b.Intrinsics(
    height=1000,
    width=1000,
    fx=500.0, fy=500.0,
    cx=500.0, cy=500.0,
    near=0.01, far=10.0
)

b.setup_renderer(intrinsics)
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
meshes = []
for idx in range(1,22):
    mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
    b.RENDERER.add_mesh_from_file(mesh_path, scaling_factor=10.0/1000.0)

# b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(), "sample_objs/cube.obj"), scaling_factor=1.0/10000000000.0)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (1024, 1024, 1024)


In [4]:
#b.setup_visualizer()

In [5]:
b.ycb_loader.MODEL_NAMES[11]

'021_bleach_cleanser'

In [65]:
IDX = 11
frames = 40
devices = 1
dots = 250
pc = jnp.array(b.RENDERER.meshes[IDX].vertices)
lifetime = 5 #keep 1-1/5 of the dots after every frame update
point_rad = 5

if devices > 1:
    pc_subsample_start = pc[jax.random.choice(jax.random.PRNGKey(10), jnp.arange(pc.shape[0]), shape=(dots,) )] #want 1000 dots total
    pc_replacements = pc[jax.random.choice(jax.random.PRNGKey(0), jnp.arange(pc.shape[0]), shape=(devices,frames//devices,dots//lifetime) )]

    pc_subsamples = jnp.zeros((devices,frames//devices,*pc_subsample_start.shape))
    pc_subsamples = pc_subsamples.at[0,...].set(pc_subsample_start)
    for i in range(1,frames//devices*devices):
        pc_subsamples = pc_subsamples.at[i,...].set(pc_subsamples[i-1,...])
        sampled_indices = jax.random.choice(jax.random.PRNGKey(i), jnp.arange(dots), shape=(dots//lifetime,) )
        pc_subsamples = pc_subsamples.at[i//devices, i%devices,sampled_indices,...].set(pc_replacements[i//devices, i%devices,...])

else:
    pc_subsample_start = pc[jax.random.choice(jax.random.PRNGKey(10), jnp.arange(pc.shape[0]), shape=(dots,) )] #want 1000 dots total
    pc_replacements = pc[jax.random.choice(jax.random.PRNGKey(0), jnp.arange(pc.shape[0]), shape=(frames,dots//lifetime) )]

    pc_subsamples = jnp.zeros((frames,*pc_subsample_start.shape))
    pc_subsamples = pc_subsamples.at[0,...].set(pc_subsample_start)
    for i in range(1,frames):
        pc_subsamples = pc_subsamples.at[i,...].set(pc_subsamples[i-1,...])
        sampled_indices = jax.random.choice(jax.random.PRNGKey(i), jnp.arange(dots), shape=(dots//lifetime,) )
        pc_subsamples = pc_subsamples.at[i,sampled_indices,...].set(pc_replacements[i,...])

    

In [66]:
#b.show_cloud("1", pc_subsamples[0])

In [67]:
poses = jnp.array([b.t3d.inverse_pose(b.t3d.transform_from_pos_target_up(
        jnp.array([0.0, 2.0, 0.0]),
        jnp.array([0.0, 0.0, 0.0]),
        jnp.array([0.0, 0.0, 1.0]),
)) @ b.t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle) for angle in jnp.linspace(0.0, 4*jnp.pi, frames)])

if devices > 1:
        poses = poses.reshape((devices,frames//devices,poses.shape[1],poses.shape[2]))


In [68]:
def circles(flips_xy, radius):
    centers = jnp.array((flips_xy>0).nonzero(size=5000,fill_value=jnp.inf))
    x,y = jnp.meshgrid(jnp.arange(flips_xy.shape[1]),jnp.arange(flips_xy.shape[0]))
    xymesh = jnp.array([y,x])
    distances_to_keypoints = (
        jnp.linalg.norm(xymesh[:, :,:,None] - centers[:,None, None,:],
        axis=0
    ))
    index_of_nearest_keypoint = distances_to_keypoints.argmin(2)
    distance_to_nearest_keypoints = distances_to_keypoints.min(2)
    DISTANCE_THRESHOLD = radius
    valid_match_mask = (distance_to_nearest_keypoints < DISTANCE_THRESHOLD)[...,None]
    return valid_match_mask

def render_point_light(pose, pc_to_render, key):
    pc_in_camera_frame = b.t3d.apply_transform(pc_to_render, pose)
    img = b.render_point_cloud(pc_in_camera_frame, intrinsics)
    rendered_image = point_cloud_img = b.RENDERER.render_single_object(pose,  jnp.int32(IDX))[:,:,:3]
    mask = (rendered_image[:,:,2] < intrinsics.far)
    
    matches = (jnp.abs(img[:,:,2] - rendered_image[:,:,2]) < 0.05)
    
    flips = (jax.random.uniform(key,shape=matches.shape) < 0.0005)
    
    final_no_noise = circles(mask * matches,point_rad)
    final_with_noise = circles(mask * matches + (1.0 - mask) * flips, point_rad)

    return final_no_noise, final_with_noise

gpus = jax.devices('gpu')
render_point_light_parallel_jit = jax.jit(jax.vmap(render_point_light, in_axes=(0,0, 0)))

In [69]:
if devices > 1:
    key = jax.random.PRNGKey(100)
    keys = jax.random.split(jax.random.PRNGKey(100), poses.shape[0]*poses.shape[1])
    keys = jnp.reshape(keys,(poses.shape[0],poses.shape[1],-1))
    render_point_light_parallel = jax.vmap(render_point_light, in_axes=(0,0, 0))
    pmapout = jax.pmap(lambda poses, pc_subsamples, keys: render_point_light_parallel(poses, pc_subsamples, keys),in_axes=0)
    pmapout(poses, pc_subsamples, keys)

else:
    key = jax.random.PRNGKey(100)
    keys = jax.random.split(jax.random.PRNGKey(100), poses.shape[0])
    images_no_noise, images = render_point_light_parallel_jit(poses, pc_subsamples, keys)

In [70]:
viz = [b.get_depth_image(1.0 - point_light_image * 1.0, cmap=matplotlib.colormaps['Greys']) for point_light_image in images ]
b.make_gif_from_pil_images(viz, "out_noise.gif")

In [71]:
viz[0].save('out_frame.png')

In [72]:
viz = [b.get_depth_image(1.0 - point_light_image * 1.0, cmap=matplotlib.colormaps['Greys']) for point_light_image in images_no_noise ]
b.make_gif_from_pil_images(viz, "out_clean.gif")

In [73]:
static = jnp.repeat(images[0,...][jnp.newaxis,...], frames, axis=0)

In [74]:
viz = [b.get_depth_image(1.0 - point_light_image * 1.0, cmap=matplotlib.colormaps['Greys']) for point_light_image in jnp.concatenate((static, images_no_noise, images),axis=2)]
b.make_gif_from_pil_images(viz, "out_merge.gif")