<a href="https://colab.research.google.com/github/shaunaknn/cfd-tutorials/blob/main/Jax_intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap
from jax import random

In [2]:
np.zeros(10)

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [3]:
jnp.zeros(10)

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [4]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[ 1.6226422   2.0252647  -0.43359444 -0.07861735  0.1760909  -0.97208923
 -0.49529874  0.4943786   0.6643493  -0.9501635 ]


In [5]:
size = 3000
x = random.normal(key, (size,size), dtype = jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

1.23 s ± 125 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
y = np.random.rand(size, size)
%timeit np.dot(y, y.T)

1.21 s ± 309 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
from jax._src.api import block_until_ready
x = np.random.normal(size = (size,size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

1.01 s ± 207 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
from jax import device_put

x = np.random.normal(size = (size,size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

915 ms ± 20.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
@jit
def selu(x, alpha = 1.67, lmbda = 1.05):
  return lmbda * jnp.where(x>0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (10000000,))
%timeit selu(x).block_until_ready()

24 ms ± 835 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

33.6 ms ± 8.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
@jit
def selu(x, alpha = 1.67, lmbda = 1.05):
  return lmbda * jnp.where(x>0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (10000000,))
%timeit selu(x).block_until_ready()

41.4 ms ± 2.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
from jax import config
config.update("jax_enable_x64", True)

from numpy.matrixlib import defmatrix

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(100.)
y_small = jnp.linspace(0,3,100)

derivative_fn = grad(sum_logistic)  # define a function which is the gradient of sum_logistic
%timeit derivative_fn(x_small)
print(derivative_fn(x_small))

3.47 ms ± 268 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
[2.50000000e-01 1.96611933e-01 1.04993585e-01 4.51766597e-02
 1.76627062e-02 6.64805667e-03 2.46650929e-03 9.10221180e-04
 3.35237671e-04 1.23379350e-04 4.53958077e-05 1.67011429e-05
 6.14413685e-06 2.26031919e-06 8.31527336e-07 3.05902133e-07
 1.12535149e-07 4.13993738e-08 1.52299793e-08 5.60279637e-09
 2.06115361e-09 7.58256042e-10 2.78946809e-10 1.02618796e-10
 3.77513454e-11 1.38879439e-11 5.10908903e-12 1.87952882e-12
 6.91440011e-13 2.54366565e-13 9.35762297e-14 3.44247711e-14
 1.26641655e-14 4.65888615e-15 1.71390843e-15 6.30511676e-16
 2.31952283e-16 8.53304763e-17 3.13913279e-17 1.15482242e-17
 4.24835426e-18 1.56288219e-18 5.74952226e-19 2.11513104e-19
 7.78113224e-20 2.86251858e-20 1.05306174e-20 3.87399763e-21
 1.42516408e-21 5.24288566e-22 1.92874985e-22 7.09547416e-23
 2.61027907e-23 9.60268005e-24 3.53262857e-24 1.29958143e-24
 4.78089288e-25 1.75879220e-25 6.47023493e-26 2.38026641e-26
 8.75651076e-27 3

In [13]:
def first_finite_differences(f, x):
  eps = 1e-6
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

%timeit first_finite_differences(sum_logistic, x_small)
print(first_finite_differences(sum_logistic, x_small))
# x = jnp.arange(10)
# for v in jnp.eye(len(x)):
#   print(v)
# jnp.array([v for v in jnp.eye(len(x))])


32.4 ms ± 815 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
[2.49999999e-01 1.96611936e-01 1.04993582e-01 4.51766624e-02
 1.76626997e-02 6.64805810e-03 2.46650700e-03 9.10219455e-04
 3.35241168e-04 1.23378641e-04 4.53965754e-05 1.66977543e-05
 6.14619466e-06 2.25952590e-06 8.31335001e-07 3.05533376e-07
 1.13686838e-07 4.26325641e-08 1.42108547e-08 7.10542736e-09
 0.00000000e+00 7.10542736e-09 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0

In [14]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.03532558051623558


In [15]:
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

In [16]:
mat = random.normal(key, (150,100))
batched_x = random.normal(key, (10,100))

def apply_matrix(v):
  return jnp.dot(mat, v)

A = random.normal(key, (4,4))
x = jnp.arange(4)
y = jnp.arange(4)
z = jnp.stack([x,y],1) # if axis = 0, jnp.dot(A,z) will have an error
print(x)
print(A)
print(jnp.dot(A,x))  # jnp.dot is matrix multiplication. If x is a vector, it represents Ax
print(A*x)           # A*x: A = [a1,a2,a3,a4] A*x = [a1*x1, a2*x2, a3*x3, a4*x4]

print(z)
print(jnp.dot(A,z))
print(A @ z)         # A @ z is matrix multiplication

[0 1 2 3]
[[-0.20584214 -0.78476578  1.81608667  0.18784401]
 [ 0.08086788 -0.37211079  1.19016372  0.33864229]
 [ 0.08482584 -0.87181784  1.05451609 -1.5594979 ]
 [ 0.36753958  2.51635215  0.25856516 -0.28371043]]
[ 3.41093961  3.02414353 -3.44127934  2.18235118]
[[-0.         -0.78476578  3.63217335  0.56353204]
 [ 0.         -0.37211079  2.38032745  1.01592687]
 [ 0.         -0.87181784  2.10903219 -4.6784937 ]
 [ 0.          2.51635215  0.51713031 -0.85113128]]
[[0 0]
 [1 1]
 [2 2]
 [3 3]]
[[ 3.41093961  3.41093961]
 [ 3.02414353  3.02414353]
 [-3.44127934 -3.44127934]
 [ 2.18235118  2.18235118]]
[[ 3.41093961  3.41093961]
 [ 3.02414353  3.02414353]
 [-3.44127934 -3.44127934]
 [ 2.18235118  2.18235118]]


In [17]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

#apply_matrix(batched_x) #result in an error

In [18]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched], 0)

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
1.39 ms ± 397 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
75.2 µs ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [20]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
91.8 µs ± 5.54 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
