# Backpropagation in DAGs

This notebook provides an introduction to neural network learning using backpropagation. We discuss backpropagation in the context of directed acyclic computational graphs. Our main result is that a single training step (consisting of both a forward and a backward pass into the network) has a time and memory complexity that is linear in the network size. In the last section, we take a closer look at the implementation of `.backward` in PyTorch.

## Gradient descent on the loss surface

```{margin}
**Loss surfaces from samples**
```

Recall that a neural network measures its error rate using a **loss function** $\mathcal L.$ In theory, we assume the existence of an underlying distribution for the input and output data, hence an expected value for the loss. In practice, we approximate the loss using the performance of the network on a dataset $\mathcal X$ where each of the points $\mathbf x_i$ is sampled independently using
  
$$\mathcal L(\mathcal X, \mathbf w) = \frac{1}{N}\sum_{i=1}^N \mathcal L(\mathbf x_i, \mathbf w).$$ 

We can imagine this as forming a surface in $\mathbb R^d \times \mathbb R$ where $d$ is the number of parameters of the network, with the current parameter setting being a point $(\mathbf w, \mathcal L(\mathcal X, \mathbf w))$ on this surface. Note that the empirical loss surface will generally vary with $\mathcal X$, but we expect these surfaces to be similar for large $N = |\mathcal X|$ assuming the points sampled from the same distribution. The networks learns by finding the parameters that minimizes the loss, typically through an iterative process. In practice, to update the weights, we use variants of **gradient descent** characterized by the update rule

$$\mathbf w \leftarrow \mathbf w - \epsilon \nabla_{\mathbf w} \mathcal L$$ 

where $-\nabla_{\mathbf w} \mathcal L$ is the direction of steepest descent, and $\epsilon > 0$ is some positive number called the **learning rate**. Note that we can compute the gradient as the average of gradients $\nabla_\mathbf w \mathcal L (\mathbf x_i, \mathbf w)$ at the point $\mathbf w$ on the loss surface generated by the data point $\mathbf x_i.$

```{margin}
**The need for efficient backprop.**
```

Since $\nabla_\mathbf w \mathcal L$ consists of partial derivatives for each weight in the network, this can easily number in millions (or even billions for SoTA models). How do we compute these derivatives efficiently? As discussed above, we have to compute the gradient at the current state of the network, so we would have to perform a forward pass to compute all parameter values given $\mathbf w$ up to the final node. This is followed by a backward pass where we compute compute every partial derivative by a clever use of the chain rule, recursively backward for each layer of the network. Both forward and backward passes will be implemented efficiently such that that no value is computed twice.

## Backpropagation in DAGs

```{margin}
**Forward pass** 
```

Note that a neural network can be modelled as a **directed acyclic graph** (DAG) of compute and parameter nodes that implements a function $f$ and can be extended to implement the calculation of the loss value for each training example and parameter values. In computing $f(\mathbf x),$ the values for each node are calculated from bottom to top, storing every value in the nodes so we don't have to recompute any known value. Assuming each activation and each arithmetic operation between weights and and node or input values takes constant time, then one forward pass takes $\mathcal O(V + E)$ calculations were $V$ is the number of compute nodes and $E$ is the number of parameters of the network &mdash; i.e. the network size. This is also the memory complexity of the whole operation.

```{margin}
**Backward pass** 
```

During backward pass, we divide the calculation of gradients into two groups: (1) **local gradients** obtained when perturbing adjacent compute nodes $u$ and $w$, and (2) **backpropagated gradients** of the form ${\frac{\partial{\mathcal L}}{\partial u}}$ for a node ${u}.$ Our goal is to calculate the gradient of the top-most node with respect to the leaves of the graph (i.e. nodes with zero fan-in).


Backpropagation proceeds by **induction**. (1) For the base step, $\frac{\partial{\mathcal L}}{\partial \mathcal L} = 1$ for the node which computes the loss value. This value is stored. (2) For the inductive step, suppose ${\frac{\partial{\mathcal L}}{\partial u}}$ are stored for each compute node $u$ in the upper layer, then after computing local gradients ${\frac{\partial{u}}{\partial w}}$, the backpropagated gradients ${\frac{\partial{\mathcal L}}{\partial w}}$ for compute nodes $w$ can be calculated via the chain rule:

