# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

In [200]:
import random
import numpy as np
from torchvision import datasets, transforms

In [201]:
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

--2022-11-20 16:31:42--  https://s3.amazonaws.com/img-datasets/mnist.npz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.64.254, 52.217.202.128, 54.231.202.0, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.64.254|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11490434 (11M) [application/octet-stream]
Saving to: ‘mnist.npz’


2022-11-20 16:31:43 (95.1 MB/s) - ‘mnist.npz’ saved [11490434/11490434]



In [202]:
# Let's read the mnist dataset

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [203]:
def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

This version of backpropagation was taken from google colab notebook for lab 3. The only change is the addition of ability to turn off the logging in SGD function.

In [204]:
class Network(object):
    def __init__(self, sizes):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x) 
                        for x, y in zip(sizes[:-1], sizes[1:])]
    def feedforward(self, a):
        # Run the network on a batch
        a = a.T
        for b, w in zip(self.biases, self.weights):
            a = sigmoid(np.matmul(w, a)+b)
        return a
    
    def update_mini_batch(self, mini_batch, eta):
        # Update networks weights and biases by applying a single step
        # of gradient descent using backpropagation to compute the gradient.
        # The gradient is computed for a mini_batch which is as in tensorflow API.
        # eta is the learning rate      
        nabla_b, nabla_w = self.backprop(mini_batch[0].T,mini_batch[1].T)
            
        self.weights = [w-(eta/len(mini_batch[0]))*nw 
                        for w, nw in zip(self.weights, nabla_w)]
        self.biases = [b-(eta/len(mini_batch[0]))*nb 
                       for b, nb in zip(self.biases, nabla_b)]
        
    def backprop(self, x, y):
        # For a single input (x,y) return a pair of lists.
        # First contains gradients over biases, second over weights.
        g = x
        gs = [g] # list to store all the gs, layer by layer
        fs = [] # list to store all the fs, layer by layer
        for b, w in zip(self.biases, self.weights):
            f = np.dot(w, g)+b
            fs.append(f)
            g = sigmoid(f)
            gs.append(g)
        # backward pass <- both steps at once
        dLdg = self.cost_derivative(gs[-1], y)
        dLdfs = []
        for w,g in reversed(list(zip(self.weights,gs[1:]))):
            dLdf = np.multiply(dLdg,np.multiply(g,1-g))
            dLdfs.append(dLdf)
            dLdg = np.matmul(w.T, dLdf)
        
        dLdWs = [np.matmul(dLdf,g.T) for dLdf,g in zip(reversed(dLdfs),gs[:-1])] 
        dLdBs = [np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1) for dLdf in reversed(dLdfs)] 
        return (dLdBs,dLdWs)

    def evaluate(self, test_data):
        # Count the number of correct answers for test_data
        pred = np.argmax(self.feedforward(test_data[0]),axis=0)
        corr = np.argmax(test_data[1],axis=1).T
        return np.mean(pred==corr)
    
    def cost_derivative(self, output_activations, y):
        return (output_activations-y) 
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None, log=True):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                y_mini_batch = y_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                self.update_mini_batch((x_mini_batch, y_mini_batch), eta)
            if test_data and log:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate((x_test, y_test))))
            elif log:
                print("Epoch: {0}".format(j))


## Memory analysis
Standard backpropagation algorithm stores all activations computed during the forward pass. This version of the algorithm is vectorized, meaning that in single call to backpropagation whole mini batch of data is processed. This means that the activations of all layers will take 
$$O\left(\text{minibatch_size} * ∑_{i=1} s_i\right)$$
memory (where $s_i$ is the size of layer). In other words it is the sum of sizes of all layers, except the first one, times the size of the mini batch. This means that, memory depends on both the number of layers and the size of each layer.

# Checkpointing implementation

