# Robot Localization, in Jax

This runs the robot localization benchmark in Jax.  We only compute the linearization, and do not implement the optimization loop.  We also compute a large number of linearizations in batch.

This can do the experiment on either CPU or GPU.

See [the paper](https://symforce.org/paper) for more information.

In [None]:
import itertools
import time

import jax
import jaxlie
import numpy as onp
from jax import numpy as np

In [None]:
# Set to CPU
# Comment out to use GPU/TPU
jax.config.update("jax_platform_name", "cpu")

In [None]:
# Print the platform (CPU/GPU) we're using
jax.lib.xla_bridge.get_backend().platform

In [None]:
def between(a, b):
    return a.inverse() @ b


def local_coordinates(a, b):
    return between(a, b).log()


# https://github.com/brentyi/jaxlie/blob/9f177f2640641c38782ec1dc07709a41ea7713ea/jaxlie/manifold/_manifold_helpers.py
def matching_residual(world_T_body, world_t_landmark, body_t_landmark, sigma):
    residual = (
        lambda parameters: (jaxlie.SE3(parameters).apply(body_t_landmark) - world_t_landmark)
        / sigma
    )
    # TODO: fwd or reverse here?
    residual_D_storage = jax.jacfwd(residual)(world_T_body.parameters())
    storage_D_tangent = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(world_T_body)
    J = residual_D_storage @ storage_D_tangent
    return residual(world_T_body.parameters()), J


def odometry_residual(world_T_a, world_T_b, a_T_b, diagonal_sigmas):
    storage_D_tangent_a = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(world_T_a)
    storage_D_tangent_b = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(world_T_b)
    residual = (
        lambda parameters: local_coordinates(
            between(jaxlie.SE3(parameters[:7]), jaxlie.SE3(parameters[7:])), a_T_b
        )
        / diagonal_sigmas
    )
    # TODO: fwd or reverse here?
    residual_D_storage = jax.jacfwd(residual)(
        np.concatenate((world_T_a.parameters(), world_T_b.parameters()))
    )
    J = residual_D_storage @ np.block(
        [[storage_D_tangent_a, np.zeros((7, 6))], [np.zeros((7, 6)), storage_D_tangent_b]]
    )
    return residual(np.concatenate((world_T_a.parameters(), world_T_b.parameters()))), J

In [None]:
# https://brentyi.github.io/jaxlie/vmap_usage/
pose_matching_residual = jax.vmap(matching_residual, (None, 0, 0, None), (0, 0))
full_matching_residual = jax.vmap(pose_matching_residual, (0, None, 0, None), (0, 0))
full_odometry_residual = jax.vmap(odometry_residual, (0, 0, 0, None))

In [None]:
@jax.jit
def problem_linearization(
    poses,
    odometry_relative_pose_measurements,
    world_t_landmark,
    body_t_landmark,
    measurement_sigma,
    odometry_diagonal_sigmas,
):
    matching_b, matching_J = full_matching_residual(
        poses, world_t_landmark, body_t_landmark, measurement_sigma
    )
    odometry_b, odometry_J = full_odometry_residual(
        jaxlie.SE3(wxyz_xyz=poses.wxyz_xyz[:-1]),
        jaxlie.SE3(wxyz_xyz=poses.wxyz_xyz[1:]),
        odometry_relative_pose_measurements,
        odometry_diagonal_sigmas,
    )
    return matching_b, matching_J, odometry_b, odometry_J


problem_linearization = jax.jit(jax.vmap(problem_linearization, (0, 0, 0, 0, None, None)))

In [None]:
def time_func(f, key, calls):
    start = time.perf_counter()
    for _ in range(calls):
        f(key)
        _, key = jax.random.split(key)
    end = time.perf_counter()
    return (end - start) / calls

In [None]:
key = jax.random.PRNGKey(42)

nbatches = reversed([1, 1e1, 1e2, 1e3, 1e4, 1e5])
# nposes = reversed([3, 10, 1e2, 1e3])
# nlandmarks = reversed([5, 20, 1e2, 1e3, 1e4])
nposes = [5]
nlandmarks = [20]
for BATCH, NUM_POSES, NUM_LANDMARKS in itertools.product(nbatches, nposes, nlandmarks):
    BATCH = int(BATCH)
    NUM_POSES = int(NUM_POSES)
    NUM_LANDMARKS = int(NUM_LANDMARKS)

    storage = np.zeros((BATCH, NUM_POSES, 7))
    storage = storage.at[:, 0].set(1)
    poses = jaxlie.SE3(wxyz_xyz=storage)

    odometry_relative_pose_measurements = jaxlie.SE3(wxyz_xyz=poses.wxyz_xyz[..., :-1, :])
    world_t_landmark = jax.random.normal(key, (BATCH, NUM_LANDMARKS, 3))
    _, key = jax.random.split(key)
    body_t_landmark = jax.random.normal(key, (BATCH, NUM_POSES, NUM_LANDMARKS, 3))
    _, key = jax.random.split(key)
    measurement_sigma = 1
    odometry_sigmas = np.ones(6)

    problem_linearization(
        poses,
        odometry_relative_pose_measurements,
        world_t_landmark,
        body_t_landmark,
        measurement_sigma,
        odometry_sigmas,
    )

    def random_linearization(key):
        world_t_landmark_new = world_t_landmark.at[0, 0].set(jax.random.normal(key))
        return problem_linearization(
            poses,
            odometry_relative_pose_measurements,
            world_t_landmark_new,
            body_t_landmark,
            measurement_sigma,
            odometry_sigmas,
        )

    t = time_func(random_linearization, key, 10)

    def random_ninearization(key):
        world_t_landmark_new = world_t_landmark.at[0, 0].set(jax.random.normal(key))
        return world_t_landmark_new

    _, key = jax.random.split(key)
    t2 = time_func(random_ninearization, key, 10)

    print(
        f"{BATCH:>8}   {NUM_POSES:>8}   {NUM_LANDMARKS:>8}   {t:10.5}   {t2:10.5}   {t - t2:10.5}   {(t - t2)/BATCH:10.5}"
    )