In [1]:
import numpy as np
import jax
from time import time

## TPU

In [2]:
@jax.jit
def f_tpu(x):
    return x @ x

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
x1 = jax.random.normal(key, (5000, 5000))
x2 = jax.random.normal(subkey, (5000, 5000))

%timeit -n1 -r1 f_tpu(x1).block_until_ready()
%timeit -n1 -r1 f_tpu(x2).block_until_ready()

672 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
12.7 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


## CPU (numpy)

In [3]:
def f_cpu(x):
    return x @ x

x1 = np.random.normal(size=(5000, 5000))
x2 = np.random.normal(size=(5000, 5000))

%timeit -n1 -r1 f_cpu(x1)
%timeit -n1 -r1 f_cpu(x2)

271 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
212 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


## CPU (jax.numpy)

In [4]:
@jax.jit
def f_jnp(x):
    return x @ x

cpu = jax.devices('cpu')[0]

x1 = jax.random.normal(key, (5000, 5000))
x1 = jax.device_put(x1, cpu)
x2 = jax.random.normal(subkey, (5000, 5000))
x2 = jax.device_put(x2, cpu)

%timeit -n1 -r1 f_jnp(x1).block_until_ready()
%timeit -n1 -r1 f_jnp(x2).block_until_ready()

210 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
99.6 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
