# Matrix Multiplication, with Jax

This runs a large number of batched matrix multiplcations with Jax.  Jax does have support for sparse matrices, but for the matrix sizes we're interested in they aren't a good performance tradeoff.

This can run the benchmark both on CPU and GPU.

See [the paper](https://symforce.org/paper) for more information.

In [None]:
import time

import jax
import numpy as onp
from jax import numpy as np

In [None]:
# Set to CPU
# Comment out to use GPU/TPU
jax.config.update("jax_platform_name", "cpu")

In [None]:
# Print the platform (CPU/GPU) we're using
jax.lib.xla_bridge.get_backend().platform

In [None]:
def time_func(f, key, calls):
    start = time.perf_counter()
    for _ in range(calls):
        f(key)
        _, key = jax.random.split(key)
    end = time.perf_counter()
    return (end - start) / calls

In [None]:
key = jax.random.PRNGKey(42)

for N in reversed([1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7]):
    N = int(N)
    mat_size_m = 20
    mat_size_n = 15

    A = jax.random.normal(key, (N, mat_size_m, mat_size_n))
    _, key = jax.random.split(key)
    B = jax.random.normal(key, (N, mat_size_n, mat_size_m))
    _, key = jax.random.split(key)

    np.matmul(A, B)

    def matmul(key):
        A_new = A.at[0, 0, 0].set(jax.random.normal(key))
        return np.matmul(A_new, B)

    t = time_func(matmul, key, 10)

    def notmul(key):
        A_new = A.at[0, 0, 0].set(jax.random.normal(key))
        return A_new

    _, key = jax.random.split(key)
    t2 = time_func(notmul, key, 10)

    print(f"{N:>10}   {t:10.5} {t2:10.5} {t - t2:10.5} {(t - t2) / N:10.5}")