# Speed!
XLA is the backend that backs Tensorflow, also backs pytorch when using TPUs. FAST!

In [1]:
from jax import jit
import jax.numpy as jnp
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (5000, 5000))

In [3]:
def f(x):
    y = x
    for _ in range(10):
        y = y - 0.1 * y + 3.
    return y[:100, :100]

f(x)

DeviceArray([[19.498985, 19.82703 , 19.105713, ..., 20.288704, 19.505972,
              18.800804],
             [20.183136, 19.432299, 19.18563 , ..., 19.554996, 18.913944,
              19.772003],
             [19.303692, 19.688906, 19.3761  , ..., 19.901508, 19.094027,
              19.131971],
             ...,
             [19.463682, 19.39614 , 19.29148 , ..., 19.64818 , 20.130577,
              19.868395],
             [19.425179, 19.687124, 19.9568  , ..., 19.83506 , 19.73038 ,
              19.514133],
             [19.437508, 19.607405, 19.633265, ..., 19.108871, 20.188248,
              19.94832 ]], dtype=float32)

In [4]:
g = jit(f)
g(x)

DeviceArray([[19.498985, 19.82703 , 19.105713, ..., 20.288704, 19.505972,
              18.800804],
             [20.183136, 19.432299, 19.18563 , ..., 19.554996, 18.913944,
              19.772003],
             [19.303692, 19.688906, 19.3761  , ..., 19.901508, 19.094027,
              19.131971],
             ...,
             [19.463682, 19.39614 , 19.29148 , ..., 19.64818 , 20.130577,
              19.868395],
             [19.425179, 19.687124, 19.9568  , ..., 19.83506 , 19.73038 ,
              19.514133],
             [19.437508, 19.607405, 19.633265, ..., 19.108871, 20.188248,
              19.94832 ]], dtype=float32)

## Microbenchmarks!

In [5]:
%timeit f(x).block_until_ready()

7.02 ms ± 2.17 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
%timeit -n 100 g(x).block_until_ready()

33.9 µs ± 1.48 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
from jax import grad

grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)

DeviceArray(0.6216267, dtype=float32)

In [9]:
%timeit grad(jit(grad(jit(grad(jnp.tanh)))))(1.0).block_until_ready()

76 ms ± 208 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


https://github.com/google/jax/blob/main/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb