# Pytorch Backprop (Advanced)

## Softmax
Let's first focus on the softmax function which is a $x \to y: R^{n} \rightarrow R^{n}$ mapping where any output element depends on all input elements, and its Jacobian matrix is somewhat non-trivial. These make it a good example function to study backprop:

In [270]:
import torch
from torch.nn.functional import softmax

x = torch.randn((5,), requires_grad=True) # a test input vector
y = softmax(x, dim=-1)
print(x)
print(y)

y_grad = torch.randn((5,), requires_grad=False) 

y.backward(y_grad, retain_graph=True)
print(x.grad)
#assert torch.allclose(x.grad, y.grad_fn(y_grad)) # same as calling y.grad_fn()
#assert torch.allclose(x.grad, torch.autograd.grad(y, x, y_grad)[0]) # an alternative way to get x.grad

tensor([ 0.9403, -0.7540, -0.5940, -1.6701,  0.7946], requires_grad=True)
tensor([0.4279, 0.0786, 0.0922, 0.0315, 0.3698], grad_fn=<SoftmaxBackward0>)
tensor([ 0.1915, -0.0174,  0.0615,  0.0018, -0.2374])


Assume the hypotheticcal scaler loss function (at the root of our computation graph) is $L$, the output value `x.grad` $\nabla_x^T L $, according to the [chain rule](../math/chain_rule_and_jacobian.ipynb), is derived from $\nabla_{y}^T L \cdot J_x y$ where  $\nabla_{y} L$ is just `y_grad`.

## Softmax (Naively Batched)