$${ \frac{\partial\mathcal L}{\partial w} } = \sum_{ {u} }\left( {{\frac{\partial\mathcal L}{\partial u}}} \right)\left( {{\frac{\partial{u}}{\partial w}}} \right).$$


Thus, continuing the "flow" of gradients to the current layer. The process ends on  nodes with zero fan-in. Note that the partial derivatives are evaluated on the current network state &mdash; these values are stored during forward pass which precedes backward pass. Analogously, all backpropagated gradients are stored in each compute node for use by the next layer. On the other hand, there is no need to store local gradients; these are computed as needed. Hence, it suffices to compute all gradients with respect to compute nodes to get all gradients with respect to the weights of the network.
    

```{figure} ../img/backprop-compgraph.png
---
width: 35em
name: backprop-compgraph
---
Backprop on a generic comp. graph with fan out > 1 on node <code>y</code>. Each backpropagated gradient computation is stored in the corresponding node. For node <code>y</code> to calculate the backpropagated gradient we have to sum over the two incoming gradients which can be implemented using matrix multiplication of the gradient vectors.
```


<br>

**Backpropagation algorithm.** Now that we know how to compute each backpropagated gradient implemented as `u.backward()` for node `u` which *sends* its gradient $\partial \mathcal L / \partial u$ to all its parent nodes, i.e. nodes on the lower layer. We now write the complete algorithm:

```python 
def Forward():
    for c in compute: 
        c.forward()

def Backward(loss):
    for c in compute: c.grad = 0
    for c in params:  c.grad = 0
    for c in inputs:  c.grad = 0
    loss.grad = 1

    for c in compute[::-1]: 
        c.backward()

def SGD(eta):
    for w in params:
        w.value -= eta * w.grad
```

```{margin}
**Backpropagation equations for MLPs**
```

Consider a dense fully connected neural network which is clearly a computational DAG. Let ${z_k}^{[t]} = \sum_l {w_{kl}}^{[t]}{a_l}^{[t-1]}$ and ${a_k}^{[t]} = \phi^{[t]}({z_j}^{[t]})$ be the values of compute nodes at the $t$-th layer of the network. The backpropagated gradients for the compute nodes of the current layer are given by
    
$$\begin{aligned}
        \dfrac{\partial \mathcal L}{\partial {a_j}^{[t]}} 
        &= \sum_{k}\dfrac{\partial \mathcal L}{\partial {z_k}^{[t+1]}} \dfrac{\partial {z_k}^{[t+1]}}{\partial {a_j}^{[t]}} = \sum_{k}\dfrac{\partial \mathcal L}{\partial {z_k}^{[t+1]}} {w_{kj}}^{[t+1]}
    \end{aligned}$$

and

$$\begin{aligned}
    \dfrac{\partial \mathcal L}{\partial {z_j}^{[t]}} 
    &= \sum_{k}\dfrac{\partial \mathcal L}{\partial {a_k}^{[t]}} \dfrac{\partial {a_k}^{[t]}}{\partial {z_j}^{[t]}}.
\end{aligned}$$

This sum typically reduces to a single term for activations such as ReLU but not for softmax. Similarly, the backpropagated gradients for the parameter nodes are given by

$$\begin{aligned}
    \dfrac{\partial \mathcal L}{\partial {w_{kl}}^{[t]}} 
    &= \dfrac{\partial \mathcal L}{\partial {z_k}^{[t]}} \dfrac{\partial {z_k}^{[t]}}{\partial {w_{kl}}^{[t]}} = \dfrac{\partial \mathcal L}{\partial {z_k}^{[t]}} {a^{[t-1]}_l}. \\
\end{aligned}$$

Backpropagated gradients for compute nodes are stored until the weights are updated, e.g. $\frac{\partial \mathcal L}{\partial {z_k}^{[t+1]}}$ are retrieved in the compute nodes of the $t+1$-layer to compute gradients in the $t$-layer. On the other hand, the local gradients $\frac{\partial {a_k}^{[t]}}{\partial {z_j}^{[t]}}$ are computed directly using autodifferentiation and evaluated with the current network state obtained during forward pass.

