# Vector-valued function

Let $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$ be vector-valued function given by
$$
f(x) = (f_1(x), f_2(x)), \quad f_1(x) = 3 x_1^3 - x_2^2 + x_3, \quad f_2(x) = sin(x_1)x_2, 
$$
i.e. $n=2$ and $m=3$. The Jacobian $D_f \in \mathbb{R}^{m \times n}$ given by $D_f = \bigl(\frac{\partial}{\partial x_j}f_i(x) \bigr)_{i, j}$ is
$$
\frac{\partial}{\partial x_1}f_1(x) = 9 x_1^2, \quad 
\frac{\partial}{\partial x_2}f_1(x) = - 2 x_2, \quad
\frac{\partial}{\partial x_3}f_1(x) = 1, \quad
\frac{\partial}{\partial x_1}f_2(x) = cos(x_1)x_2, \quad 
\frac{\partial}{\partial x_2}f_2(x) = sin(x_1), \quad
\frac{\partial}{\partial x_3}f_2(x) = 0 .
$$

In [114]:
import jax.numpy as jnp
from jax import jacfwd, jacrev 
from jax import vmap
from jax import random

key = random.PRNGKey(0)

In [186]:
# domain, codomain dimension and batch size
n = 3
m = 2
batch_size = int(1e3)

In [116]:
def check_jacobian(x, jac_y):
    jac_y_ref = jnp.array([
        [9*x[0]**2, -2*x[1], 1.0], 
        [jnp.cos(x[0])*x[1], jnp.sin(x[0]), 0.],
    ])
    if jnp.isclose(jac_y, jac_y_ref).all().item():
        return True
    else:
        return False

1. Vector input variable (no batch)

In [117]:
def f(x):
    y = jnp.array([
        3 * x[0]**3 - x[1]**2 + x[2],
        jnp.sin(x[0])*x[1],
    ])
    return y
    
def jac_f(x):
    return jacfwd(f)(x)
    #return jacrev(f)(x)

In [118]:
# initialize input
x = random.normal(key, (n,))

# compute jacobian
jac_y = jac_f(x)

# check jacobian 
check_jacobian(x, jac_y)

True

2. Multiple scalar input variables (no batch)

In [119]:
def g(x1, x2, x3):
    y = jnp.array([
        3 * x1**3 - x2**2 + x3,
        jnp.sin(x1)*x2
    ])
    return y

def jac_g(x1, x2, x3):
    return jacfwd(g, (0, 1, 2))(x1, x2, x3)

In [120]:
# initialize input
x1 = random.normal(key)
x2 = random.normal(key)
x3 = random.normal(key)

# compute jacobian
jac_y = jnp.array(jac_g(x1, x2, x3)).T

# check jacobian 
check_jacobian(jnp.array([x1, x2, x3]), jac_y)

True

In [177]:
def check_jacobian_vect(x, jac_y):
    jac_y_ref = jnp.empty((batch_size, m, n))
    jac_y_ref = jac_y_ref.at[:, 0, 0].set(9*x[:, 0]**2)
    jac_y_ref = jac_y_ref.at[:, 0, 1].set(-2*x[:, 1])
    jac_y_ref = jac_y_ref.at[:, 0, 2].set(1.0)
    jac_y_ref = jac_y_ref.at[:, 1, 0].set(jnp.cos(x[:, 0])*x[:, 1])
    jac_y_ref = jac_y_ref.at[:, 1, 1].set(jnp.sin(x[:, 0]))
    jac_y_ref = jac_y_ref.at[:, 1, 2].set(0)
    #return jac_y_ref
    
    if jnp.isclose(jac_y, jac_y_ref).all().item():
        return True
    else:
        return False

3. Vector input variable (batch)

In [178]:
def f_vect(inputs):
    return vmap(f)(inputs)
    
def jac_f_vect(inputs):
    return vmap(jac_f)(inputs)

def jac_f_vect2(inputs):
    return jacfwd(f_vect)(inputs)

In [187]:
x = random.normal(key, (batch_size, n))
y = f_vect(x)

In [188]:
%time

# compute gradients
jac_y = jac_f_vect(x)

CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 5.25 µs


In [189]:
check_jacobian_vect(x, jac_y)

True

4. Multiple scalar input variables (batch input)

In [191]:
def g_vect(x1, x2, x3):
    return vmap(g)(x1, x2, x3)

def jac_g_vect(x1, x2, x3):
    return vmap(jac_g)(x1, x2, x3)

In [192]:
x = random.normal(key, (batch_size, n))
x1 = x[:, 0]
x2 = x[:, 1]
x3 = x[:, 2]

In [193]:
%time
jac_y = jnp.stack((jac_g_vect(x1, x2, x3))).swapaxes(0, 1).swapaxes(1, 2)

CPU times: user 1e+03 ns, sys: 1e+03 ns, total: 2 µs
Wall time: 2.86 µs


In [194]:
check_jacobian_vect(x, jac_y)

True