# Example of multivariate derivatives via univariate automatic differentiation

The general formula is given by Eqs. (13) and (17) in the paper of Andreas Griewank, Jean Utke, and Andrea Walther, "Evaluating higher derivative tensors by forward propagation of univariate Taylor series", *Math. Comp.* **69** (2000), 1117-1130, [https://doi.org/10.1090/S0025-5718-00-01120-0](https://doi.org/10.1090/S0025-5718-00-01120-0).

According to these formulas, the partial derivative of a function $f(x_1,x_2,...,x_n)$ of total order $d$ with respect to $n$ variables $x_1,x_2,...,x_n$ can be computed as:

$$
\frac{\partial^{|\mathbf{i}|} f}{\partial x_1^{i_1}\partial x_2^{i_2}...\partial x_n^{i_n}} = \sum_{|\mathbf{c}|=d}\frac{\partial^d f}{\partial t^d}\Big|_{\mathbf{c}}\cdot g_{\mathbf{i},\mathbf{c}}.
$$

Here, $\frac{\partial^d f}{\partial t^d}\Big|_{\mathbf{c}}$ is the directional derivative of $f(x_1+tc_1,x_2+tc_2,...,x_n+tc_n)$ evaluated at $t=0$, where the direction vectors $|\mathbf{c}|=d$ span all possible combinations of $c_i=0..d$ ($i=1..n$) constrained by $\sum_i c_i=d$.
For example, for $n=3$ and $d=2$, the direction vectors $\mathbf{c}$ include $(2,0,0)$, $(0,2,0)$, $(0,0,2)$, $(1,1,0)$, $(1,0,1)$, $(0,1,1)$.

The coefficients $g_{\mathbf{i},\mathbf{c}}$ depend on the multi-index $\mathbf{i}$ of the partial derivative and the direction vector $\mathbf{c}$, computed using:

$$
g_{\mathbf{i},\mathbf{c}} = \sum_{0<\mathbf{k}\leq\mathbf{i}}(-1)^{|\mathbf{i}-\mathbf{k}|}{\mathbf{i}\choose\mathbf{k}}{d\mathbf{k}/|\mathbf{k}|\choose\mathbf{c}}\left(|\mathbf{k}|/d\right)^{|\mathbf{i}|}
$$

In [1]:
import itertools
from typing import List

import jax
import numpy as np
from jax import numpy as jnp
from jax.experimental import jet
from scipy.special import binom, comb, factorial

jax.config.update("jax_enable_x64", True)

Define the target function $f$ and the total derivative order

In [2]:
# test function
func = lambda x: jnp.cos(x[0]) * jnp.sin(x[1]) * x[1]

# expansion point
x0 = np.array((0.3, 0.3), dtype=np.float64)

d = 8  # derivative order

Generate a set of direction vectors $|\mathbf{c}|=d$.

In [3]:
c = np.array(
    [
        elem
        for elem in itertools.product(*[range(0, d + 1) for _ in range(len(x0))])
        if np.sum(elem) == d
    ]
)

print("derivative order:", d)
print("directions:\n", c)

derivative order: 8
directions:
 [[0 8]
 [1 7]
 [2 6]
 [3 5]
 [4 4]
 [5 3]
 [6 2]
 [7 1]
 [8 0]]


Compute univariate derivatives $\frac{d^df}{dt^d}\big|_{\mathbf{c}}$ for different directions $|\mathbf{c}|=d$.

In [4]:
df_dt = np.zeros(len(c), dtype=np.float64)

for i in range(len(c)):
    # jet.jet returns a tuple: f(x0), (df/dt, d^2 f/dt^2, ..., d^d f/dt^d)
    # we are interested only in the last element, d^d f/dt^d
    _, (*_, df_dt[i]) = jet.jet(
        func,
        (x0,),
        ((jnp.asarray(c, dtype=jnp.float64)[i],) + (np.zeros(len(x0)),) * (d - 1),),
    )
    print(f"direction: {c[i]}, d^d f/dt^d: {df_dt[i]}")

direction: [0 8], d^d f/dt^d: -121075230.22449228
direction: [1 7], d^d f/dt^d: -54881157.282575674
direction: [2 6], d^d f/dt^d: -40512748.34065934
direction: [3 5], d^d f/dt^d: -33198675.398743037
direction: [4 4], d^d f/dt^d: -26272698.456826724
direction: [5 3], d^d f/dt^d: -19347745.51491043
direction: [6 2], d^d f/dt^d: -12294792.572994135
direction: [7 1], d^d f/dt^d: -4382703.631077848
direction: [8 0], d^d f/dt^d: 1420969.3108384246


Define function for computing $g_{\mathbf{i},\mathbf{c}}$

In [5]:
def g_ic(i: List[int], c: List[int], d: int):
    sum_i = sum(i)
    k_ind = [elem for elem in itertools.product(*[range(0, k + 1) for k in i])]
    fac1 = (-1) ** np.sum(np.array(i)[None, :] - np.array(k_ind), axis=-1)
    if d == 0:
        fac2 = np.ones_like(len(k_ind))
    else:
        fac2 = (np.sum(k_ind, axis=-1) / d) ** sum_i
    x = [d / sum(k) * np.array(k) if sum(k) > 0 else np.array(k) for k in k_ind]
    c = np.sum(
        np.array([np.prod(comb(i, k) * binom(x_, c)) for k, x_ in zip(k_ind, x)])
        * fac1
        * fac2
    )
    return c

Compute partial derivatives $\frac{\partial^{|\mathbf{i}|} f}{\partial x_1^{i_1}\partial x_2^{i_2}...\partial x_n^{i_n}}$ for all $|\mathbf{i}|=d$

In [6]:
i_list = [
    elem
    for elem in itertools.product(*[range(0, d + 1) for _ in range(len(x0))])
    if np.sum(elem) == d
]

fac_d = factorial(d)
for i in i_list:
    g = np.array([g_ic(i, c_, d) for c_ in c]) # g_{ic} coefficients
    deriv_i = g @ df_dt / fac_d
    print(i, deriv_i)

(0, 8) -7.2166460886295
(1, 7) 0.6960217188304038
(2, 6) -5.391310473716733
(3, 5) 0.5213573337380131
(4, 4) -3.5659748588064093
(5, 3) 0.34669294864648864
(6, 2) -1.740639243900105
(7, 1) 0.17202856355383234
(8, 0) 0.08469637100925592


For comparison, we can also compute partial derivatives by nesting `jacfwd` calls

In [7]:

def jacfwd(x0, ind):
    f = func
    for _ in range(sum(ind)):
        f = jax.jacfwd(f)
    i = sum([(i,) * o for i, o in enumerate(ind)], start=tuple())
    return f(x0)[i]

for i in i_list:
    deriv_i = jacfwd(x0, i)
    print(i, deriv_i)

(0, 8) -7.216646088629457
(1, 7) 0.6960217188253812
(2, 6) -5.391310473719779
(3, 5) 0.5213573337350594
(4, 4) -3.565974858810101
(5, 3) 0.3466929486447378
(6, 2) -1.7406392439004228
(7, 1) 0.17202856355441612
(8, 0) 0.08469637100925528
