# Introduction to automatic differentiation

Automatic differentiation is a technique to automatically compute the derivative of a function, where the function is written as a composition of elementary sub-functions that we know how to differentiate in closed form. It is a cornerstone in modern machine learning, and can be useful in any science where coding is involved.

In practice, assume that we also have some Python function
```python
def func(x):
    ...
    return y
```
where `x` is a Numpy array of size $p$ and `y` is a numpy array of size $q$, and `...` is a succession of operations that we can differentiate in closed form. Then, `func` implements a differentiable function
$$f:\mathbb{R}^p\to \mathbb{R}^q.$$

Two quantities of interest for $f$ are **Jacobian-Vector-Products** (JVP's) and **Vector-Jacobian-Products** (VJP's), which can be computed automatically from the code of `func`.

Letting $J(x)\in\mathbb{R}^{q\times p}$ the Jacobian matrix of $f$ at $x$, the Jacobian-vector product of $f$ at $x$ with a vector $v\in \mathbb{R}^p$ is simply $J(x)v\in\mathbb{R}^q$.

The Vector-Jacobian-Product with a vector $v'\in\mathbb{R}^q$ is $J(x)^Tv'\in\mathbb{R}^p$.

**Key remark**: automatic differentiation allows to compute these JVP's-VJP's efficiently, at roughly the same cost as evaluating the function `func` itself! Hence, it does not work by first computing the big matrix $J(x)$ and then multiplying it by a vector; it is much more efficient than this.


**Question 1:** In machine learning, we are often interested in computing the gradient of cost functions, i.e. functions that go from $\mathbb{R}^p$ to $\mathbb{R}$. Can you compute the gradient of $f$ using a JVP? Using a VJP? Which one is the most convenient?

## Forward mode automatic differentiation

Forward mode automatic differentiation is a method to compute JVP's. Even though in machine learning we usually want to compute VJP's, we begin with forward mode automatic differentiation because it is conceptually more natural.

The idea of forward mode automatic differentiation is simply to follow the same steps as in the code of the function, while applying at the same time the rule of differentiation.

For instance, we consider the following toy example of ``func`` (it is just a random function)

In [50]:
import numpy as np

In [51]:
d = 3
x = np.array([1., 2., -1.])
A = np.array([[0., 2., -1.], [3., 1.5, 2.]])

def func(x):
    # x is a 3-dimensional array
    y = x ** 2
    z = np.sin(y)
    t = A @ z
    return y + z, t

func(x)

(array([1.84147098, 3.2431975 , 1.84147098]),
 array([-2.35507598,  3.07215118]))

This function takes as input something of size $3$ and outputs two vectors of size $3$ and $2$ respectively. 

We can compute "automatically" its JVP with a vector $v \in \mathbb{R}^3$ by following the rules of differentiation. 

To do so by hand, we copy and paste the code, and add the missing differentiation operations.

In [52]:
def jvp(x, v):
    # x is a 3-dimensional array
    
    y = x ** 2
    vy = 2 * x * v  # this is dy / dx * v
    
    z = np.sin(y)
    vz = np.cos(y) * vy  # this is dz / dx * v
    
    t = A @ z
    vt = A @ vz # this is dt / dx * v
    return vy + vz, vt  # this is df / dx * v

In [53]:
v = np.array([1., 3., 4.])

jvp(x, v)

(array([  3.08060461,   4.15627655, -12.32241845]),
 array([-11.36502845, -17.16860823]))

Importantly, forward-mode automatic differentiation works by keeping track at each step of the variation of the current variable with respect to the **input**.

We can check that these computations are correct, simply by comparing with an approximation of the JVP, given by:

$$
J(x)v \simeq \frac{f(x+\varepsilon v) - f(x)}{\varepsilon}
$$
for $\varepsilon$ a small scalar.

In [54]:
def approximate_jvp(x, v, eps=1e-7):
    return tuple((a - b) / eps for a, b in zip(func(x + eps * v), func(x)))


