In [11]:
import timeit
from functools import partial

import jax
import jax.numpy as jnp

In [12]:
@partial(jax.jit, static_argnums=(2,))
def contract(
    Hv: jnp.array,
    q: jnp.array,
    optimize="auto",
):
    nl = jnp.einsum(
        "m i j, k l j, i, k, l -> m",
        Hv,
        Hv,
        q,
        q,
        q,
        optimize=optimize,
    )
    return nl

In [14]:
n_modes = 100
Hv = jax.random.normal(jax.random.PRNGKey(0), (n_modes, n_modes, n_modes))
q = jax.random.normal(jax.random.PRNGKey(1), (n_modes,))
_ = contract(Hv, q, optimize="auto")
_ = contract(Hv, q, optimize="greedy")
_ = contract(Hv, q, optimize="optimal")
_ = contract(Hv, q, optimize="eager")

In [30]:
%timeit -n 10 -r 1000 contract(Hv, q, optimize="auto")
%timeit -n 10 -r 1000 contract(Hv, q, optimize="greedy")
%timeit -n 10 -r 1000 contract(Hv, q, optimize="optimal")
%timeit -n 10 -r 1000 contract(Hv, q, optimize="eager")

The slowest run took 12.98 times longer than the fastest. This could mean that an intermediate result is being cached.
57.6 μs ± 65.1 μs per loop (mean ± std. dev. of 1000 runs, 10 loops each)
The slowest run took 17.44 times longer than the fastest. This could mean that an intermediate result is being cached.
63.9 μs ± 81 μs per loop (mean ± std. dev. of 1000 runs, 10 loops each)
The slowest run took 16.96 times longer than the fastest. This could mean that an intermediate result is being cached.
58.1 μs ± 82.7 μs per loop (mean ± std. dev. of 1000 runs, 10 loops each)
The slowest run took 11.77 times longer than the fastest. This could mean that an intermediate result is being cached.
60.2 μs ± 76.4 μs per loop (mean ± std. dev. of 1000 runs, 10 loops each)
