# Pytorch Backprop

Let's start from an example:

In [None]:
import math
import torch

import matplotlib.pyplot as plt

a = torch.linspace(0., 2. * math.pi, steps=25, requires_grad=True) # leaf node
b = torch.sin(a)
b.retain_grad() # We need to call retrain_grad() explicitly here for a non-leaf node, as we won't get non-leaf gradients by default.
plt.plot(a.detach(), b.detach(), label="b")

c = torch.cos(a)
c.retain_grad()
plt.plot(a.detach(), c.detach(), label="c")

d = b + c
d.retain_grad()
plt.plot(a.detach(), d.detach(), label="d")
plt.legend()

out = d.sum() # root node
out.retain_grad()

The *computation graph* (a DAG) looks like:

In [None]:
import graphviz
dot = graphviz.Digraph()
dot.node('a', 'a = [0, 2pi]')
dot.node('b', 'b = sin(a)')
dot.node('c', 'c = cos(a)')
dot.node('d', 'd = b + c')
dot.node('o', 'out = sum(d)')
dot.edges(['ab', 'ac', 'bd', 'cd', 'do'])
dot

## grad_fn, backward, and grad
The `grad_fn` is the local function, and its property `next_functions` backtrack its parent node(s) in the computation graph:

In [None]:
print(a.grad_fn)

In [None]:
b.grad_fn

In [None]:
# note that grad_fn `SinBackward0` is a built-in (non-python) function:
import inspect
inspect.getmro(b.grad_fn.__class__)

In [None]:
c.grad_fn

In [None]:
d.grad_fn

In [None]:
out.grad_fn

In [None]:
out.grad_fn.next_functions # only a single parent

In [None]:
d.grad_fn.next_functions # two parents

In [None]:
d.grad_fn.next_functions[1][0].next_functions # right parent

Before we get `grad`, we will need to "back propagate" (a typical implementation will go backward and accumulate gradients in reverese topological order):

In [None]:
out.backward(retain_graph=True) # retain_graph=True keep this computation graph for us to call backward() again (see later)
# Also, in PyTorch, because out is a scalar, an implicit argument `gradient=tensor(1.)` is assumed.
# This is convenient because most often it will be a loss value, and the derivatives of loss w.r.t. loss is 1.

In [None]:
print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(out.grad)

In [None]:
plt.plot(a.detach(), b.detach(), label="b")
plt.plot(a.detach(), b.grad.detach(), label="b.grad")
plt.plot(a.detach(), a.grad.detach(), label="a.grad")
eps = 0.05
plt.plot(a.detach(), torch.cos(a).detach() - torch.sin(a).detach() + eps, label="cos(a) - sin(a)")
plt.legend()

if we think via math equations:
$$
\begin{align}
\text{out} &= \sum_i d_i \\
d_i &= b_i + c_i \\
b_i &= \sin(a_i) \\
c_i &= \cos(a_i) \\
\end{align}
$$
therefore
$$
\begin{align}
\text{b[i].grad} &= \frac{\partial}{\partial b_i}\sum_{j}d_j = 1 \\
\text{a[i].grad} &=  1 \cdot \frac{\partial}{\partial a_i} (\sin a_i + \cos a_i) = \cos a_i - \sin a_i
\end{align}
$$

Note that we can also call `backward()` from an intermediate node:

In [None]:
# clear the accumulated gradients
a.grad.zero_()
b.grad.zero_()
c.grad.zero_()
d.grad.zero_()
out.grad.zero_()

In [None]:
d.backward(gradient=torch.ones(25))

We will get the same results:

In [None]:
plt.plot(a.detach(), b.detach(), label="b")
plt.plot(a.detach(), b.grad.detach(), label="b.grad")
plt.plot(a.detach(), a.grad.detach(), label="a.grad")
eps = 0.05
plt.plot(a.detach(), torch.cos(a).detach() - torch.sin(a).detach() + eps, label="cos(a) - sin(a)")
plt.legend()

However, there is a minor difference: Because `out` node does not gets propogated this time, so its gradients is still zero for the last `backward()` call:

In [None]:
print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(out.grad)

## Finally, a simple "torch.tensor.backward()"
To dive a little bit more, take a look at a toy Tensor class in [backward.py](./backward.py) to understand how the backward method works in a nutshell.

Note that, in real-world and production implementation (like in PyTorch), the "DFS postorder" topologicial ordering will be replaced with "Kahnâ€™s algorithm" because the latter produces deterministic ordering and the vertices can be executed in batch immediately when they are ready. Furthermore, the Kahn's algorithm naturally includes cycle detection.