In [119]:
### Preliminaries ###

import jax.numpy as jnp
import jax
import os
import trimesh
import b3d
from b3d import Pose
import rerun as rr
import genjax
from tqdm import tqdm
import demos.differentiable_renderer.likelihood_debugging.demo_utils as du
import demos.differentiable_renderer.likelihood_debugging.model as m
import demos.differentiable_renderer.likelihood_debugging.likelihoods as l
import b3d.differentiable_renderer as r
import matplotlib.pyplot as plt
import numpy as np


In [142]:
rr.init("test2")
rr.connect("127.0.0.1:8812")

In [121]:
import importlib
importlib.reload(du)

<module 'demos.differentiable_renderer.likelihood_debugging.demo_utils' from '/home/georgematheos/b3d/demos/differentiable_renderer/likelihood_debugging/demo_utils.py'>

In [122]:
(
    renderer,
    (observed_rgbds, gt_rots),
    ((patch_vertices_P, patch_faces, patch_vertex_colors), X_WP),
    X_WC
) = du.get_renderer_boxdata_and_patch()

hyperparams = r.DifferentiableRendererHyperparams(
    3, 1e-5, 1e-2, -1
)

depth_scale = 0.0001
color_scale = 0.002
mindepth = -1.0
maxdepth = 2.0
likelihood = l.ArgMap(
    l.ImageDistFromPixelDist(
        l.uniform_multilaplace_mixture,
        [True, True, False, False, False, False]
    ),
    lambda weights, rgbds: ( renderer.height, renderer.width,
                            weights, rgbds, (depth_scale,), (color_scale,), mindepth, maxdepth )
)

model = m.single_object_model_factory(
    renderer,
    likelihood,
    hyperparams
)

In [123]:

### Generate image samples ###

def generate_image(key):
    trace, weight = model.importance(
        key,
        genjax.choice_map({ "pose": X_WP, "camera_pose": X_WC }),
        (patch_vertices_P, patch_faces, patch_vertex_colors, ())
    )
    return trace.get_retval()[0]
key = jax.random.PRNGKey(0)
images = jax.vmap(generate_image)(jax.random.split(key, 100))
for i, image in enumerate(images):
    rr.set_time_sequence("image_sample", i)
    rr.log(f"/image_sample/rgb", rr.Image(image[:, :, :3]))
    rr.log(f"/image_sample/depth", rr.DepthImage(image[:, :, 3]))



In [124]:
# W = world; C = Camera; P0 = patch at time 0; P20 = patch at time 20; B20 = Box
R_W_B0 = b3d.Rot.from_matrix(gt_rots[0])
R_W_B20 = b3d.Rot.from_matrix(gt_rots[20])

X_W_B0 = Pose.from_quat(R_W_B0.as_quat())
X_W_B20 = Pose.from_quat(R_W_B20.as_quat())

X_W_P0 = X_WP
X_P0_B0 = X_W_P0.inv().compose(X_W_B0)
X_P_B = X_P0_B0 # This should remain constant

X_P0_B20 = X_W_P0.inv().compose(X_W_B20)
# X_W_B20 = X_W_P0.compose(X_P0_B20)
X_W_P20 = X_W_B20.compose(X_P_B.inv())

In [125]:
# Get frame 19
R_W_B19 = b3d.Rot.from_matrix(gt_rots[19])
X_W_B19 = Pose.from_quat(R_W_B19.as_quat())
X_P0_B19 = X_W_P0.inv().compose(X_W_B19)
X_W_P19 = X_W_B19.compose(X_P_B.inv())

In [126]:
r = X_W_P20.rot * b3d.Rot.from_euler("y", 0.9, degrees=False)
pose_init = Pose.from_vec(jnp.concatenate([
    X_W_P20._position, r.as_quat()
]))

In [127]:
### Grid over positions and show samples ###
def importance_from_pos_quat(pos, quat):
    pose = Pose.from_vec(jnp.concatenate([pos, quat]))
    trace, weight = model.importance(
        key,
        genjax.choice_map({
            "pose": pose,
            "camera_pose": X_WC,
            "observed_rgbd": observed_rgbds[20]
        }),
        (patch_vertices_P, patch_faces, patch_vertex_colors, ())
    )
    return trace, weight

def weight_from_pos_quat(pos, quat):
    return importance_from_pos_quat(pos, quat)[1]

In [128]:
grad_jitted_2 = jax.jit(jax.grad(weight_from_pos_quat, argnums=(0, 1,)))
grad_jitted_2(pose_init._position, pose_init._quaternion)

(Array([1999.2793 , 3784.2354 , -159.25781], dtype=float32),
 Array([-430.0289 , -205.71703, -388.1051 ,  157.40657], dtype=float32))

In [129]:
import optax

In [130]:
### This fits position decently ###

# pos = pose_init._position - jnp.array([0.01, 0.015, -0.007])
# quat = pose_init._quaternion
params = {
    "pos": X_W_P19._position,
    "quat": X_W_P20.rot.as_quat()
}
optimizer = optax.adam(learning_rate=2e-4, b1=0.7)#{"pos": 1e-6, "quat": 1e-4})
opt_state = optimizer.init(params)
key = jax.random.PRNGKey(40)
for i in range(40):
    pos, quat = params["pos"], params["quat"]
    grad_pos, grad_quat = grad_jitted_2(pos, quat)
    updates, opt_state = optimizer.update({"pos": -grad_pos, "quat": jnp.zeros_like(-grad_quat)}, opt_state)
    params = optax.apply_updates(params, updates)
    tr, weight = importance_from_pos_quat(params["pos"], params["quat"])
    rr.set_time_sequence("adam-3", i)
    m.rr_log_trace(tr, renderer)
    rr.log("weight", rr.Scalar(weight))

