# Pytorch Backprop (Advanced)

## Softmax
Let's first focus on the softmax function which is a $x \to y: R^{n} \rightarrow R^{n}$ mapping where any output element depends on all input elements, and its Jacobian matrix is somewhat non-trivial. These make it a good example function to study backprop:

In [47]:
import torch
from torch.nn.functional import softmax

x = torch.randn((5,), requires_grad=True) # a test input vector
y = softmax(x, dim=-1)
print(x)
print(y)

y_grad = torch.randn((5,), requires_grad=False) 

y.backward(y_grad, retain_graph=True)
print(x.grad)
#assert torch.allclose(x.grad, y.grad_fn(y_grad)) # same as calling y.grad_fn()
#assert torch.allclose(x.grad, torch.autograd.grad(y, x, y_grad)[0]) # an alternative way to get x.grad

tensor([ 1.8477,  0.8658,  1.0925,  1.8025, -0.5439], requires_grad=True)
tensor([0.3458, 0.1295, 0.1625, 0.3305, 0.0316], grad_fn=<SoftmaxBackward0>)
tensor([ 0.0801,  0.0278,  0.2022, -0.3059, -0.0042])


Assume the hypotheticcal scaler loss function (at the root of our computation graph) is $L$, the output value `x.grad` $\nabla_x^T L $, according to the [chain rule](../math/chain_rule_and_jacobian.ipynb), is derived from $\nabla_{y}^T L \cdot J_x y$ where  $\nabla_{y} L$ is just `y_grad`.

## Softmax (Batched)

We can also compute the "batched" version. In backprop context, batch means the same and (non-batched) input vector $x$ with (batched) `y_grad` (if $x$ are batched and are different, the Jacobian matrix's values $\left. J_{x}y \right|_{x = x_1, x_2, ..., x_B}$ would be different, thus it gives worse utilization of the hardware):

In [48]:
Y_grad = torch.randn((3, 5), requires_grad=False)  # a test **batch**
# let's set its first row to be identical to our above vector, just make it easier to observe the differences.
Y_grad[0, :] = y_grad.detach()
print(Y_grad)

# y = softmax(x, dim=-1) # same x, same, y, no need to do it again!

x.grad.zero_()
y.backward(Y_grad[0, :], retain_graph=True)
print(x.grad) # same as the old x.grad

x.grad.zero_()
y.backward(Y_grad[1, :], retain_graph=True)
print(x.grad)

x.grad.zero_()
y.backward(Y_grad[2, :], retain_graph=True)
print(x.grad)

tensor([[-0.0335, -0.0504,  0.9793, -1.1906, -0.3986],
        [ 0.1053,  0.5013, -0.1292, -1.2227,  1.3505],
        [-0.2068, -1.4909,  0.5869,  1.4332, -1.1067]])
tensor([ 0.0801,  0.0278,  0.2022, -0.3059, -0.0042])
tensor([ 0.1336,  0.1013,  0.0247, -0.3112,  0.0516])
tensor([-0.1647, -0.2280,  0.0516,  0.3846, -0.0435])


## VJP and JVP
A VJP (Vector–Jacobian Product) is simply a term to describe the math form:
$$
\underbrace{\mathbf{v}^\top}_{\text{cotangents}}
\underbrace{\frac{\partial f(\mathbf{x})}{\partial \mathbf{x}}}_{\substack{\text{Jacobian of }f \\ \text{w.r.t. primals}}}
$$

similarly, JVP is
$$
\underbrace{\frac{\partial f(\mathbf{x})}{\partial \mathbf{x}}}_{\substack{\text{Jacobian of }f \\ \text{w.r.t. primals}}}
\underbrace{\mathbf{u}}_{\text{tangents}}
$$

In [49]:
y2, vjp_fn = torch.func.vjp(softmax, x) # torch.func.vjp(func, *primals)
assert torch.allclose(y, y2)

In [50]:
vjp_fn(y_grad) # y_grad is the "v"

(tensor([ 0.0801,  0.0278,  0.2022, -0.3059, -0.0042],
        grad_fn=<SoftmaxBackwardDataBackward0>),)

However, `vjp_fn` cannot take in batched y_grad, i.e., Y_grad:

In [51]:
try:
    vjp_fn(Y_grad)  # error!
except Exception as e:
    print(e)

Mismatch in shape: grad_output[0] has a shape of torch.Size([3, 5]) and output[0] has a shape of torch.Size([5]).


A simple solution is to loop and apply (i.e., `vmap`):

In [52]:
batched_vjp_fn = torch.func.vmap(vjp_fn)
vmap_X_grad = batched_vjp_fn(Y_grad)[0].detach()
assert torch.allclose(vmap_X_grad, torch.vstack([
    vjp_fn(Y_grad[0, :])[0].detach(),
    vjp_fn(Y_grad[1, :])[0].detach(),
    vjp_fn(Y_grad[2, :])[0].detach(),
]))
vmap_X_grad

tensor([[ 0.0801,  0.0278,  0.2022, -0.3059, -0.0042],
        [ 0.1336,  0.1013,  0.0247, -0.3112,  0.0516],
        [-0.1647, -0.2280,  0.0516,  0.3846, -0.0435]])

where each row corresponds to exactly the single-batched x.grad after each `y.backward` above.

Note that this vmap-vjp composition can be conveniently wrapped in to the `jacrev` function, which stands for "Jacobian-Reverse":

In [55]:
torch.func.jacrev(softmax)(x)

tensor([[ 0.2262, -0.0448, -0.0562, -0.1143, -0.0109],
        [-0.0448,  0.1128, -0.0210, -0.0428, -0.0041],
        [-0.0562, -0.0210,  0.1361, -0.0537, -0.0051],
        [-0.1143, -0.0428, -0.0537,  0.2213, -0.0105],
        [-0.0109, -0.0041, -0.0051, -0.0105,  0.0306]],
       grad_fn=<ViewBackward0>)