# Real-valued functions (scalar functions)

Consider a real-valued function of multiple real variables i.e. a scalar function $f: \mathbb{R}^n \rightarrow \mathbb{R}$ given by
$$
f(x_1, x_2) = \frac{1}{2}{x_1}^2 + x_2. \quad
$$
i.e. $n=2$. Its partial derivatives are
$$
\frac{\partial}{\partial x_1}f(x_1, x_2) = x_1, \quad 
\frac{\partial}{\partial x_2}f(x_1, x_2) = 1.
$$

In [18]:
from jax import numpy as jnp, random
from jax import grad, vmap

In [19]:
# create key
key = random.PRNGKey(0)

# parameters
n = 2
batch_size = int(1e6)

In [20]:
def check_gradient(x, grad_y):
    grad_y_ref = jnp.array([x[0], 1.])
    if jnp.isclose(grad_y, grad_y_ref).all().item():
        return True
    else:
        return False

1. Vector input variable (no batch)

In [21]:
def f(x):
    return 0.5 * x[0]**2 + x[1]

def grad_f(x):
    return grad(f)(x)

In [22]:
# initialize input
x = random.normal(key, (n,))

# compute gradients
grad_y = grad_f(x)

# check gradient
check_gradient(x, grad_y)

True

2. Multiple scalar input variables (no batch)

In [23]:
def g(x1, x2):
    return 0.5 * x1**2 + x2

def grad_g(x1, x2):
    return grad(g, argnums=(0, 1))(x1, x2)

In [24]:
# initialize input
x1 = random.normal(key)
x2 = random.normal(key)

# compute gradients
grad_y = jnp.array(grad_g(x1, x2))

# check gradient
check_gradient(jnp.array([x1, x2]), grad_y)

True

3. Vector input variable (batch input)

In [25]:
def check_gradient_vect(x, grad_y):
    grad_y_ref = jnp.hstack((
        x[:, 0][:, None], 
        jnp.ones(batch_size)[:, None],
    ))
    if jnp.isclose(grad_y, grad_y_ref).all().item():
        return True
    else:
        return False

In [26]:
def f_vect(inputs):
    return vmap(f, in_axes=0, out_axes=0)(inputs)
    
def grad_f_vect(inputs):
    return vmap(grad_f, in_axes=0, out_axes=0)(inputs)

In [27]:
x = random.normal(key, (batch_size, n))

In [28]:
# compute gradients
%time grad_y = grad_f_vect(x)

CPU times: user 21.2 ms, sys: 4.58 ms, total: 25.8 ms
Wall time: 13.9 ms


In [29]:
check_gradient_vect(x, grad_y)

True

4. Multiple scalar input variables (batch input)

In [30]:
def g_vect(x1, x2):
    return vmap(g, in_axes=(0, 0), out_axes=0)(x1, x2)

def grad_g_vect(x1, x2):
    return vmap(grad_g, in_axes=(0, 0), out_axes=0)(x1, x2)

In [31]:
x = random.normal(key, (batch_size, n))
x1 = x[:, 0]
x2 = x[:, 1]

In [32]:
%time grad_y = jnp.vstack((grad_g_vect(x1, x2))).T

CPU times: user 16.3 ms, sys: 0 ns, total: 16.3 ms
Wall time: 9.71 ms


In [33]:
check_gradient_vect(x, grad_y)

True