In [205]:
class NetworkWithCheckpointing(Network):
    def __init__(self, sizes, checkpoint_every_nth_layer: int = 0, *args, **kwargs):
        super().__init__(sizes, *args, **kwargs)
        self.checkpoint_every_nth_layer = checkpoint_every_nth_layer

    def forward_between_checkpoints(self, a, checkpoint_idx_start, layer_idx_end):
        after_activation = [a]

        for layer_idx in range(checkpoint_idx_start, layer_idx_end):
          layer_product = self.weights[layer_idx] @ after_activation[-1] + self.biases[layer_idx]
          layer_product_sig = sigmoid(layer_product)

          after_activation.append(layer_product_sig)

        return after_activation

    def backprop(self, x, y):
        # if self.checkpoint_every_nth_layer is equal to 0 solution falls back to backward solution implemented in super class.
        if self.checkpoint_every_nth_layer == 0:
          return super().backprop(x, y)

        delta_nabla_b = [np.zeros_like(p) for p in self.biases]
        delta_nabla_w = [np.zeros_like(p) for p in self.weights]

        g = x
        # Activations list stores activations on checkpoints
        activations = [g] 

        for layer_idx in range(self.num_layers - 1):
          f = np.dot(self.weights[layer_idx], g) + self.biases[layer_idx]
          g = sigmoid(f)

          # We append activation to list only every self.checkpoint_every_nth_layer
          if (layer_idx + 1) % self.checkpoint_every_nth_layer == 0:
            activations.append(g)

      
        # Preparation to checkpointed backpropagation
        prev_checkpoint_number = -1
        cur_checkpoint_idx = self.num_layers - 1
        prev_checkpoint_idx = self.num_layers - ((self.num_layers - 1) % self.checkpoint_every_nth_layer) - 1
        position_in_chunk = -1

        # Activations between last checkpoint and end of network are restored
        restored_chunk = self.forward_between_checkpoints(activations[prev_checkpoint_number], prev_checkpoint_idx, cur_checkpoint_idx - 1)
        # Last activation of the network doesn't need to restored, because it is still remembered from forward pass.
        restored_chunk.append(g)
        dLdg = self.cost_derivative(g, y)

        # Backpropagating over network layers
        for layer_idx in range(1, self.num_layers):
          # If current layer is one of the checkpointed ones activations between it and previous checkpoint are restored.
          if (self.num_layers - layer_idx) % self.checkpoint_every_nth_layer == 0:
            prev_checkpoint_number -= 1
            cur_checkpoint_idx = prev_checkpoint_idx - 1
            prev_checkpoint_idx -= self.checkpoint_every_nth_layer
            position_in_chunk = -1

            restored_chunk = self.forward_between_checkpoints(activations[prev_checkpoint_number], prev_checkpoint_idx, cur_checkpoint_idx)
            # Last layer doesn't need to be computed because it is checkpointed.
            restored_chunk.append(activations[prev_checkpoint_number + 1])
            

          # This part is very similar to standard backpropagation,
          # instead of global list of activations there is a list of activations between current checkpoints.

          # Last layer of activation can deleted from activation list, as it will not be used later.
          g = restored_chunk.pop()
          
          dLdf = np.multiply(dLdg,np.multiply(g,1-g))
          dLdg = np.matmul(self.weights[-layer_idx].T, dLdf)

          delta_nabla_w[-layer_idx] = np.matmul(dLdf, restored_chunk[-1].T)
          delta_nabla_b[-layer_idx] = np.sum(dLdf, axis=1).reshape(-1, 1)
          
          position_in_chunk -= 1
        
        return (delta_nabla_b, delta_nabla_w)

