In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.functional import jacobian

# Backpropagation through Linear Layer of a Neural Network using Pytorch

We will follow the teachings of Justing Johnson in his [CS231n course](https://cs231n.stanford.edu/handouts/linear-backprop.pdf). We will derive the backpropagation for a linear layer. In this linear layer we will assume a batch of $N$ records, with $D$ features. Linear layer transforms a $D$ dimension input to a $M$ dimension output. The weights matrix is $W$ of size $D \times M$ and the bias vector is $b$ of size $M$. The input to the layer is $X$ of size $N \times D$ and the output is $Z$ of size $N \times M$. The loss function is $L$.

Mathematical notation:

$$
\begin{align*}
X &\in \mathbb{R}^{N \times D} \\
W &\in \mathbb{R}^{D \times M} \\
b &\in \mathbb{R}^{M} \\
Z = XW + b &\in \mathbb{R}^{N \times M} \\
L &\in \mathbb{R}
\end{align*}
$$



For the purpose of focusing on the backpopagation through linear layer only. We will also assume that $\frac{\partial L}{\partial Z}$ is given to us. We will derive the backpropagation for $\frac{\partial L}{\partial X}$, $\frac{\partial L}{\partial W}$ and $\frac{\partial L}{\partial b}$. Note that $\frac{\partial L}{\partial Z}$ is of size $N \times M$. as $Z$ is of size $N \times M$. and L is a scalar.

Components of $\frac{\partial L}{\partial Z}$ are as follows:

$$
\begin{align*}
\frac{\partial L}{\partial Z} &= \begin{bmatrix}
\frac{\partial L}{\partial Z_{11}} & \frac{\partial L}{\partial Z_{12}} & \cdots & \frac{\partial L}{\partial Z_{1M}} \\
\frac{\partial L}{\partial Z_{21}} & \frac{\partial L}{\partial Z_{22}} & \cdots & \frac{\partial L}{\partial Z_{2M}} \\
\vdots & \vdots & \ddots & \vdots \\
\frac{\partial L}{\partial Z_{N1}} & \frac{\partial L}{\partial Z_{N2}} & \cdots & \frac{\partial L}{\partial Z_{NM}}
\end{bmatrix}
\end{align*}
$$

Our goal is to use $\frac{\partial L}{\partial Z}$ to calculate $\frac{\partial L}{\partial X}$, $\frac{\partial L}{\partial W}$ and $\frac{\partial L}{\partial b}$. Since L is a scalar, $\frac{\partial L}{\partial X}$, $\frac{\partial L}{\partial W}$ and $\frac{\partial L}{\partial b}$ will be of the same size as $X$ ($N \times D$), $W$ ($D \times M$) and $b$ ($M$) respectively.

By the chain rule, we know that:


$$
\begin{equation*}
\begin{aligned}
\frac{\partial L}{\partial X} &= \frac{\partial L}{\partial Z}  \frac{\partial Z}{\partial X} \\
\frac{\partial L}{\partial W} &= \frac{\partial L}{\partial Z}  \frac{\partial Z}{\partial W} \\
\frac{\partial L}{\partial b} &= \frac{\partial L}{\partial Z}  \frac{\partial Z}{\partial b}
\end{aligned}
\end{equation*}
$$



::: {.callout-note}
- $\frac{\partial Z}{\partial X}$ is a Jacobian Tensor of size $(N \times M), (N \times D)$
- $\frac{\partial Z}{\partial W}$ is a Jacobian Tensor of size $(N \times M), (D \times M)$
- $\frac{\partial Z}{\partial b}$ is a Jacobian Tensor of size $M \times N$.
:::

In a typical network you may see N (Batch size) = 64; D (feature dimension) = 4096; M (Output features) = 4096. That means $\frac{\partial Z}{\partial X}$ will be a tensor of size $64 \times 4096 \times 64 \times 4096$. Thats about 68 Billion numbers each using 32-bit float. Thats a lot of memory. Doing this operation in a naive way will be very slow and memory intensive. We will see how to do this in a more efficient way.

However for most common neural network layers,  we can derive the compute the product $\frac{\partial L}{\partial Z}  \frac{\partial Z}{\partial X}$ without explicitly using the jacobian $\frac{\partial Z}{\partial X}$. We will see how to do this in the next section.



## Deriving $\frac{\partial L}{\partial X}$
---

We will start by deriving $\frac{\partial L}{\partial X}$. Note Since L is scalar $\frac{\partial L}{\partial X}$ has the same size as $X$ ($N \times D$).

$$
X = \begin{bmatrix}
X_{11} & X_{12} & \cdots & X_{1D} \\
X_{21} & X_{22} & \cdots & X_{2D} \\
\vdots & \vdots & \ddots & \vdots \\
X_{N1} & X_{N2} & \cdots & X_{ND}
\end{bmatrix}
\implies 
\frac{\partial L}{\partial X} = \begin{bmatrix}
\frac{\partial L}{\partial X_{11}} & \frac{\partial L}{\partial X_{12}} & \cdots & \frac{\partial L}{\partial X_{1D}} \\
\frac{\partial L}{\partial X_{21}} & \frac{\partial L}{\partial X_{22}} & \cdots & \frac{\partial L}{\partial X_{2D}} \\
\vdots & \vdots & \ddots & \vdots \\
\frac{\partial L}{\partial X_{N1}} & \frac{\partial L}{\partial X_{N2}} & \cdots & \frac{\partial L}{\partial X_{ND}}
\end{bmatrix}
$$

Lets start by calculating $\frac{\partial L}{\partial X_{mn}}$ for a single element $X_{mn}$ of $X$. We will use the chain rule to calculate this. We will calculate the derivative of $L$ with respect to $X_{mn}$ Also note $\frac{\partial L}{\partial X_{mn}}$  is a scalar.

By chain rule
$$ 
\frac{\partial L}{\partial X_{mn}}  = \frac{\partial L}{\partial Z} \frac{\partial Z}{\partial X_{mn}}
$$
NB: 
- $\frac{\partial L}{\partial Z}$ is of size $N \times M$.
- $\frac{\partial Z}{\partial X_{mn}}$ is of size ($N \times M) \times 1$.


Now note 

$$
\begin{equation*}
\begin{aligned}
Z &= XW + b \\
Z_{ij} &= \sum_{k=1}^{D} x_{ik} w_{kj} + b_j \\
\implies \\
\frac{\partial Z_{ij}}{\partial x_{mn}} &= w_{nj} \text{ if } i = m \text{ else } 0 \\
\frac{\partial Z_{ij}}{\partial w_{mn}} &= x_{im} \text{ if } j = n \text{ else } 0 \\
\end{aligned}
\end{equation*}
$$


<!-- $$
\begin{align*}
Z = XW + b\\
Z_{ij} = \sum_{k=1}^{D} x_{ik} w_{kj} + b_j \\
\implies \\
\frac{\partial Z_{ij}}{\partial x_{mn}} = w_{nj} \text{ if } i = m \text{ else } 0 \\
\frac{\partial Z_{ij}}{\partial w_{mn}} = x_{im} \text{ if } j = n \text{ else } 0 \\
\end{align*}
$$ -->

This $\implies$

$$
\begin{equation*}
\begin{aligned}
\frac{\partial L}{\partial X_{mn}} &= \sum_{i=1}^{N} \sum_{j=1}^{M} \frac{\partial L}{\partial Z_{ij}} \frac{\partial Z_{ij}}{\partial X_{mn}} \\
&= \sum_{i=1}^{N} \sum_{j=1}^{M} \frac{\partial L}{\partial Z_{ij}} \frac{\partial Z_{ij}}{\partial x_{mn}} \\
&= \sum_{i=1}^{N} \sum_{j=1}^{M} \frac{\partial L}{\partial Z_{ij}} w_{nj} \text{ if } i = m \text{ else } 0 \\
&= \sum_{j=1}^{M} \frac{\partial L}{\partial Z_{mj}} w_{nj} \\
\end{aligned}
\end{equation*}
$$


Notice this nothing but a dot product of the $m$ th row of $\frac{\partial L}{\partial Z}$ and the $n$ th row of $W$  i.e ($n$ th column of $W^T$).

::: {.callout-note}
 So given $Z = X @ W + b$ and $\frac{\partial L}{\partial Z}$ we can calculate $\frac{\partial L}{\partial X}$ as follows:
$$
\begin{equation*}
\begin{aligned}
\frac{\partial L}{\partial X} &= \frac{\partial L}{\partial Z} W^T
\end{aligned}
\end{equation*}
$$
:::

Similarly we will now derive $\frac{\partial L}{\partial W}$ 

Deriving $\frac{\partial L}{\partial W}$
---

We will start by deriving $\frac{\partial L}{\partial W}$. Note $\frac{\partial L}{\partial W}$ has the same size as $W$ ($D \times M$).


$$
W = \begin{bmatrix}
W_{11} & W_{12} & \cdots & W_{1M} \\
W_{21} & W_{22} & \cdots & W_{2M} \\
\vdots & \vdots & \ddots & \vdots \\
W_{D1} & W_{D2} & \cdots & W_{DM}
\end{bmatrix}
$$

$\implies$

$$
\frac{\partial L}{\partial W} = \begin{bmatrix}
\frac{\partial L}{\partial W_{11}} & \frac{\partial L}{\partial W_{12}} & \cdots & \frac{\partial L}{\partial W_{1M}} \\
\frac{\partial L}{\partial W_{21}} & \frac{\partial L}{\partial W_{22}} & \cdots & \frac{\partial L}{\partial W_{2M}} \\
\vdots & \vdots & \ddots & \vdots \\
\frac{\partial L}{\partial W_{D1}} & \frac{\partial L}{\partial W_{D2}} & \cdots & \frac{\partial L}{\partial W_{DM}}
\end{bmatrix}
$$

Lets start by calculating $\frac{\partial L}{\partial W_{mn}}$ for a single element $W_{mn}$ of $W$. We will use the chain rule to calculate this. We will calculate the derivative of $L$ with respect to $W_{mn}$ Also note $\frac{\partial L}{\partial W_{mn}}$  is a scalar.

By chain rule
$$
\begin{equation*}
\begin{aligned}
\frac{\partial L}{\partial W_{mn}} &= \frac{\partial L}{\partial Z} \frac{\partial Z}{\partial W_{mn}}
\end{aligned}
\end{equation*}
$$

Now note

$$
\begin{equation*}
\begin{aligned}
Z &= XW + b \\
Z_{ij} &= \sum_{k=1}^{D} x_{ik} w_{kj} + b_j \\
\implies &\\
\frac{\partial Z_{ij}}{\partial w_{mn}} &= x_{im} \text{ if } j = n \text{ else } 0 \\
\end{aligned}
\end{equation*}
$$

This $\implies$

$$
\begin{equation*}
\begin{aligned}
\frac{\partial L}{\partial W_{mn}} &= \sum_{i=1}^{N} \sum_{j=1}^{M} \frac{\partial L}{\partial Z_{ij}} \frac{\partial Z_{ij}}{\partial W_{mn}} \\
&= \sum_{i=1}^{N} \sum_{j=1}^{M} \frac{\partial L}{\partial Z_{ij}} \frac{\partial Z_{ij}}{\partial w_{mn}} \\
&= \sum_{i=1}^{N} \frac{\partial L}{\partial Z_{in}} x_{im} \\
&= \sum_{i=1}^{N} x_{im} \frac{\partial L}{\partial Z_{in}} \\
\end{aligned}
\end{equation*}
$$

Observe that this is nothing by dot product of m th column of $X$ ( or the m th row of $X^T$) and the n th column of $\frac{\partial L}{\partial Z}$.


::: {.callout-note}
 So given $Z = X @ W + b$ and $\frac{\partial L}{\partial Z}$ we can calculate $\frac{\partial L}{\partial W}$ as follows:
$$
\begin{equation*}
\begin{aligned}
\frac{\partial L}{\partial W} &= X^T \frac{\partial L}{\partial Z}
\end{aligned}
\end{equation*}
$$
:::


## Deriving $\frac{\partial L}{\partial b}$

We will derive $\frac{\partial L}{\partial b}$. Note $\frac{\partial L}{\partial b}$ has the same size as $b$ ($M$).
$$
b = \begin{bmatrix}
b_{1} \\
b_{2} \\
\vdots \\
b_{M}
\end{bmatrix}
$$

$\implies$


$$
\frac{\partial L}{\partial b} = \begin{bmatrix}
\frac{\partial L}{\partial b_{1}} \\
\frac{\partial L}{\partial b_{2}} \\
\vdots \\
\frac{\partial L}{\partial b_{M}}
\end{bmatrix}
$$

Lets start by calculating $\frac{\partial L}{\partial b_{m}}$ for a single element $b_{m}$ of $b$. We will use the chain rule to calculate this. We will calculate the derivative of $L$ with respect to $b_{m}$ Also note $\frac{\partial L}{\partial b_{m}}$  is a scalar.

By chain rule
$$
\begin{equation*}
\begin{aligned}
\frac{\partial L}{\partial b_{m}} &= \frac{\partial L}{\partial Z} \frac{\partial Z}{\partial b_{m}}
\end{aligned}
\end{equation*}
$$

Now note

$$
\begin{equation*}
\begin{aligned}
Z &= XW + b \\
Z_{ij} &= \sum_{k=1}^{D} x_{ik} w_{kj} + b_j \\
\implies &\\
\frac{\partial Z_{ij}}{\partial b_{m}} &= 1 \text{ if } j = m \text{ else } 0 \\
\end{aligned}
\end{equation*}
$$

This $\implies$

$$
\begin{equation*}
\begin{aligned}
\frac{\partial L}{\partial b_{m}} &= \sum_{i=1}^{N} \sum_{j=1}^{M} \frac{\partial L}{\partial Z_{ij}} \frac{\partial Z_{ij}}{\partial b_{m}} \\
&= \sum_{i=1}^{N} \sum_{j=1}^{M} \frac{\partial L}{\partial Z_{ij}} \frac{\partial Z_{ij}}{\partial b_{m}} \\
&= \sum_{i=1}^{N} \frac{\partial L}{\partial Z_{im}} \\
\end{aligned}
\end{equation*}
$$

Observe that this is nothing but the sum of the m th column of $\frac{\partial L}{\partial Z}$.

::: {.callout-note}
 So given $Z = X @ W + b$ and $\frac{\partial L}{\partial Z}$ we can calculate $\frac{\partial L}{\partial b}$ as follows:
$$
\begin{equation*}
\begin{aligned}
\frac{\partial L}{\partial b} &= \sum_{i=1}^{N} \frac{\partial L}{\partial Z_{i}}
\end{aligned}
\end{equation*}
$$
:::

Now we will go back to pytorch and verify our derivations.

In [9]:
## a quick setup in pytorch
N = 5; D = 3; M = 2
X = torch.randn(N, D, requires_grad=False) ## disabling requires grad as we will do it all manually
W = torch.randn(D, M, requires_grad=False)
b = torch.randn(M, requires_grad=False)
#Z = X @ W + b
def linear_transform(X, W, b):
    res = X @ W + b
    print(f"Output Shape: {res.shape}")
    return res

dL_dZ = torch.randn(N, M, requires_grad=False)

## not we are interested in dL_dW and dL_db and dL_dX

jacobian_output = jacobian(linear_transform, (X, W, b))
jacob_dZ_dX, jacob_dZ_dW, jacob_dZ_db = jacobian_output
print(jacob_dZ_dX.shape, jacob_dZ_dW.shape, jacob_dZ_db.shape)

#dL_dZ @ jacob_dZ_dX

Output Shape: torch.Size([5, 2])
torch.Size([5, 2, 5, 3]) torch.Size([5, 2, 3, 2]) torch.Size([5, 2, 2])


In [10]:
## Now 

dl_dx = X.T @ dL_dZ 
dl_dw = dL_dZ @ W.T
dl_db = dL_dZ.sum(dim=0)
dl_dx, dl_dw, dl_db


(tensor([[ 0.0938,  1.3802],
         [-3.0765,  3.4443],
         [ 1.4643,  0.1143]]),
 tensor([[-0.2672,  0.1566, -0.2825],
         [-0.1333,  0.6751, -0.9973],
         [-1.0372,  1.0880, -1.7853],
         [-0.5966,  0.4912, -0.8339],
         [ 0.5652, -0.4775,  0.8074]]),
 tensor([ 2.8203, -2.7928]))

In [11]:
import torch
from torch.autograd.functional import jacobian


# Derived gradients using the formulas
dl_dw = X.T @ dL_dZ
dl_db = dL_dZ.sum(dim=0)
dl_dx = dL_dZ @ W.T

# Using the Jacobian to compute the gradients
dl_dx_jacobian = torch.einsum('ij,ijkl->kl', dL_dZ, jacob_dZ_dX)
dl_dw_jacobian = torch.einsum('ij,ijkl->kl', dL_dZ, jacob_dZ_dW)
dl_db_jacobian = torch.einsum('ij,ijk->k', dL_dZ, jacob_dZ_db)

# Compare the results
print("Manual gradient dL/dX: \n", dl_dx)
print("Jacobian-based dL/dX: \n", dl_dx_jacobian)
print("Manual gradient dL/dW: \n", dl_dw)
print("Jacobian-based dL/dW: \n", dl_dw_jacobian)
print("Manual gradient dL/db: \n", dl_db)
print("Jacobian-based dL/db: \n", dl_db_jacobian)

# Check if they are equal
print("dL/dX close: ", torch.allclose(dl_dx, dl_dx_jacobian))
print("dL/dW close: ", torch.allclose(dl_dw, dl_dw_jacobian))
print("dL/db close: ", torch.allclose(dl_db, dl_db_jacobian))


Manual gradient dL/dX: 
 tensor([[-0.2672,  0.1566, -0.2825],
        [-0.1333,  0.6751, -0.9973],
        [-1.0372,  1.0880, -1.7853],
        [-0.5966,  0.4912, -0.8339],
        [ 0.5652, -0.4775,  0.8074]])
Jacobian-based dL/dX: 
 tensor([[-0.2672,  0.1566, -0.2825],
        [-0.1333,  0.6751, -0.9973],
        [-1.0372,  1.0880, -1.7853],
        [-0.5966,  0.4912, -0.8339],
        [ 0.5652, -0.4775,  0.8074]])
Manual gradient dL/dW: 
 tensor([[ 0.0938,  1.3802],
        [-3.0765,  3.4443],
        [ 1.4643,  0.1143]])
Jacobian-based dL/dW: 
 tensor([[ 0.0938,  1.3802],
        [-3.0765,  3.4443],
        [ 1.4643,  0.1143]])
Manual gradient dL/db: 
 tensor([ 2.8203, -2.7928])
Jacobian-based dL/db: 
 tensor([ 2.8203, -2.7928])
dL/dX close:  True
dL/dW close:  True
dL/db close:  True


## A quick note on jacobian product for gradient backpropagation for multivariate functions
---
Let $F$ : from (m,n) to (p,q) be a function.  Let $G$ be a function from (p,q) to (r,s). We want to see the composition of $F$ and $G$. So the composition function $H = G \circ F$. H is a map from $(m,n) \to (x,y)$. Let $J_1$ be the jacobian of $F$ with respect to $X$. Let $J_2$ be the jacobian of $G$ with respect to $Y$. Then the jacobian of $H$ with respect to $X$ is given by:

$$
\begin{equation*}
\begin{aligned}
H(X) &= G(F(X)) \\
\frac{\partial H}{\partial X} &= \frac{\partial G}{\partial F} \frac{\partial F}{\partial X} \\
\end{aligned}
\end{equation*}
$$

Where $J_1 = \frac{\partial F}{\partial X}$ and $J_2 = \frac{\partial G}{\partial Y}$. 

Note the dimensions of the jacobian matrices. $J_1$ is of size $(m \times n) \times (p \times q)$ and $J_2$ is of size $(p \times q) \times (r \times s)$. The jacobian of $H$ with respect to $X$ is of size $(m \times n) \times (r \times s)$.

In [12]:
## lets do it with autograd now
m,n = 5, 4
p,q = 4, 3
r, s = 3, 2

x = torch.randn(m, n, requires_grad=False)
W_1 = torch.randn(p, q, requires_grad=False)
W_2 = torch.randn(r, s, requires_grad=False)

def linear_transform(x, W):
    return x @ W

def linear_transform_G(x, W1, W2):
    out_1 = linear_transform(x, W1)
    return linear_transform(out_1, W2)



out_1 = linear_transform(x, W_1)
out_2 = linear_transform(out_1, W_2)

jacobian_F = jacobian(linear_transform, (x, W_1))
jacobian_G = jacobian(linear_transform, (out_1, W_2))
jacobian_H = jacobian(linear_transform_G, (x, W_1, W_2))

## lets print out the shapes
print(f"jacob_F: {jacobian_F[0].shape}")
print(f"jacob_G: {jacobian_G[0].shape}")
print(f"jacob_H: {jacobian_H[0].shape}")

jacob_H_manual = torch.einsum('abij,ijxy->abxy', jacobian_G[0], jacobian_F[0])
print(jacob_H_manual.shape)
torch.allclose(jacobian_H[0], jacob_H_manual)

jacob_F: torch.Size([5, 3, 5, 4])
jacob_G: torch.Size([5, 2, 5, 3])
jacob_H: torch.Size([5, 2, 5, 4])
torch.Size([5, 2, 5, 4])


True