approximate_jvp(x, v)

(array([  3.0806046 ,   4.15628231, -12.32241868]),
 array([-11.3650169 , -17.16860493]))

In [55]:
jvp(x, v)

(array([  3.08060461,   4.15627655, -12.32241845]),
 array([-11.36502845, -17.16860823]))

We see that we have done a good job: both computations lead to almost the same result. 

A natural question is now: why implement the JVP with the approximation technique?
The main reason is that this technique is not exact; it has some error driven by $\varepsilon$. On the other hand, the automatic differentiation is exact up to machine precision.

You are now tasked to use the same method to implement the JVP in a simple two-layer neural network, defined by:

$$ f(x) = A\sigma(Bx),$$

where $B \in\mathbb{R}^{n\times p}$, $A\in\mathbb{R}^{q\times n}$ and $\sigma:\mathbb{R}^n\to\mathbb{R}^n$ is a non-linear function, defined by $\sigma(y_1, \dots, y_n) = (y_1^2, \dots, y_n^2)$.

We give below the code of the function.

Beware that it takes as inputs `x, B, A`, hence the vector `v` in the JVP will contain 3 variables.

In [56]:
def two_layers(x, B, A):
    y = B @ x
    z = y ** 2
    return A @ z

In [57]:
p = 2
n = 4
q = 3

x = np.random.randn(p)
B = np.random.randn(n, p)
A = np.random.randn(q, n)

two_layers(x, B, A)

array([0.44313108, 0.19279774, 0.1266412 ])

Now, fill in the blanks in the next function to compute the JVP with automatic differentiation; and check that this is correct with approximate differentiation.

In [58]:
def two_layers_jvp(x, B, A, v):
    vx, vB, vA = v  # v is a tuple
    
    y = B @ x
    vy = # TODO
    
    z = y ** 2
    vz = # TODO
    return # TODO

In [59]:
vx = np.random.randn(p)
vB = np.random.randn(n, p)
vA = np.random.randn(q, n)

v = (vx, vB, vA)
two_layers_jvp(x, B, A, v)

array([ 2.39764278, -6.37349448, -0.50283813])

In [60]:
def approximate_jvp(x, B, A, v, eps=1e-7):
    vx, vB, vA = v
    return (two_layers(x + eps * vx, B + eps * vB, A + eps * vA) - two_layers(x, B, A))/ eps

In [61]:
approximate_jvp(x, B, A, v)  # this should be close to the output of your two_layers_jvp

array([ 2.39764268, -6.37349406, -0.50283775])

Forward-mode automatic differentiation is rarely used in machine learning, because for a function $f:\mathbb{R}^p\to \mathbb{R}$ it only computes $\nabla f(x)^T v$ for a vector $v$.

What is used in practice for real-valued functions is rather backward automatic differentiation.

## Backward automatic differentiation (or backpropagation for real-valued functions)

The main idea is the following:
- Forward mode automatic differentiation tracks the derivative of each variable with respect to the **input**.
- Backward mode automatic differentiation tracks the derivative of the **output** with respect to each variable.

Hence, backward-mode computes things in a backwards order: it goes backward through the computational graph of the function.

The key point here, and if this is the first time you see this it might seem a bit weird, is that the classical rules of differentiation also apply backwards.

Let us take the same example as before, and try to compute a Vector-Jacobian Product.

In [62]:
d = 3
x = np.array([1., 2., -1.])
A = np.array([[0., 2., -1.], [3., 1.5, 2.]])

def func(x):
    # x is a 3-dimensional array
    y = x ** 2
    z = np.sin(y)
    t = A @ z
    return y + z, t

func(x)

(array([1.84147098, 3.2431975 , 1.84147098]),
 array([-2.35507598,  3.07215118]))

