# Bare-Bones Backpropagation
> Demonstrating the simplest possible backpropagation implementation, with all the clutter removed.

- toc:true
- badges: true
- comments: true
- author: Charlie Blake
- categories: [neural-networks, backpropagation]
- image: images/blog/simple-backprop/viz.png

## Making Backprop Simple

The first few times I came across backpropagation I struggled to get a feel for what was going on. It's not just enough
to follow the equations - I couldn't visualise the operations and updates, especially the mysterious backwards pass. 

If someone had shown me then how simple it was to implement the key part of the training algorithm that figures out how to update the weights, then it would have helped me a lot. I understood how the chain rule worked and its relevance here, but I didn't have a picture of it in my head. I got bogged down in the matrix notation and PyTorch tensors and lost sight of what was really going on.

So this is a simple-as-possible backprop implementation, to clear up that confusion. I don't go into the maths; I assume the reader already knows what's going on in theory, but doesn't have a great feel for what happens in practice.

This can also serve as a reference for how to implement this from scratch a clean way. Enjoy!


## Setup

First things first, there's some setup to do. This isn't a tutorial on data loading, so I'm just going to paste some
code for loading up our dataset and we can ignore the details. The only thing worth noting is that we'll be using the
classic *MNIST* dataset:

In [13]:
#collapse-hide
import math
import torch
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.nn.functional import one_hot
from functools import reduce
import altair as alt
import pandas as pd

batch_sz = 64
train = DataLoader(MNIST('data/', train=True, download=True,
                         transform=torchvision.transforms.Compose([
                             torchvision.transforms.ToTensor(),
                             lambda x: torch.flatten(x),
                         ])
                        ), batch_size=batch_sz, shuffle=True)

This `train` dataloader can be iterated over and returns minibatches of shape `(batch__sz, input_dim)`.

In our case, these values are `(64, 28*28=784)`

## Layer-by-layer

We'll construct our neural network by making classes for each layer. We'll use a standard setup of: linear layer, ReLU, linear layer, softmax; plus a cross-entropy loss.

For each layer class we require two methods:

- **`__call__(self, x)`**: implements the forward pass. `__call__` allows us to feed an input through the layer by treating the initialised layer object as a function. For example: `relu_layer = ReLU(); output = relu_layer(input)`.


- **`bp_input(self, grad)`**: implements the backward pass, allowing us to backpropagate the gradient vector through the layer. The `grad` parameter is a matrix of partial derivatives of the loss, with respect to the data sent from the given layer to the next. As such, it is a `(batch_sz, out)` matrix. The job of `bp_input` is to return a `(batch_sz, in)` matrix to be sent to the next layer by multiplying `grad` by the derivative of the forward pass with respect to the _input_ (or an equivalent operation).

There are two other methods we sometimes wish to implement for different layers:

- **`__init__(self, ...)`**: initialises the layer, e.g. weights.


- **`bp_param(self, grad)`**: the "final stop" of the backwards pass. Only applicable for layers with trainable weights. Similar to `bp_input`, but calculates the derivative with respect to the _weights_ of the layer. Should return a matrix with the same shape as the weights (`self.W`) to be updated.

> Important: The key point to recall when visualising this is that when we have a batch dimension it is always the first dimension. For both the forward and backward pass. This makes everything much simpler!

### Linear Layer

Let's start with the linear layer. We do the following:
1. We start by initialising the weights (in this case using the Xavier initialisation).
2. We then implement the call method. Rather than adding an explicit bias, we append a vector of ones to the layer's input (this is equivalent, and makes backprop simpler).
3. Backpropagation with respect to the input is just right multiplication by the transpose of the weight matrix (adjusted to remove the added 1s column)
4. Backpropagation with respect to the output is left multiplication by the transpose of the input matrix.

