In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import rerun as rr
import genjax
import os
import numpy as np
import jax.numpy as jnp
import jax
from b3d import Pose
import b3d
from tqdm   import tqdm

PORT = 8812
rr.init("object_prior_test")
rr.connect(addr=f'127.0.0.1:{PORT}')

In [3]:
path = os.path.join(b3d.get_assets_path(),
#  "shared_data_bucket/input_data/orange_mug_pan_around_and_pickup.r3d.video_input.npz")
# "shared_data_bucket/input_data/shout_on_desk.r3d.video_input.npz")
"shared_data_bucket/input_data/desk_ramen2_spray1.r3d.video_input.npz")
video_input = b3d.VideoInput.load(path)

In [4]:
image_width, image_height, fx,fy, cx,cy,near,far = np.array(video_input.camera_intrinsics_depth)
image_width, image_height = int(image_width), int(image_height)
fx,fy, cx,cy,near,far = float(fx),float(fy), float(cx),float(cy),float(near),float(far)

rgbs = video_input.rgb[::3] / 255.0
xyzs = video_input.xyz[::3]
# Resize rgbs to be same size as depth.
rgbs_resized = jnp.clip(jax.vmap(jax.image.resize, in_axes=(0, None, None))(
    rgbs, (video_input.xyz.shape[1], video_input.xyz.shape[2], 3), "linear"
), 0.0, 1.0)

point_cloud_og = xyzs[0].reshape(-1,3)
colors_og = rgbs_resized[0].reshape(-1,3)

In [5]:
import b3d.model2 as model2

In [6]:
key = jax.random.PRNGKey(0)
intrinsics = (fx,fy, cx,cy,near,far)
subx, suby = 2, 2
w, h = image_width//subx, image_height//suby
p_occupancy = 0.4
C = genjax.Pytree.const
object_model_args = (
    C(image_width), C(image_height),
    C(subx), C(suby),
    C(intrinsics), p_occupancy
)
trace1 = model2.generate_object_voxel_mesh.simulate(key, object_model_args)

In [7]:
xy_indices = jnp.meshgrid(jnp.arange(0, image_width, subx), jnp.arange(0, image_height, suby))
depth_constraint = xyzs[0, :, :, 2][tuple(xy_indices)]
color_constraint = rgbs_resized[0, :, :, :][tuple(xy_indices)]

In [8]:
# "inference" has never been so easy...
background_object_constraints = genjax.choice_map({
    "voxel_present": genjax.vector_choice_map(
        genjax.vector_choice_map(genjax.choice(jnp.ones((w, h), dtype=int)))
    ),
    "depth": genjax.vector_choice_map(
        genjax.vector_choice_map(
            genjax.Mask(
                jnp.ones((w, h), dtype=int),
                genjax.choice(depth_constraint.transpose())
            )
        )
    ),
    "color": genjax.vector_choice_map(
        genjax.vector_choice_map(
            genjax.vector_choice_map(
                genjax.Mask(
                    jnp.ones((w, h, 3), dtype=int),
                    genjax.choice(color_constraint.swapaxes(0, 1))
                )
            )
        )
    )
})

In [9]:
trace2, _ = model2.generate_object_voxel_mesh.importance(key, background_object_constraints, object_model_args)

In [10]:
# Visualize the object in the trace
vertices, faces, vertex_colors = trace2.get_retval()
rr.log(
    "/3d/mesh",
    rr.Mesh3D(
        vertex_positions=vertices,
        indices=faces,
        vertex_colors=vertex_colors
    ),
    timeless=True
)

## Make trace for full model (@ t=0)

In [11]:
renderer = b3d.Renderer(image_width, image_height, fx, fy, cx, cy, near, far)
model = model2.model_multiobject_gl_factory(renderer)
importance_jit = jax.jit(model.importance)

In [12]:
color_error, depth_error = (jnp.float32(30.0), jnp.float32(0.02))
inlier_score, outlier_prob = (jnp.float32(5.0), jnp.float32(0.001))
color_multiplier, depth_multiplier = (jnp.float32(3000.0), jnp.float32(3000.0))
max_n_frames_arr = jnp.arange(20)
n_step_frames = 0

full_model_args = (
    jnp.arange(1),
    color_error, depth_error,
    inlier_score, outlier_prob, color_multiplier, depth_multiplier,
    object_model_args,
    max_n_frames_arr, n_step_frames
)


In [13]:
START_T = 0
full_model_constraints = genjax.choice_map({
    "camera_pose": Pose.identity(),
    "object_pose_0": Pose.identity(),
    "object_0": 0,
    "obs": genjax.choice_map({
        "observed_rgb": rgbs_resized[START_T],
        "observed_depth": xyzs[START_T,...,2]
    }),
    "mesh": background_object_constraints
})

In [14]:
trace_, _ = importance_jit(
    key,
    full_model_constraints,
    full_model_args
)

In [15]:
# model2.rerun_visualize_trace_t(trace_, 0)
model2.rerun_visualize_trace_across_time(trace_)

### inference over time