<br>

We highlight two important properties of the algorithm which makes it the practical choice for training huge neural networks:

* **Modularity.** The dependence only on nodes belonging to the upper layer suggests a modularity in the computation, e.g. we can connect DAG subnetworks with possibly distinct network architectures by only connecting nodes that are exposed between layers. 

<br>

* **Bottleneck and complexity.** Assuming each computation of a local derivative takes constant time, then backward pass requires $\mathcal O(V + E)$ computations, where $E$ is the number of weights which corresponds to chain rule involving local and backpropagated gradients, and $V$ is the number of nodes which corresponds to computing activation gradients. Thus, efficient autodifferentiation and fast matrix multiplications, e.g. by having specialized hardware for parallelism, are crucial steps in making backpropagation efficient. Since each weight and backpropagated gradient must be stored in backward pass, then the memory complexity is similarly linear in the network size.

```{figure} ../img/backprop-compgraph2.png
---
width: 35em
name: backprop-compgraph2
---
Backprop with weights for a single layer neural network with sigmoid activation and cross-entropy loss. Observe the gradient flowing from node <code>L</code> to the node <code>w0</code>.
```

```{figure} ../img/backprop-compgraph3.png
---
width: 25em
name: backprop-compgraph3
---
Backprop with weights for a single layer neural network with sigmoid activation and cross-entropy loss. Local gradients that require current values of the nodes while backpropagated gradients are accessed from the layer above. Node <code>u</code> which has fan-in > 1 performs chain rule on the backpropagated gradients.
```

## Autodifferentiation with PyTorch autograd

The `autograd` package allows automatic differentiation by building computational graphs on the fly every time we pass data through our model. Autograd tracks which data combined through which operations to produce the output. This allows us to take derivatives over ordinary imperative code. Recall that this functionality is consistent with the memory and time requirements outlined in the discussion above.

<br>

**Backward for scalars.** Let $y = \mathbf x^\top \mathbf x = \sum_i {x_i}^2.$ In this example, we initialize a tensor `x` which initially has no gradient. Calling backward on `y` results in gradients being stored on the leaf tensor `x`. 

In [132]:
x = torch.arange(4, dtype=torch.float, requires_grad=True)
y = x.T @ x 

y.backward() 
(x.grad == 2*x).all()

tensor(True)

**Backward for vectors.** Let $\mathbf y = g(\mathbf x)$ and let $\mathbf v$ be a vector having the same length as $\mathbf y.$ Then `y.backward(v)` implements   

$$\sum_i v_i \left(\frac{\partial y_i}{\partial x_j}\right)$$ 
  
resulting in a vector of same length as `x` that is stored in `x.grad`. Note that the terms on the right are the local gradients in backprop. Hence, if `v` contains backpropagated gradients of nodes that depend on `y`, then this operation gives us the backpropagated gradients with respect to `x`, i.e. setting $v_i = \frac{\partial \mathcal{L} }{\partial y_i}$ gives us the vector $\frac{\partial \mathcal{L} }{\partial x_j}.$

In [179]:
x = torch.rand(size=(4,), dtype=torch.float, requires_grad=True)
v = torch.rand(size=(2,), dtype=torch.float)
y = x[:2]

# Computing the Jacobian by hand
J = torch.tensor(
    [[1, 0, 0, 0],
    [0, 1, 0, 0]], dtype=torch.float
)

# Confirming the above formula
y.backward(v)
(x.grad == v @ J).all()

tensor(True)

**Locally disabling gradient tracking.** To stop PyTorch from building computational graphs, we can put the code inside a `with torch.no_grad()` block. In this mode, the result of every computation will have `requires_grad=False`, even when the inputs have `requires_grad=True`. 
<br><br>
Another method is to use the `.detach()` method which returns a new tensor detached from the current graph but shares the same storage with the original one. In-place modifications on either of them will be seen, and may trigger errors in correctness checks. Disabling gradient computation is useful when computing values, e.g. accuracy, whose gradients will not be backpropagated into the network.