# Inverse Compose, with Jax

This runs the inverse compose benchmark with Jax, either on CPU or GPU.  We use the [jaxlie](https://brentyi.github.io/jaxlie) library for Lie Group operations.  We then compute the resulting point and jacobian of the point with respect to the pose, batched over large numbers of poses and points.  

See [the paper](https://symforce.org/paper) for more information.

In [None]:
import time

import jax
import jaxlie
import numpy as onp
from jax import numpy as np

In [None]:
# Set to CPU
# Comment out to use GPU/TPU
jax.config.update("jax_platform_name", "cpu")

In [None]:
# Print the platform (CPU/GPU) we're using
jax.lib.xla_bridge.get_backend().platform

In [None]:
def time_func(f, key, calls):
    start = time.perf_counter()
    for _ in range(calls):
        f(key)
        _, key = jax.random.split(key)
    end = time.perf_counter()
    return (end - start) / calls

In [None]:
# Helpful source and documentation references:
# https://github.com/brentyi/jaxlie/blob/9f177f2640641c38782ec1dc07709a41ea7713ea/jaxlie/manifold/_manifold_helpers.py
# https://brentyi.github.io/jaxlie/vmap_usage/


@jax.jit
@jax.vmap
def inverse_compose(world_T_body, point):
    # A helper function that computes the result as a function of the parameters
    out_point = lambda parameters: jaxlie.SE3(parameters).inverse().apply(point)
    # The jacobian of the output with respect to the parameters (not the tangent space)
    # jacfwd is indeed better than jacrev here
    result_D_storage = jax.jacfwd(out_point)(world_T_body.parameters())
    # The jacobian of the parameters with respect to the tangent space
    storage_D_tangent = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(world_T_body)
    # Put it all together
    J = result_D_storage @ storage_D_tangent
    return out_point(world_T_body.parameters()), J

In [None]:
key = jax.random.PRNGKey(42)

for N in reversed([1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7]):
    N = int(N)

    storage = np.zeros((N, 7))
    storage = storage.at[:, 0].set(1)
    poses = jaxlie.SE3(wxyz_xyz=storage)

    points = jax.random.normal(key, (N, 3))
    _, key = jax.random.split(key)

    inverse_compose(poses, points)

    def random_inverse_compose(key):
        points_new = points.at[0, 0].set(jax.random.normal(key))
        return inverse_compose(poses, points_new)

    t = time_func(random_inverse_compose, key, 10)

    def random_ninverse_compose(key):
        points_new = points.at[0, 0].set(jax.random.normal(key))
        return points_new

    _, key = jax.random.split(key)
    t2 = time_func(random_ninverse_compose, key, 10)

    print(f"{N:>10}   {t:10.5} {t2:10.5} {t - t2:10.5} {(t - t2) / N:10.5}")