In [131]:
## Looks pretty good... ##

# pos = pose_init._position - jnp.array([0.01, 0.015, -0.007])
# quat = pose_init._quaternion
params = {
    "pos": X_W_P19._position,
    "quat": X_W_P19._quaternion
}
optimizer_pos = optax.adam(learning_rate=2e-4, b1=0.7)#{"pos": 1e-6, "quat": 1e-4})
optimizer_quat = optax.adam(learning_rate=4e-3)
opt_state_pos = optimizer.init(params["pos"])
opt_state_quat = optimizer.init(params["quat"])
key = jax.random.PRNGKey(40)
for i in range(30):
    pos, quat = params["pos"], params["quat"]
    grad_pos, grad_quat = grad_jitted_2(pos, quat)
    updates_pos, opt_state_pos = optimizer_pos.update(-grad_pos, opt_state_pos)
    updates_quat, opt_state_quat = optimizer_quat.update(-grad_quat, opt_state_quat)
    params["pos"] = optax.apply_updates(params["pos"], updates_pos)
    params["quat"] = optax.apply_updates(params["quat"], updates_quat)
    tr, weight = importance_from_pos_quat(params["pos"], params["quat"])
    rr.set_time_sequence("adam-19to20-2", i)
    m.rr_log_trace(tr, renderer)
    rr.log("weight", rr.Scalar(weight))

In [132]:
### Try patch tracking ###

In [133]:
### Grid over positions and show samples ###
def importance_from_pos_quat_v3(pos, quat, timestep):
    pose = Pose.from_vec(jnp.concatenate([pos, quat]))
    trace, weight = model.importance(
        key,
        genjax.choice_map({
            "pose": pose,
            "camera_pose": X_WC,
            "observed_rgbd": observed_rgbds[timestep]
        }),
        (patch_vertices_P, patch_faces, patch_vertex_colors, ())
    )
    return trace, weight

def weight_from_pos_quat_v3(pos, quat, timestep):
    return importance_from_pos_quat_v3(pos, quat, timestep)[1]

In [134]:
grad_jitted_3 = jax.jit(jax.grad(weight_from_pos_quat_v3, argnums=(0, 1,)))

In [148]:
optimizer_pos = optax.adam(learning_rate=1e-4, b1=0.7)
optimizer_quat = optax.adam(learning_rate=4e-3)

@jax.jit
def optimizer_kernel(st, i):
    opt_state_pos, opt_state_quat, pos, quat, timestep = st
    grad_pos, grad_quat = grad_jitted_3(pos, quat, timestep)
    updates_pos, opt_state_pos = optimizer_pos.update(-grad_pos, opt_state_pos)
    updates_quat, opt_state_quat = optimizer_quat.update(-grad_quat, opt_state_quat)
    pos = optax.apply_updates(pos, updates_pos)
    quat = optax.apply_updates(quat, updates_quat)
    return (opt_state_pos, opt_state_quat, pos, quat, timestep), (pos, quat)

In [149]:
@jax.jit
def unfold_20_steps(st):
    ret_st, _ = jax.lax.scan(optimizer_kernel, st, jnp.arange(20))
    return ret_st

@jax.jit
def unfold_40_steps(st):
    ret_st, _ = jax.lax.scan(optimizer_kernel, st, jnp.arange(40))
    return ret_st

@jax.jit
def unfold_100_steps(st):
    ret_st, _ = jax.lax.scan(optimizer_kernel, st, jnp.arange(100))
    return ret_st

@jax.jit
def unfold_300_steps(st):
    ret_st, _ = jax.lax.scan(optimizer_kernel, st, jnp.arange(300))
    return ret_st

@jax.jit
def unfold_150_steps(st):
    ret_st, _ = jax.lax.scan(optimizer_kernel, st, jnp.arange(150))
    return ret_st

In [150]:
opt_state_pos = optimizer.init(X_WP._position)
opt_state_quat = optimizer.init(X_WP._quaternion)

In [151]:
### THIS CELL SUCCESSFULLY TRACKS THE OBJECT THROUGH THE WHOLE VIDEO!! ###

opt_state_pos = optimizer.init(X_WP._position)
opt_state_quat = optimizer.init(X_WP._quaternion)
pos = X_WP._position
quat = X_WP._quaternion
for timestep in range(30):
    opt_state_pos = optimizer.init(pos)
    opt_state_quat = optimizer.init(quat)
    (opt_state_pos, opt_state_quat, pos, quat, _) = unfold_100_steps(
        (opt_state_pos, opt_state_quat, pos, quat, timestep)
    )
    tr, weight = importance_from_pos_quat_v3(pos, quat, timestep)
    rr.set_time_sequence("tracking-frameONLY-7", timestep)
    m.rr_log_trace(tr, renderer)

In [139]:
opt_state_pos = optimizer.init(X_WP._position)
opt_state_quat = optimizer.init(X_WP._quaternion)
pos = X_WP._position
quat = X_WP._quaternion
N_STEPS = 80
for timestep in range(10):
    opt_state_pos = optimizer.init(pos)
    opt_state_quat = optimizer.init(quat)
    for i in range(N_STEPS):
        (opt_state_pos, opt_state_quat, pos, quat, _), _ = optimizer_kernel(
            (opt_state_pos, opt_state_quat, pos, quat, timestep), i
        )
        tr, weight = importance_from_pos_quat_v3(pos, quat, timestep)
        rr.set_time_sequence("full_seq-10", i + timestep * N_STEPS)
        m.rr_log_trace(tr, renderer)
        rr.log("weight", rr.Scalar(weight))
    tr, weight = importance_from_pos_quat_v3(pos, quat, timestep)
    rr.set_time_sequence("tracking-frame-10", timestep)
    m.rr_log_trace(tr, renderer)