In [16]:
translation_deltas = Pose.concatenate_poses([jax.vmap(lambda p: Pose.from_translation(p))(jnp.stack(
    jnp.meshgrid(
        jnp.linspace(-0.01, 0.01, 11),
        jnp.linspace(-0.01, 0.01, 11),
        jnp.linspace(-0.01, 0.01, 11),
    ),
    axis=-1,
).reshape(-1, 3)), Pose.identity()[None,...]])

rotation_deltas = Pose.concatenate_poses([jax.vmap(Pose.sample_gaussian_vmf_pose, in_axes=(0,None, None, None))(
    jax.random.split(jax.random.PRNGKey(0), 11*11*11),
    Pose.identity(),
    0.00001, 1000.0
), Pose.identity()[None,...]])


In [17]:
all_deltas =  Pose.stack_poses([translation_deltas, rotation_deltas])

In [18]:
update_jit = jax.jit(model.update)

In [22]:
# for t in tqdm(range(10)):
t = 1
diffed_args = (
    *( # all args except the last one (n timesteps) have nochange
        genjax.Diff(a, genjax.NoChange)
        for a in full_model_args[:-1]
    ),
    genjax.Diff(t, genjax.UnknownChange)
)
constraints = genjax.choice_map({
    "obs_steps": genjax.indexed_choice_map(
        jnp.array([t]),
        genjax.choice_map({
            "observed_rgb": jnp.array([rgbs_resized[t]]),
            "observed_depth": jnp.array([xyzs[t,...,2]])
        })
    )
})
trace = update_jit(
    key, trace_,
    constraints,
    diffed_args
)

TypeError: true_fun and false_fun output must have same type structure, got PyTreeDef((*, (CustomNode(Pose[None], [CustomNode(Diff[('primal', 'tangent'), (), ()], [*, CustomNode(_NoChange[(), (), ()], [])]), CustomNode(Diff[('primal', 'tangent'), (), ()], [*, CustomNode(_NoChange[(), (), ()], [])])]), [CustomNode(Pose[None], [CustomNode(Diff[('primal', 'tangent'), (), ()], [*, CustomNode(_NoChange[(), (), ()], [])]), CustomNode(Diff[('primal', 'tangent'), (), ()], [*, CustomNode(_NoChange[(), (), ()], [])])])]), *, *)) and PyTreeDef((*, (CustomNode(Pose[None], [*, *]), [CustomNode(Pose[None], [*, *])]), *, *)).

In [None]:
rr.log("/rgb", rr.Image(observed_rgb))
rr.log("/depth", rr.DepthImage(observed_depth))
rr.log("/rgb/rendered", rr.Image(rendered_rgb))
rr.log("/depth/rendered", rr.DepthImage(rendered_depth))

NameError: name 'observed_rgb' is not defined

In [None]:
start, length = object_library.ranges[0]

In [None]:
object_library.faces[start:start+length].shape

(147456, 3)

In [None]:
vp, f, vc = object_library.get_object(0)
rr.log(
    f"/foo",
    rr.Mesh3D(
        vertex_positions=vp,
        indices=f,
        vertex_colors=vc
    ),
    timeless=True
)

# Below here is scratch work

In [None]:
(voxel_present, depth, colors, resolutions, point_centers) = trace.get_retval()
depth_flat = depth.reshape(-1)
colors_flat = colors.reshape(-1, 3)

In [None]:
resolutions.shape

(3072,)

In [None]:
voxel_present.shape

(48, 64)

In [None]:
vertices, faces, vertex_colors, face_colors = b3d.utils.make_mesh_from_point_cloud_and_resolution_2(
    point_centers,
    colors_flat,
    resolutions,
    voxel_present.reshape(-1)
)

In [None]:
voxel_present, depth, colors = trace.get_retval()

In [None]:
f = lambda m : m.match(lambda: -1.0, lambda x : x)
jax.vmap(jax.vmap(f, in_axes=(0,)), in_axes=(0,))(depth)

Array([[ 8.2909977e-03,  7.7206809e-03,  5.9520635e-03, ...,
         4.7618030e-03,  9.8284325e-03, -1.0000000e+00],
       [ 1.9602785e-03,  1.8960666e-03, -1.0000000e+00, ...,
         3.5102842e-03,  1.7511351e-03,  6.8051419e-03],
       [ 3.4411047e-03, -1.0000000e+00,  8.8439239e-03, ...,
         4.7888327e-03,  2.6385877e-03, -1.0000000e+00],
       ...,
       [-1.0000000e+00,  4.2594587e-03,  3.0224035e-03, ...,
         2.0125713e-03,  7.3168082e-03,  7.4895606e-03],
       [ 8.3255786e-03,  7.2128675e-03,  2.3591798e-04, ...,
         6.9076382e-04,  4.4467375e-03,  8.7554250e-03],
       [ 2.0493641e-03, -1.0000000e+00,  3.4507955e-03, ...,
        -1.0000000e+00,  2.1509863e-03,  5.5744122e-03]], dtype=float32)

In [None]:
genjax.Mask.match

BoundMethod(
  __func__=<wrapped function match>,
  __self__=Mask(flag=i32[48,64], value=f32[48,64])
)

In [None]:
vertices, faces, vertex_colors, face_colors = trace.get_retval()