# Real functions
Consider a real-valued function of a real variable $f: \mathbb{R} \rightarrow \mathbb{R}$ given by
$$
f(x) = \frac{1}{2}{x}^2. \quad
$$
Its derivative is
$$
\frac{d}{d x}f(x) = x. 
$$

In [2]:
import jax.numpy as jnp
from jax import grad, vmap, random

key = random.PRNGKey(0)

In [3]:
# parameters
batch_size = int(1e6)

1. no batch input

In [4]:
def f(x):
    return 0.5 * x**2

def der1_f(x):
    return grad(f)(x)

In [5]:
# initialize input
x = random.normal(key)

# compute gradients
grad_y = der1_f(x)

x, grad_y

(Array(-0.20584226, dtype=float32), Array(-0.20584226, dtype=float32))

2. Batch input

In [6]:
def f_vect(inputs):
    return vmap(f, in_axes=0, out_axes=0)(inputs)
    
def der1_f_vect(inputs):
    return vmap(der1_f, in_axes=0, out_axes=0)(inputs)

In [7]:
x = random.normal(key, (batch_size,))

In [41]:
%time

# compute gradients
grad_y = der1_f_vect(x)

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 4.05 µs


In [42]:
x, grad_y

(Array([ 1.99376   ,  0.20781846, -0.34406224, ...,  0.03467206,
         0.7103182 ,  0.1965587 ], dtype=float32),
 Array([ 1.99376   ,  0.20781846, -0.34406224, ...,  0.03467206,
         0.7103182 ,  0.1965587 ], dtype=float32))

3. k-th derivatives

In [43]:
def der2_f(x):
    #return grad(der1_f)(x)
    return grad(grad(f))(x)

def der3_f(x):
    #return grad(der2_f)(x)
    return grad(grad(grad(f)))(x)

In [44]:
# initialize input
x = random.normal(key)

# input, first, second and third derivative
#x, grad(f)(x), grad(grad(f))(x), grad(grad(grad((f))))(x)
x, der1_f(x), der2_f(x), der3_f(x)

(Array(-0.20584226, dtype=float32),
 Array(-0.20584226, dtype=float32),
 Array(1., dtype=float32),
 Array(0., dtype=float32))

4. k-th derivatives (with batch input)

In [45]:
def der2_f_vect(inputs):
    return vmap(der2_f, in_axes=0, out_axes=0)(inputs)

In [46]:
x = random.normal(key, (batch_size,))

In [47]:
x, der1_f_vect(x), der2_f_vect(x) 

(Array([ 1.99376   ,  0.20781846, -0.34406224, ...,  0.03467206,
         0.7103182 ,  0.1965587 ], dtype=float32),
 Array([ 1.99376   ,  0.20781846, -0.34406224, ...,  0.03467206,
         0.7103182 ,  0.1965587 ], dtype=float32),
 Array([1., 1., 1., ..., 1., 1., 1.], dtype=float32))