In [63]:
def vjp(x, v):
    # x is a 3-dimensional array
    # the function outputs two arrays, hence now v is two arrays
    vu, vt = v
    
    # we first run the function code, without differentiating anything. We keep each intermediate variable in memory ! 
    y = x ** 2
    z = np.sin(y)
    t = A @ z
    output_u = y + z
    output_t = t
    # Now, we " backprop" through the computational graph, starting from the end.
    # Here, I comment each line of code and write the corresponding differentiation below.
    
    # output_t = t
    vt = vt  # This is just the derivative of output_t in the direction vt
    
    # output_u = y + z
    vy = vu.copy()  # this is d output / dy in the direction (vu, vt). By "in the direction", we just mean the VJP.
    vz = vu.copy()  # this is d output/ dz in the direction (vu, vt).
    
    # t = A @ z
    vz += A.T @ vt  # we have already seen z in the graph, so we accumulate the derivatives.
    
    # z = np.sin(y)
    vy += np.cos(y) * vz  # remember that vy is the derivative of the output w.r.t. y at the current point of the graph.
    # the previous equation can be rewritten as doutput / dy = doutput/dz * dz/dy, and we have doutput/dz = vz, and from the line z = sin(y), we also have dz/dy = cos(y).
    
    # y = x ** 2
    vx = 2 * x * vy  # same as just above. This might be counter intuitive, so take your time to reflect on this.
    return vx  # we have finished traversing the computational graph!

In [64]:
vu = np.array([1., 0.1, 3.])
vt = np.array([0.4, -0.3])
v = vu, vt
vjp(x, v)

array([ 2.10806046, -0.77655852, -8.16120922])

We can check that this is correct by computing the scalar product with a vector:

In [65]:
v2 = np.array([2., 3., 0.5])

np.dot(vjp(x, v), v2)

-2.194159242052528

In [66]:
sum(np.dot(a, b) for a, b in zip(jvp(x, v2), v))

-2.1941592420525278

Now, it is your turn to implement this method to compute a VJP through the two layers neural network:

In [67]:
def two_layers(x, B, A):
    y = B @ x
    z = y ** 2
    return A @ z

In [68]:
p = 2
n = 4
q = 3

x = np.random.randn(p)
B = np.random.randn(n, p)
A = np.random.randn(q, n)

two_layers(x, B, A)

array([ 3.39956728,  0.35436586, -0.69567035])

In [69]:
v = np.random.randn(q)

In [70]:
def two_layers_vjp(x, B, A, v):
    # This should output vx, vB, vA.
    y = B @ x
    z = y ** 2
    output = A @ z
    # Now, backprop through the computations:
    # output = A @ z
    vz = # TODO
    vA = # TODO
    
    # z = y ** 2
    vy = # TODO
    # y = B @ x
    vB = # TODO
    vx = # TODO
    return # TODO

In [71]:
two_layers_vjp(x, B, A, v)
vx = np.random.randn(p)
vB = np.random.randn(n, p)
vA = np.random.randn(q, n)

v2 = (vx, vB, vA)

In [72]:
sum(np.sum(a * b) for a, b in zip(two_layers_vjp(x, B, A, v), v2))

-18.80730340627934

In [73]:
np.sum(two_layers_jvp(x, B, A, v2) * v)

-18.807303406279335

These numbers should be the same!

Now, imagine that we add a loss function on top of it, for instance defined as $f(x, A, B) = \sum_{i=1}^q [\texttt{twolayers}(x, A, B)]_i^2$

In [74]:
def loss(x, B, A):
    z = two_layers(x, B, A)
    return np.sum(z ** 2)

Using the previous VJP function `two_layers_jvp`, compute the gradient of the function with respect to (x, B, A):

In [75]:
def grad(x, B, A):
    # forward computations
    z = two_layers(x, B, A)
    output = np.sum(z ** 2)
    # Backward computations:
    doutputdz = # TODO
    return # TODO

In [76]:
grad(x, B, A)