## Memory analysis
In this version of backpropagation algorithm we only save every $n$-th activation instead of every single one during the forward pass. Then, during the backward pass missing activations are computed from saved ones. Because of this less memory is used, but running time is longer. for forward pass only every nth layer is saved, so memory usage is 
$$O\left(\sum_{i=1}s_{xi}\right)$$
where $x$ is the `checkpoint_every_nth_layer` paramter. Then during the backward pass, between every two adjacent checkpoints activations are restored using
$$O\left(\sum_{i=c_j}^{c_{j+1}}s_{i}\right)$$
memory (where $c_j$ and $c_{j+1}$ are indexes of checkpoints. After backpropagation crosses checkpoint corresponding activations can be deleted. So the peak memory usage during backward pass is
$$O\left(\max_{j}\sum_{i=c_j}^{c_{j+1}}s_{i}\right)$$
Depending on the shape of the network and number of layers this can be significantly less than in the backward implementation without checkpointing.

# Accuracy comparison
To check, whether the implementation is correct we can run 4 networks (one standard and three with different number of checkpoints) with the same initial set of weights. Since checkpointing should not change the way backpropagation algorithm works the accuracies in all 4 networks should be identical

In [206]:
architecture = [784, 30, 30, 30, 30, 10]
epochs = 5

network = Network(architecture)
weights = network.weights
biases = network.biases
print("Standard network without checkpointing:")
network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 1)
network.weights = weights
network.biases = biases
print("Checkpointing every layer:")
network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 2)
network.weights = weights
network.biases = biases
print("Checkpointing every second layer:")
network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 4)
network.weights = weights
network.biases = biases
print("Checkpointing every 4 layers:")
network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

Standard network without checkpointing:
Epoch: 0, Accuracy: 0.7747
Epoch: 1, Accuracy: 0.8523
Epoch: 2, Accuracy: 0.8748
Epoch: 3, Accuracy: 0.8903
Epoch: 4, Accuracy: 0.9025
Checkpointing every layer:
Epoch: 0, Accuracy: 0.7747
Epoch: 1, Accuracy: 0.8523
Epoch: 2, Accuracy: 0.8748
Epoch: 3, Accuracy: 0.8903
Epoch: 4, Accuracy: 0.9025
Checkpointing every second layer:
Epoch: 0, Accuracy: 0.7747
Epoch: 1, Accuracy: 0.8523
Epoch: 2, Accuracy: 0.8748
Epoch: 3, Accuracy: 0.8903
Epoch: 4, Accuracy: 0.9025
Checkpointing every 4 layers:
Epoch: 0, Accuracy: 0.7747
Epoch: 1, Accuracy: 0.8523
Epoch: 2, Accuracy: 0.8748
Epoch: 3, Accuracy: 0.8903
Epoch: 4, Accuracy: 0.9025


# Running time comparison
For the sake of readability of the output logging of accuracy is disabled during running time comparison.

### Test on simple, small architecture. 
Standard implementation and checkpointing every layer should have similar times (small time differencese can be present due to slight differneces in implementation). Since the network has only one hidden layer, there should be no significant difference between checkpoint every 2 and 4 layers (or more).

In [207]:
architecture = [784, 100, 10]
epochs = 10

