# Pytorch Backprop (Advanced)

## Softmax
Let's first focus on the softmax function which is a $x \to y: R^{n} \rightarrow R^{n}$ mapping where any output element depends on all input elements, and its Jacobian matrix is somewhat non-trivial. These make it a good example function to study backprop:

In [240]:
import torch
from torch.nn.functional import softmax

x = torch.randn((5,), requires_grad=True) # a test input vector
y = softmax(x, dim=-1)
print(x)
print(y)

y_grad = torch.randn((5,), requires_grad=True) 

y.backward(y_grad, retain_graph=True)
assert torch.allclose(x.grad, y.grad_fn(y_grad))
print(x.grad)
print(torch.autograd.grad(y, x, y_grad)) # an alternative way to get v.grad

tensor([ 1.0040, -1.5440,  1.4854,  0.8328,  0.3945], requires_grad=True)
tensor([0.2449, 0.0192, 0.3964, 0.2064, 0.1331], grad_fn=<SoftmaxBackward0>)
tensor([-0.2268, -0.0136,  0.1423, -0.1827,  0.2808])
(tensor([-0.2268, -0.0136,  0.1423, -0.1827,  0.2808]),)


Assume the hypotheticcal scaler loss function (at the root of our computation graph) is $L$, the output value `x.grad` $\nabla_x^T L $, according to the [chain rule](../math/chain_rule_and_jacobian.ipynb), is derived from $\nabla_{y}^T L \cdot J_x y$ where  $\nabla_{y} L$ is just `y_grad`.

## Softmax (Batched)

We can also compute the "batched" version:

In [241]:
X = torch.randn((3, 5), requires_grad=False) # a test **matrix**
Y_grad = torch.randn((3, 5), requires_grad=False) 

# let's set its first row to be identical to our above vector, just make it easier to observe the differences.
X[0, :] = x.detach() 
Y_grad[0, :] = y_grad.detach()

X.requires_grad = True
Y_grad.requires_grad = True

# now perform the batched version:
Y = softmax(X, dim=-1)
print(X)
print(Y)
Y.backward(Y_grad)
print(X.grad)

assert torch.allclose(X.grad[0, :], x.grad)

tensor([[ 1.0040, -1.5440,  1.4854,  0.8328,  0.3945],
        [ 1.2526,  1.3910, -0.9862, -0.6285,  0.0700],
        [ 0.3329,  0.6859,  1.6928, -1.2989,  0.0233]], requires_grad=True)
tensor([[0.2449, 0.0192, 0.3964, 0.2064, 0.1331],
        [0.3685, 0.4232, 0.0393, 0.0562, 0.1129],
        [0.1380, 0.1964, 0.5375, 0.0270, 0.1012]], grad_fn=<SoftmaxBackward0>)
tensor([[-0.2268, -0.0136,  0.1423, -0.1827,  0.2808],
        [-0.1579,  0.3427, -0.0167, -0.0944, -0.0736],
        [-0.0404,  0.3583, -0.3998,  0.0679,  0.0140]])


## VJP and JVP
A VJP (Vector–Jacobian Product) is simply a term to describe the math form:
$$
\underbrace{\mathbf{v}^\top}_{\text{cotangents}}
\underbrace{\frac{\partial f(\mathbf{x})}{\partial \mathbf{x}}}_{\substack{\text{Jacobian of }f \\ \text{w.r.t. primals}}}
$$

similarly, JVP is
$$
\underbrace{\frac{\partial f(\mathbf{x})}{\partial \mathbf{x}}}_{\substack{\text{Jacobian of }f \\ \text{w.r.t. primals}}}
\underbrace{\mathbf{u}}_{\text{tangents}}
$$

In [242]:
y2, vjp_fn = torch.func.vjp(softmax, x) # torch.func.vjp(func, *primals)
assert torch.allclose(y, y2)

In [243]:
vjp_fn(y_grad) # y_grad is the "v"

(tensor([-0.2268, -0.0136,  0.1423, -0.1827,  0.2808],
        grad_fn=<SoftmaxBackwardDataBackward0>),)

However, `vjp_fn` cannot take in batched y_grad, i.e., Y_grad:

In [244]:
try:
    vjp_fn(Y_grad)  # error!
except Exception as e:
    print(e)

Mismatch in shape: grad_output[0] has a shape of torch.Size([3, 5]) and output[0] has a shape of torch.Size([5]).


A simple solution is to loop and apply (i.e., `vmap`):

In [245]:
batched_vjp_fn = torch.func.vmap(vjp_fn)
vmap_X_grad = batched_vjp_fn(Y_grad)[0].detach()
vmap_X_grad

tensor([[-0.2268, -0.0136,  0.1423, -0.1827,  0.2808],
        [ 0.0644,  0.0288,  0.1057, -0.2041,  0.0053],
        [-0.1223,  0.0310, -0.3767,  0.4771, -0.0091]])

In [246]:
vmap_X_grad == torch.vstack([
    vjp_fn(Y_grad[0, :])[0].detach(),
    vjp_fn(Y_grad[1, :])[0].detach(),
    vjp_fn(Y_grad[2, :])[0].detach(),
])

tensor([[True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True]])