# Multivariate derivatives


In [1]:
import jax
import jax.numpy as jnp


from pnfindiff import central, differentiate, gradient, differentiate_along_axis

Let's define a function $f: R^d \rightarrow R$.

In [2]:
f = lambda z: jnp.stack([jnp.dot(z, z), jnp.cos(jnp.dot(z, z))])
f_batched = jax.vmap(f)
d_in, d_out = 4, 2

# Some point x in R^d
x = jnp.arange(1.0, 1.0 + d_in)
assert f(x).shape == (d_out,)

# The gradient takes values in R^d
df = jax.jacfwd(f)
assert df(x).shape == (d_out, d_in)



In [3]:
n = 3
scheme, xs = central(dx=0.01)
assert xs.shape == (n,)

In [6]:
scheme, xs_full, labels = gradient(
    central(dx=0.01), shape_input=(d_in,), shape_output=(d_out,)
)

assert labels == (
    "shape_output_differential",
    "shape_input_differential",
    "shape_input_values",
    "fd_weights",
)
r"""

xs[0, 1, 2, 3]: 
the 3rd FD weight, 
the 2nd dimension of _one_ input to the function (likely to be zero),
along the first input-dimension axis (i.e. the 1st partial derivative is approximated)
and the 0th element of the output. In other words, xs[a, b] is a (c,d) array providing the FD grid for 

    .. math:: \frac{\partial}{\partial_b}f_a(x)

jnp.array(
    [
        [
            [
                [
                
                ]
                for k in input_values
            ]
            for i in input_diff
        ]
        for o in output_diff
    ]
)
for o in output_diff:
    for i in input_diff:
        
"""


# (output_dim_idx, input_dim_idx, input_value_idx, fd_weight_idx)
assert xs_full.shape == (d_out, d_in, d_in, n)

# (output_dim_idx, input_dim_idx, fd_weight_idx)
assert fxs.shape == (d_out, d_in, n)

# (output_dim_idx, input_dim_idx)
assert dfxs.shape == (d_out, d_in)

TypeError: gradient() got an unexpected keyword argument 'shape_input'

In [None]:
# k = 2
# scheme, xs = pnfindiff.gradient(*pnfindiff.central(dx=0.1, order_method=k), dim=3)
# assert xs.shape == (k+1, d, k+1)

# xs_shifted = x[..., None] + xs
# assert xs_shifted.shape == (k+1, d, k+1)

# dfx, _ = pnfindiff.differentiate(f(xs_shifted), scheme=scheme)

# assert dfx.shape == df(x).shape == (d,)
# assert jnp.allclose(dfx, df(x), rtol=1e-4, atol=1e-4)

In [21]:
d_in = 6
d_out = 7
fd_weights = jnp.ones(3)
xs = jnp.stack([jnp.stack([jnp.stack([fd_weights] * d_in) for i in range(d_in)]) for o in range(d_out)])

In [22]:
xs.shape

(7, 6, 6, 3)

In [24]:
jnp.stack([fd_weights] * d_in)

DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32)

In [50]:
fd_weights = jnp.array([3, 4, 5]).reshape((1, -1))

jnp.pad(fd_weights, ((0,d_in),(0,0)))

DeviceArray([[3, 4, 5],
             [0, 0, 0],
             [0, 0, 0],
             [0, 0, 0],
             [0, 0, 0],
             [0, 0, 0],
             [0, 0, 0]], dtype=int32)

In [55]:
X = jnp.stack([jnp.pad(fd_weights, pad_width=((i, d_in-i-1), (0,0))) for i in range(d_in)])
