In [1]:
import numpy as onp
import jaxlie
from jax import numpy as np
import jax
import time

In [2]:
# Set to CPU
# Comment out to use GPU/TPU
jax.config.update('jax_platform_name', 'cpu')

In [3]:
jax.lib.xla_bridge.get_backend().platform

'cpu'

In [4]:
def time_func(f, key, calls):
    start = time.perf_counter()
    for _ in range(calls):
        f(key)
        _, key = jax.random.split(key)
    end = time.perf_counter()
    return (end - start) / calls

In [5]:
# https://github.com/brentyi/jaxlie/blob/9f177f2640641c38782ec1dc07709a41ea7713ea/jaxlie/manifold/_manifold_helpers.py
# https://brentyi.github.io/jaxlie/vmap_usage/

@jax.jit
@jax.vmap
def inverse_compose(world_T_body, point):
    out_point = lambda parameters: jaxlie.SE3(parameters).inverse().apply(point)
    # jacfwd is indeed better here
    result_D_storage = jax.jacfwd(out_point)(world_T_body.parameters())
    storage_D_tangent = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(world_T_body)
    J = result_D_storage @ storage_D_tangent
    return out_point(world_T_body.parameters()), J

In [6]:
key = jax.random.PRNGKey(42)

for N in reversed([1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7]):
    N = int(N)
    
    storage = np.zeros((N, 7))
    storage = storage.at[:, 0].set(1)
    poses = jaxlie.SE3(wxyz_xyz=storage)

    points = jax.random.normal(key, (N, 3))
    _, key = jax.random.split(key)

    inverse_compose(poses, points)
    
    def random_inverse_compose(key):
        points_new = points.at[0, 0].set(jax.random.normal(key))
        return inverse_compose(poses, points_new)
    
    t = time_func(random_inverse_compose, key, 10)    

    def random_ninverse_compose(key):
        points_new = points.at[0, 0].set(jax.random.normal(key))
        return points_new
    
    _, key = jax.random.split(key)
    t2 = time_func(random_ninverse_compose, key, 10)
    
    print(f"{N:>10}   {t:10.5} {t2:10.5} {t - t2:10.5} {(t - t2) / N:10.5}")

  10000000       5.5745   0.032849     5.5416 5.5416e-07
   1000000      0.55822  0.0024404    0.55578 5.5578e-07
    100000     0.057941  0.0013382   0.056602 5.6602e-07
     10000    0.0075167  0.0013319  0.0061848 6.1848e-07
      1000    0.0033099  0.0013599    0.00195   1.95e-06
       100    0.0029165  0.0013437  0.0015728 1.5728e-05
        10    0.0030606  0.0013474  0.0017132 0.00017132
