# Real functions
Consider a real-valued function of a real variable $f: \mathbb{R} \rightarrow \mathbb{R}$ given by
$$
f(x) = \frac{1}{2}{x}^2. \quad
$$
Its derivative is
$$
\frac{d}{d x}f(x) = x. 
$$

In [115]:
import jax.numpy as jnp
from jax import grad
from jax import vmap
from jax import random

key = random.PRNGKey(0)

In [116]:
# parameters
batch_size = 3#int(1e5)

1. no batch input

In [117]:
def f(x):
    return 0.5 * x**2

def der1_f(x):
    return grad(f)(x)

In [118]:
# initialize input
x = random.normal(key)

# compute gradients
grad_y = grad_f(x)

x, grad_y

(Array(-0.20584226, dtype=float32), Array(-0.20584226, dtype=float32))

2. Batch input

In [119]:
def f_vect(inputs):
    return vmap(f)(inputs)
    
def der1_f_vect(inputs):
    return vmap(der1_f)(inputs)

In [120]:
x = random.normal(key, (batch_size,))

In [121]:
%time

# compute gradients
grad_y = grad_f_vect(x)

CPU times: user 1e+03 ns, sys: 0 ns, total: 1e+03 ns
Wall time: 2.86 µs


In [122]:
x, grad_y

(Array([ 1.8160863 , -0.48262316,  0.33988908], dtype=float32),
 Array([ 1.8160863 , -0.48262316,  0.33988908], dtype=float32))

3. k-th derivatives

In [123]:
def der2_f(x):
    #return grad(der1_f)(x)
    return grad(grad(f))(x)

def der3_f(x):
    #return grad(der2_f)(x)
    return grad(grad(grad(f)))(x)

In [124]:
# initialize input
x = random.normal(key)

# input, first, second and third derivative
#x, grad(f)(x), grad(grad(f))(x), grad(grad(grad((f))))(x)
x, der1_f(x), der2_f(x), der3_f(x)

(Array(-0.20584226, dtype=float32),
 Array(-0.20584226, dtype=float32),
 Array(1., dtype=float32),
 Array(0., dtype=float32))

4. k-th derivatives (with batch input)

In [125]:
def der2_f_vect(inputs):
    return vmap(der2_f(inputs))

In [126]:
x = random.normal(key, (batch_size,))

In [127]:
x, der1_f_vect(x), der2_f_vect(x) 

TypeError: Gradient only defined for scalar-output functions. Output had shape: (3,).

In [98]:
x, vmap(grad(f)(x))#, vmap(grad(grad(f)))(x)

TypeError: Gradient only defined for scalar-output functions. Output had shape: (3,).