network = Network(architecture)
print("Standard network without checkpointing:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 1)
print("Checkpointing every layer:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 2)
print("Checkpointing every second layer:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 4)
print("Checkpointing every 4 layers:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

Standard network without checkpointing:
14.2 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Checkpointing every layer:
15.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Checkpointing every second layer:
20.6 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Checkpointing every 4 layers:
21 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


### Test on longer network
Again there sould be no significan difference between the standard network and network with checkpoint every layer. On network of this size difference in time between checkpoint every second and every fourth layer should start to show.



In [208]:
architecture = [784, 30, 30, 30, 30, 30, 30, 10]
epochs = 10

network = Network(architecture)
print("Standard network without checkpointing:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 1)
print("Checkpointing every layer:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 2)
print("Checkpointing every second layer:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 4)
print("Checkpointing every 4 layers:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

Standard network without checkpointing:
10.8 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Checkpointing every layer:
11.8 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Checkpointing every second layer:
15.1 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Checkpointing every 4 layers:
16.3 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


### Test on very long network


In [209]:
architecture = [784] + ([30] * 50) + [10]
epochs = 10

network = Network(architecture)
print("Standard network without checkpointing:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 1)
print("Checkpointing every layer:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 2)
print("Checkpointing every second layer:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 4)
print("Checkpointing every 4 layers:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = 10)
print("Checkpointing every 10 layers:")
%timeit -n1 -r1 network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)

Standard network without checkpointing:
58.2 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Checkpointing every layer:
1min 3s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Checkpointing every second layer:
1min 16s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Checkpointing every 4 layers:
1min 22s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Checkpointing every 10 layers:
1min 26s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


# Memory usage
Using `tracemalloc` we can compare the usage of memory of different networks. Since checkpointing is used mostly to allow bigger network to fit on devices with smaller memory what we are most interested in is the peak memory usage, and that is the statistics that we will compare here.

In [212]:
import tracemalloc

def check_memory(architecture, checkpoint_every_nth_layer):
  tracemalloc.start()
  if checkpoint_every_nth_layer == 0:
    network = Network(architecture)
  else:
    network = NetworkWithCheckpointing(architecture, checkpoint_every_nth_layer = checkpoint_every_nth_layer)

  network.SGD((x_train, y_train), epochs=epochs, mini_batch_size=100, eta=3., test_data=(x_test, y_test), log=False)
  peak = tracemalloc.get_traced_memory()[1]
  tracemalloc.stop()
  return peak

### Small network
When comparing memory usage on small network the difference in memory isn't very big, but still can be seen.

In [213]:
architecture = [784, 30, 30, 30, 30, 30, 10]
epochs = 2


standard_peak = check_memory(architecture, 0)
print("Standard network memory usage peak: ", standard_peak)

every_2_peak = check_memory(architecture, 2)
print("Checkpoint every 2 layers memory usage peak: ", every_2_peak)

every_4_peak = check_memory(architecture, 4)
print("Checkpoint every 4 layers memory usage peak: ", every_4_peak)

every_10_peak = check_memory(architecture, 10)
print("Checkpoint every 10 layers memory usage peak: ", every_10_peak)

print("Memory saving when using checkpoint every 2 layers:", standard_peak - every_2_peak)
print("Memory saving when using checkpoint every 4 layers:", standard_peak - every_4_peak)
print("Memory saving when using checkpoint every 10 layers:", standard_peak - every_10_peak)

Standard network memory usage peak:  1557491
Checkpoint every 2 layers memory usage peak:  1408902
Checkpoint every 4 layers memory usage peak:  1372008
Checkpoint every 10 layers memory usage peak:  1347150
Memory saving when using checkpoint every 2 layers: 148589
Memory saving when using checkpoint every 4 layers: 185483
Memory saving when using checkpoint every 10 layers: 210341


### Bigger network
Memory usage differences can be seen better on bigger networks, like one below (network is very 'long' and 'thin' to allow for faster working).
Checkpointing does improve memory usage, although the difference between standard network and checkpoints every second layer is significantly bigger than that between checkpoint every second layer and checkpoint every fourth layer.

In [214]:
architecture = [784] + ([30] * 50) + [10]
epochs = 2


standard_peak = check_memory(architecture, 0)
print("Standard network memory usage peak: ", standard_peak)

every_2_peak = check_memory(architecture, 2)
print("Checkpoint every 2 layers memory usage peak: ", every_2_peak)

every_4_peak = check_memory(architecture, 4)
print("Checkpoint every 4 layers memory usage peak: ", every_4_peak)

every_10_peak = check_memory(architecture, 10)
print("Checkpoint every 10 layers memory usage peak: ", every_10_peak)

print("Memory saving when using checkpoint every 2 layers:", standard_peak - every_2_peak)
print("Memory saving when using checkpoint every 4 layers:", standard_peak - every_4_peak)
print("Memory saving when using checkpoint every 10 layers:", standard_peak - every_10_peak)

Standard network memory usage peak:  6012939
Checkpoint every 2 layers memory usage peak:  2660051
Checkpoint every 4 layers memory usage peak:  2347372
Checkpoint every 10 layers memory usage peak:  2177360
Memory saving when using checkpoint every 2 layers: 3352888
Memory saving when using checkpoint every 4 layers: 3665567
Memory saving when using checkpoint every 10 layers: 3835579
