In [1]:
from jax import numpy as jnp, random
from jax import jacfwd, jacrev, jvp, vjp, vmap

# Vector-valued function

Let $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$ be vector-valued function given by
$$
f(x) = (f_1(x), f_2(x)), \quad f_1(x) = 3 x_1^3 - x_2^2 + x_3, \quad f_2(x) = sin(x_1)x_2, 
$$
i.e. $n=2$ and $m=3$. The Jacobian $D_f \in \mathbb{R}^{m \times n}$ given by $D_f = \bigl(\frac{\partial}{\partial x_j}f_i(x) \bigr)_{i, j}$ is
$$
\frac{\partial}{\partial x_1}f_1(x) = 9 x_1^2, \quad 
\frac{\partial}{\partial x_2}f_1(x) = - 2 x_2, \quad
\frac{\partial}{\partial x_3}f_1(x) = 1, \quad
\frac{\partial}{\partial x_1}f_2(x) = cos(x_1)x_2, \quad 
\frac{\partial}{\partial x_2}f_2(x) = sin(x_1), \quad
\frac{\partial}{\partial x_3}f_2(x) = 0 .
$$

In [2]:
# create key
key = random.PRNGKey(0)

# domain, codomain dimension and batch size
n, m = 3, 2
batch_size = int(1e4)

## Compute Jacobians

In [3]:
def check_jacobian(x, jac_y):
    jac_y_ref = jnp.array([
        [9*x[0]**2, -2*x[1], 1.0], 
        [jnp.cos(x[0])*x[1], jnp.sin(x[0]), 0.],
    ])
    if jnp.isclose(jac_y, jac_y_ref).all().item():
        return True
    else:
        return False

1. Vector input variable (no batch)

In [4]:
def f(x):
    y = jnp.array([
        3 * x[0]**3 - x[1]**2 + x[2],
        jnp.sin(x[0])*x[1],
    ])
    return y

In [5]:
# initialize input
x = random.normal(key, (n,))

# compute jacobian
#%time jac_y = jacfwd(f)(x)
%time jac_y = jacrev(f)(x)

# check jacobian 
check_jacobian(x, jac_y)

CPU times: user 277 ms, sys: 4.24 ms, total: 281 ms
Wall time: 280 ms


True

2. Multiple scalar input variables (no batch)

In [6]:
def g(x1, x2, x3):
    y = jnp.array([
        3 * x1**3 - x2**2 + x3,
        jnp.sin(x1)*x2
    ])
    return y

def jac_g(x1, x2, x3):
    return jacfwd(g, (0, 1, 2))(x1, x2, x3)

In [7]:
# initialize input
x1 = random.normal(key)
x2 = random.normal(key)
x3 = random.normal(key)

# compute jacobian
jac_y = jnp.array(jac_g(x1, x2, x3)).T

# check jacobian 
check_jacobian(jnp.array([x1, x2, x3]), jac_y)

True

In [8]:
def check_jacobian_vect(x, jac_y):
    jac_y_ref = jnp.empty((batch_size, m, n))
    jac_y_ref = jac_y_ref.at[:, 0, 0].set(9*x[:, 0]**2)
    jac_y_ref = jac_y_ref.at[:, 0, 1].set(-2*x[:, 1])
    jac_y_ref = jac_y_ref.at[:, 0, 2].set(1.0)
    jac_y_ref = jac_y_ref.at[:, 1, 0].set(jnp.cos(x[:, 0])*x[:, 1])
    jac_y_ref = jac_y_ref.at[:, 1, 1].set(jnp.sin(x[:, 0]))
    jac_y_ref = jac_y_ref.at[:, 1, 2].set(0)
    #return jac_y_ref
    
    if jnp.isclose(jac_y, jac_y_ref).all().item():
        return True
    else:
        return False

3. Vector input variable (batch)

In [9]:
x = random.normal(key, (batch_size, n))
y = vmap(f)(x)

In [10]:
# compute gradients
#%time jac_y = vmap(jacfwd(f))(x)
#%time jac_y = jacfwd(vmap(f))(x)
%time jac_y = vmap(jacrev(f))(x)
#%time jac_y = jacrev(vmap(f))(x)

CPU times: user 164 ms, sys: 0 ns, total: 164 ms
Wall time: 163 ms


In [11]:
check_jacobian_vect(x, jac_y)

True

4. Multiple scalar input variables (batch input)

In [12]:
def g_vect(x1, x2, x3):
    return vmap(g, in_axes=(0, 0, 0), out_axes=0)(x1, x2, x3)

def jac_g_vect(x1, x2, x3):
    return vmap(jac_g, in_axes=(0, 0, 0), out_axes=0)(x1, x2, x3)

In [13]:
x = random.normal(key, (batch_size, n))
x1 = x[:, 0]
x2 = x[:, 1]
x3 = x[:, 2]

In [14]:
%time jac_y = jnp.stack((jac_g_vect(x1, x2, x3))).swapaxes(0, 1).swapaxes(1, 2)

CPU times: user 153 ms, sys: 8.31 ms, total: 161 ms
Wall time: 160 ms


In [15]:
check_jacobian_vect(x, jac_y)

True

# Compute Jacobian-Vector Products (JVP)

1. Without batch

In [16]:
# initialize input
key1, key2 = random.split(key, 2)
x = random.normal(key1, (n,))
v = random.normal(key2, (n,))

# compute jvp
primals, tangents = jvp(f, (x,), (v,))

#primals.shape, tangents.shape
x, v, f(x), primals, tangents

(Array([ 0.13893168,  0.509335  , -0.53116107], dtype=float32),
 Array([ 1.1378784 , -1.2209548 , -0.59153634], dtype=float32),
 Array([-0.7825382 ,  0.07053534], dtype=float32),
 Array([-0.7825382 ,  0.07053534], dtype=float32),
 Array([0.84988374, 0.4048928 ], dtype=float32))

2. With batch

In [17]:
# initialize input
key1, key2 = random.split(key, 2)
x = random.normal(key1, (batch_size, n))
v = random.normal(key2, (batch_size, n))

# compute jvp
primals, tangents = vmap(jvp, in_axes=(None, 0, 0))(f, (x,), (v,))

primals.shape, tangents.shape

((10000, 2), (10000, 2))

# Compute Vector-jacobian products

1. Without batch

In [18]:
# initialize input
key1, key2 = random.split(key, 2)
x = random.normal(key1, (n,))
u = random.normal(key2, (m,))

# compute vjp
y, vjp_fun = vjp(f, x)

# pull back the covector `u` along `f` evaluated at `x`
v = vjp_fun(u)

#y.shape[0] == m
x, u, f(x), y, v

(Array([ 0.13893168,  0.509335  , -0.53116107], dtype=float32),
 Array([ 0.19307722, -0.52678293], dtype=float32),
 Array([-0.7825382 ,  0.07053534], dtype=float32),
 Array([-0.7825382 ,  0.07053534], dtype=float32),
 (Array([-0.23218267, -0.2696336 ,  0.19307722], dtype=float32),))

2. With batch

In [19]:
# initialize input
key1, key2 = random.split(key, 2)
x = random.normal(key1, (batch_size, n))
u = random.normal(key2, (batch_size, m))

# compute vjp
y, vjp_fun = vmap(vjp, in_axes=(None, 0))(f, x)

# pull back the covector `u` along `f` evaluated at `x`
v = vmap(vjp_fun)(u)[0]

y.shape[1] == m, y.shape, v.shape

(True, (10000, 2), (10000, 3))