In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

2023-09-16 21:31:49.231042: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-09-16 21:31:49.231100: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 6422528 bytes free, 8514043904 bytes total.
2023-09-16 21:31:49.231143: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:453] Possibly insufficient driver version: 510.85.2


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)

In [4]:
%timeit jnp.dot(x, x.T).block_until_ready()

10.1 ms ± 4.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)

In [6]:
%timeit jnp.dot(x, x.T).block_until_ready()

26.2 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

8.89 ms ± 77.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [9]:
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

431 µs ± 48.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

110 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