(array([ 15.82332274, -48.38062266]),
 array([[-29.52150855,   4.26020931],
        [-22.3080378 ,   3.21924302],
        [-20.95312919,   3.02371798],
        [ 42.21605038,  -6.09214162]]),
 array([[ 4.04937186, 42.46766857,  2.60546593,  5.59341797],
        [ 0.42210053,  4.42676691,  0.27158991,  0.58304961],
        [-0.82864309, -8.69037012, -0.53316944, -1.14460892]]))

We will check below that this is correct.

**Question:** What is the cost of computing this gradient, compared to the cost of computing the loss function itself?


We see here the main advantage of reverse-mode automatic differentiation: it is an exact method to compute the gradient of a function, at roughly the same cost as computing the function itself !

## Drawback

The main drawback of backpropagation is that we need to keep in memory every intermediate state in the code; therefore it is costly in terms of memory. Forward mode automatic differentiation does not have such a problem: we can throw away variables that are not going to be used anymore at evaluation time,  but it is impractical to compute gradients. 

Hence, in deep learning, it is much more costly in terms of memory to train a neural network, where we have to use backpropagation, compared to just computing the output of the network without backprop.

## Autodiff in practice

So far, we have coded by hand automatic differentiation; this is of course highly impractical. Thnakfully, any modern deep learning framework like Pytorch or Jax has an automatic differentiation engine, which means that they can compute JVP / VJP/ gradients using only the code of the function that you provide. What happens under the hood is exactly what we have coded above; but it is hidden to the user. 

We will demonstrate it using Jax.

Jax has a syntax really close to numpy

In [77]:
import jax
import jax.numpy as jnp

In [78]:
d = 3
x = jnp.array([1., 2., -1.])
A = jnp.array([[0., 2., -1.], [3., 1.5, 2.]])

def func(x):
    # x is a 3-dimensional array
    y = x ** 2
    z = jnp.sin(y)  # need to put jnp instead of np
    t = A @ z
    return y + z, t

func(x)

(Array([1.841471 , 3.2431974, 1.841471 ], dtype=float32),
 Array([-2.3550758,  3.0721512], dtype=float32))

To compute the JVP of the function, simply use `jax.jvp`:

In [79]:
v = np.array([1., 3., 4.])

jax.jvp(func, (x,), (v,))[1]  # the first component is just func(x)

(Array([  3.0806046,   4.1562767, -12.322418 ], dtype=float32),
 Array([-11.365028, -17.168608], dtype=float32))

This should match what we have done above:

In [80]:
jvp(x, v)

(Array([  3.0806046,   4.1562767, -12.322418 ], dtype=float32),
 Array([-11.365028, -17.168608], dtype=float32))

Let us focus on computing gradients in Jax.

In [81]:
p = 2
n = 4
q = 3

x = np.random.randn(p)
B = np.random.randn(n, p)
A = np.random.randn(q, n)

two_layers(x, B, A)

def loss(x, B, A):
    z = two_layers(x, B, A)
    return jnp.sum(z ** 2)

The `jax.grad` function does this exactly. By default, it differentiates only with respect to the first parameter, but we can tell it to differentiate with respect to specific arguments using the `argnum` option:

In [82]:
grad = jax.grad(loss, argnums=[0, 1, 2])  #

`grad` is now a function, that computes the gradient.

In [83]:
grad(x, B, A)

(Array([-126.91913,  -75.35017], dtype=float32),
 Array([[ 36.090748 ,   3.246026 ],
        [ 69.692474 ,   6.2681875],
        [ 22.129803 ,   1.9903693],
        [-49.829563 ,  -4.481704 ]], dtype=float32),
 Array([[102.664856  ,  86.8243    ,   4.415657  ,  20.893446  ],
        [-10.290272  ,  -8.702547  ,  -0.44258875,  -2.0941854 ],
        [ 83.49165   ,  70.609406  ,   3.5910094 ,  16.991484  ]],      dtype=float32))

Therefore, automatic differentiation is already implemented in Jax!