In [2]:
class LinearLayer:
    def __init__(self, in_sz, out_sz): self.W = self._xavier_init(in_sz + 1, out_sz)  # (in+1, out)
    def _xavier_init(self, i, o): return torch.Tensor(i, o).uniform_(-1, 1) * math.sqrt(6./(i + o))
    
    def __call__(self, X):  # (batch_sz, in)
        self.X = torch.cat([X, torch.ones(X.shape[0], 1)], dim=1)  # (batch_sz, in+1)
        return self.X @ self.W  # (batch_sz, in+1) @ (in+1, out) = (batch_sz, out)
    
    def bp_input(self, grad): return (grad @ self.W.T)[:,:-1]  # (batch_sz, out) @ (out, in) =  (batch_sz, in) 
    def bp_param(self, grad): return self.X.T @ grad  # (in+1, batch_sz) @ (batch_sz, out) = (in+1, out)

### ReLU Layer

Some non-linearity is a must! Bring on the RelU function. The implementation is pretty obvious here. `clamp()` is doing all the work.

In [3]:
class ReLU:
    def __call__(self, X):
        self.X = X
        return X.clamp(min=0)  # (batch_sz, in)
    
    def bp_input(self, grad): return grad * (self.X > 0).float()  # (batch_sz, in)

### Softmax & Cross Entropy Loss

What? Both at once, why would you do this??

This is quite common, and I can justify it in two ways:

1. This layer-loss combination often go together, so why not put them all in one layer? This saves us from having to do two separate forward and backward propagation steps.
2. I won't prove it here, but it turns out that the derivative of the loss with respect to the input to the softmax, is much simpler than the two intermediate derivative operations, and bypasses the numerical stability issues that arise when we do the exponential and the logarithm. Phew!

The downside here is that is we're just doing _inference_ then we only want the softmax output. But for the purposes of this tutorial we only really care about training. So this will do just fine!

There's a trick in the second line of the softmax implementation: it turns out subtracting the argmax from the softmax input keeps the output the same, but the intermediate values are more numerically stable. How neat!

Finally, we examine the backprop step. It's so simple! Our starting grad for backprop (the initial `grad` value passed in is just the ones vector) is the difference in our predicted output vector and the actual one-hot encoded label. This is so intuitive and wonderful.

> Tip: This is exactly the same derivative as when we don't use a softmax layer and apply an MSE loss (i.e. the regression case). We can thus think of softmax + cross entropy as a way of getting to the same underlying backprop, but in the classification case.

In [4]:
class SoftmaxCrossEntropyLoss:  # (batch_sz, in=out) for all dims in this layer
    def __call__(self, X, Y):
        self.Y = Y
        self.Y_prob = self._softmax(X)
        self.loss = self._cross_entropy_loss(Y, self.Y_prob)
        return self.Y_prob, self.loss
    
    def _softmax(self, X):
        self.X = X
        X_adj = X - X.amax(dim=1, keepdim=True)
        exps = torch.exp(X_adj)
        return exps / exps.sum(axis=1, keepdim=True)
    
    def _cross_entropy_loss(self, Y, Y_prob): return (-Y * torch.log(Y_prob)).sum(axis=1).mean()
    
    def bp_input(self, grad): return (self.Y_prob - self.Y) * grad

## Putting it all together

Let's bring these layers together in a class: our `NeuralNet` implementation.

The `evaluate()` function does two things. Firstly, it runs the forward pass by chaining the `__call__()` functions, to generate label probabilities. Secondly, it uses the labels passed to it to calculate the loss and percentage correctly predicted.

> Note: for this simplified example we don't have a pure inference function, but we could add one with a small change to `SoftmaxCrossEntropyLoss`.

The `gradient_descent()` function then gets the matrix of updates for each weight matrix and applies the update. The key bit here is how `backprop()` works. Going backwards through the computation graph we chain the backprop with respect to input methods. Then for each weighted layer we want to update, we apply the backprop with respect to prameters method to the relevant gradient vector. 

