# JAX Autodiff Implementation Benchmark Testing 

In [1]:

import jax.numpy as jnp
from jax import grad, jit


# Test input
x = jnp.array([
    [1.0, 0.0, 0.0],  # Point 1
    [0.0, 1.0, 0.0],  # Point 2
    [0.0, 0.0, 1.0],  # Point 3
])

def r_f(x):
    r21 = x[0, :] - x[1, :]
    r23 = x[2, :] - x[1, :]
    r31 = x[2, :] - x[0, :]

    cos2theta = jnp.dot(r21, r23)**2 / (jnp.dot(r21, r21) * jnp.dot(r23, r23))
    sin2theta = 1 - cos2theta

    R2 = jnp.dot(r31, r31) / sin2theta / 4
    return jnp.sqrt(R2)


@jit
def r_f_jit(x):
    r21 = x[0, :] - x[1, :]
    r23 = x[2, :] - x[1, :]
    r31 = x[2, :] - x[0, :]

    cos2theta = jnp.dot(r21, r23)**2 / (jnp.dot(r21, r21) * jnp.dot(r23, r23))
    sin2theta = 1 - cos2theta

    R2 = jnp.dot(r31, r31) / sin2theta / 4
    return jnp.sqrt(R2)


# Benchmarking
print("Benchmarking r_f(x)")
%timeit r_f(x)

print("Benchmarking r_f_jit(x):")
%timeit r_f_jit(x)


print("Benchmarking grad(r_f)(x)")
%timeit grad(r_f)(x)

print("Benchmarking grad(r_f_jit)(x)")
%timeit grad(r_f_jit)(x)

print("Benchmarking jit(grad(r_f_jit))(x)")
%timeit jit(grad(r_f_jit))(x)

print("Benchmarking grad_jit = jit(grad(r_f_jit))(x)")
g1 = jit(grad(r_f_jit))(x)
%timeit g1

print("Benchmarking grad_jit = jit(grad(r_f))(x)")
g2 = jit(grad(r_f))(x)
%timeit g2


Benchmarking r_f(x)
1.31 ms ± 109 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Benchmarking r_f_jit(x):
7.9 μs ± 453 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Benchmarking grad(r_f)(x)
13 ms ± 844 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Benchmarking grad(r_f_jit)(x)
1.61 ms ± 315 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Benchmarking jit(grad(r_f_jit))(x)
93.5 ms ± 6.37 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Benchmarking grad_jit = jit(grad(r_f_jit))(x)
17.4 ns ± 0.116 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)
Benchmarking grad_jit = jit(grad(r_f))(x)
17 ns ± 0.541 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)