We can also compute the "batched" version. In backprop context, batch means the same and (non-batched) input vector $x$ with (batched) `y_grad` (if $x$ are batched and are different, the Jacobian matrix's values $\left. J_{x}y \right|_{x = x_1, x_2, ..., x_B}$ would be different):

In [271]:
Y_grad = torch.randn((3, 5), requires_grad=False)  # a test **batch**
# let's set its first row to be identical to our above vector, just make it easier to observe the differences.
Y_grad[0, :] = y_grad.detach()
print(Y_grad)

# y = softmax(x, dim=-1) # same x, same, y, no need to do it again!

x.grad.zero_()
y.backward(Y_grad[0, :], retain_graph=True)
print(x.grad) # same as the old x.grad

x.grad.zero_()
y.backward(Y_grad[1, :], retain_graph=True)
print(x.grad)

x.grad.zero_()
y.backward(Y_grad[2, :], retain_graph=True)
print(x.grad)

tensor([[ 0.1229, -0.5459,  0.3422, -0.2690, -0.9665],
        [-0.7317,  0.7580, -1.4002, -0.1251, -1.5846],
        [-1.2540,  0.4095, -0.0615, -1.0414,  0.9484]])
tensor([ 0.1915, -0.0174,  0.0615,  0.0018, -0.2374])
tensor([ 0.1031,  0.1360, -0.0394,  0.0267, -0.2263])
tensor([-0.4544,  0.0473,  0.0120, -0.0267,  0.4218])


## VJP- and JVP-Batched
A VJP (Vector–Jacobian Product) is simply a term to describe the math form:
$$
\underbrace{\mathbf{v}^\top}_{\text{cotangents}}
\underbrace{\frac{\partial f(\mathbf{x})}{\partial \mathbf{x}}}_{\substack{\text{Jacobian of }f \\ \text{w.r.t. primals}}}
$$

similarly, JVP is
$$
\underbrace{\frac{\partial f(\mathbf{x})}{\partial \mathbf{x}}}_{\substack{\text{Jacobian of }f \\ \text{w.r.t. primals}}}
\underbrace{\mathbf{u}}_{\text{tangents}}
$$

In [272]:
y2, vjp_fn = torch.func.vjp(softmax, x) # torch.func.vjp(func, *primals)
assert torch.allclose(y, y2)

In [273]:
vjp_fn(y_grad) # y_grad is the "v"

(tensor([ 0.1915, -0.0174,  0.0615,  0.0018, -0.2374],
        grad_fn=<SoftmaxBackwardDataBackward0>),)

However, `vjp_fn` cannot take in batched y_grad, i.e., Y_grad:

In [274]:
try:
    vjp_fn(Y_grad)  # error!
except Exception as e:
    print(e)

Mismatch in shape: grad_output[0] has a shape of torch.Size([3, 5]) and output[0] has a shape of torch.Size([5]).


A simple solution is to loop and apply (i.e., `vmap`):

In [275]:
batched_vjp_fn = torch.func.vmap(vjp_fn)
vmap_X_grad = batched_vjp_fn(Y_grad)[0].detach()
assert torch.allclose(vmap_X_grad, torch.vstack([
    vjp_fn(Y_grad[0, :])[0].detach(),
    vjp_fn(Y_grad[1, :])[0].detach(),
    vjp_fn(Y_grad[2, :])[0].detach(),
]))
vmap_X_grad

tensor([[ 0.1915, -0.0174,  0.0615,  0.0018, -0.2374],
        [ 0.1031,  0.1360, -0.0394,  0.0267, -0.2263],
        [-0.4544,  0.0473,  0.0120, -0.0267,  0.4218]])

where each row corresponds to exactly the single-batched x.grad after each `y.backward` above.

Due to the efficiency of only compute Jacobian matrix once (via the `vjp` function), this batched version runs faster.$^{1}$

Similarly, for `jvp` (note that pytorch jvp takes in tuples of *multiple primals or tangents* compared to `vjp`):

In [276]:
# reset
x.grad.zero_()
y = softmax(x, dim=-1)
y_grad = torch.randn((5,), requires_grad=False)
y.backward(y_grad)
print(x.grad)

tensor([ 0.1941, -0.0545, -0.1106, -0.0096, -0.0194])


$$
\begin{align}
                      && \text{y-grad }^T \cdot J_x y(x) = \text{x-grad }^T \\
\Rightarrow &&  J^T_x y(x) \cdot \text{y-grad } = \text{x-grad } \\
\end{align}
$$

In [277]:
# Move vjp here for easy comparison:
# y2, vjp_fn = torch.func.vjp(softmax, x) # torch.func.vjp(func, *primals)
y3, jvp_out = torch.func.jvp(softmax, primals=(x, ), tangents=(y_grad, )) # torch.func.jvp(func, primals, tangents)
assert torch.allclose(y, y3)
assert torch.allclose(x.grad, jvp_out)

## Jacobian

Above manipulations all hide the actual Jacobian matrix from our sight. How to retrieve it? Simple! Just constract an eye matrix and retrieve it row by row:

$$
\begin{bmatrix}
e_1 \\
e_2 \\
... \\
e_5
\end{bmatrix}
\cdot [ J_x \; y(x) ]_{5 \times 5} = E_{5 \times 5} \cdot [ J_x \; y(x) ]_{5 \times 5} = [ J_x \; y(x) ]_{5 \times 5}
$$

In [278]:
eye = torch.eye(5, 5)
J = batched_vjp_fn(eye)[0].detach()
print(J)

tensor([[ 0.2448, -0.0336, -0.0395, -0.0135, -0.1582],
        [-0.0336,  0.0724, -0.0073, -0.0025, -0.0291],
        [-0.0395, -0.0073,  0.0837, -0.0029, -0.0341],
        [-0.0135, -0.0025, -0.0029,  0.0305, -0.0116],
        [-0.1582, -0.0291, -0.0341, -0.0116,  0.2331]])


Let's verify whether this matrix is indeed the Jacobian matrix:

In [279]:
x.grad.zero_()
y = softmax(x, dim=-1)
y_grad = torch.randn((5,), requires_grad=False)
y.backward(y_grad)
assert torch.allclose(x.grad, y_grad @ J)

The above vjp-vmap composition to get the Jacobian matrix can be also replaced by the `jacrev` function, which stands for "Jacobian-of-Reverse" in the reverse mode autodiff:

In [280]:
torch.func.jacrev(softmax)(x)

tensor([[ 0.2448, -0.0336, -0.0395, -0.0135, -0.1582],
        [-0.0336,  0.0724, -0.0073, -0.0025, -0.0291],
        [-0.0395, -0.0073,  0.0837, -0.0029, -0.0341],
        [-0.0135, -0.0025, -0.0029,  0.0305, -0.0116],
        [-0.1582, -0.0291, -0.0341, -0.0116,  0.2331]],
       grad_fn=<ViewBackward0>)

Similarly, there is a `jacfwd` for forward mode autodiff.
`jacfwd` uses forward-mode AD. It is implemented as a composition of jvp-vmap.

In [281]:
torch.func.jacfwd(softmax)(x)

tensor([[ 0.2448, -0.0336, -0.0395, -0.0135, -0.1582],
        [-0.0336,  0.0724, -0.0073, -0.0025, -0.0291],
        [-0.0395, -0.0073,  0.0837, -0.0029, -0.0341],
        [-0.0135, -0.0025, -0.0029,  0.0305, -0.0116],
        [-0.1582, -0.0291, -0.0341, -0.0116,  0.2331]],
       grad_fn=<ViewBackward0>)

`jacfwd` and `jacrev` can be substituted for each other but they have different performance characteristics:$^1$
> In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.

## Reference
1. https://docs.pytorch.org/functorch/stable/notebooks/jacobians_hessians.html