In [5]:
class NeuralNet:
    def __init__(self, input_size=28*28, hidden_size=32, output_size=10, alpha=0.001):
        self.alpha = alpha
        self.z1 = LinearLayer(input_size, hidden_size)
        self.a1 = ReLU()
        self.z2 = LinearLayer(hidden_size, output_size)
        self.loss = SoftmaxCrossEntropyLoss()
    
    def evaluate(self, X, Y):
        out = self.z2(self.a1(self.z1(X)))
        correct = torch.eq(out.argmax(axis=1), Y).double().mean()
        Y_prob, loss = self.loss(out, one_hot(Y, 10))
        return Y_prob, correct, loss
    
    def gradient_descent(self):
        delta_W1, delta_W2 = self.backprop()
        self.z1.W -= self.alpha * delta_W1
        self.z2.W -= self.alpha * delta_W2
    
    def backprop(self):
        d_out = torch.ones(*self.loss.Y.shape)
        d_z2 = self.loss.bp_input(d_out)
        d_a1 = self.z2.bp_input(d_z2)
        d_z1 = self.a1.bp_input(d_a1)

        d_w2 = self.z2.bp_param(d_z2)
        d_w1 = self.z1.bp_param(d_z1)
        return d_w1, d_w2

## Training the model

We're almost there! I won't go into this bit too much because this tutorial isn't about training loops, but it's all very standard here.

We break the training data into minibatches and train on them over 10 epochs. The evaluation metrics plotted are those recorded during regular training.

> Warning: these results are only on the training set! In practice we should *always* plot performance on a test set, but we don't want to clutter the tutorial with this extra detail.

In [14]:
#collapse-hide
model = NeuralNet()
stats = {'correct': [], 'loss': [], 'epoch': []}
for epoch in range(10):
    correct, loss = 0, 0
    for i, (X, y) in enumerate(train):
        y_prob, batch_correct, batch_loss = model.evaluate(X, y)
        model.gradient_descent()
        
        correct += batch_correct / len(train)
        loss += batch_loss / len(train)
    stats['correct'].append(correct.item())
    stats['loss'].append(loss.item())
    stats['epoch'].append(epoch)
    print(f'epoch: {epoch} | correct: {correct:.2f}, loss: {loss:.2f}')

base = alt.Chart(pd.DataFrame.from_dict(stats)).mark_line() \
          .encode(alt.X('epoch', axis=alt.Axis(title='epoch')))
line1 = base.mark_line(stroke='#5276A7', interpolate='monotone') \
            .encode(alt.Y('loss'   , axis=alt.Axis(title='Loss'   , titleColor='#5276A7'), scale=alt.Scale(domain=[0.0, max(stats['loss'   ])])), tooltip='loss'   )
line2 = base.mark_line(stroke='#57A44C', interpolate='monotone') \
            .encode(alt.Y('correct', axis=alt.Axis(title='Correct', titleColor='#57A44C'), scale=alt.Scale(domain=[min(stats['correct']), 1.0])), tooltip='correct')
alt.layer(line1, line2).resolve_scale(y = 'independent')

epoch: 0 | correct: 0.87, loss: 0.47
epoch: 1 | correct: 0.92, loss: 0.28
epoch: 2 | correct: 0.93, loss: 0.23
epoch: 3 | correct: 0.94, loss: 0.20
epoch: 4 | correct: 0.95, loss: 0.18
epoch: 5 | correct: 0.95, loss: 0.16
epoch: 6 | correct: 0.96, loss: 0.15
epoch: 7 | correct: 0.96, loss: 0.14
epoch: 8 | correct: 0.96, loss: 0.13
epoch: 9 | correct: 0.96, loss: 0.12


Our results are great! After 10 epochs I'm getting a whopping 97% correct. 

And we've implemented this all from scratch, backprop included, using only 4 notebook cells worth of code. I hope training a neural network now seems less complex  😊