# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

In [1]:
import numpy as np
import time

In [2]:
# Let's read the mnist dataset

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [3]:
def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

def sigmoid_prime(z):
    # Derivative of the sigmoid
    return sigmoid(z)*(1-sigmoid(z))

class Network(object):
    def __init__(self, sizes):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes)
        self.sizes = sizes
        np.random.seed(17)
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        np.random.seed(18)
        self.weights = [np.random.randn(y, x)
                        for x, y in zip(sizes[:-1], sizes[1:])]

    def feedforward(self, a):
        # Run the network on a single case
        for b, w in zip(self.biases, self.weights):
            a = sigmoid(np.dot(w, a) + b)
        return a

    def update_mini_batch(self, x_mini_batch, y_mini_batch, eta):
        # Update networks weights and biases by applying a single step
        # of gradient descent using backpropagation to compute the gradient.
        # The gradient is computed for a mini_batch.
        # eta is the learning rate
        nabla_b = [np.zeros(b.shape) for b in self.biases]
        nabla_w = [np.zeros(w.shape) for w in self.weights]
        for x, y in zip(x_mini_batch, y_mini_batch):
            delta_nabla_b, delta_nabla_w = self.backprop(x.reshape(784, 1), y.reshape(10, 1))
            nabla_b = [nb + dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
            nabla_w = [nw + dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]
        self.weights = [w - (eta / len(x_mini_batch)) * nw
                        for w, nw in zip(self.weights, nabla_w)]
        self.biases = [b - (eta / len(x_mini_batch)) * nb
                       for b, nb in zip(self.biases, nabla_b)]

    def backprop(self, x, y):
        # For a single input (x,y) return a tuple of lists.
        # First contains gradients over biases, second over weights.

        # First initialize the list of gradient arrays
        delta_nabla_b = [np.zeros_like(p) for p in self.biases]
        delta_nabla_w = [np.zeros_like(p) for p in self.weights]

        # Then go forward remembering all values before and after activations
        # in two other array lists
        activation = x
        activations = [x]
        zs = []
        for b, w in zip(self.biases, self.weights):
            z = w @ activation + b
            zs.append(z)
            activation = sigmoid(z)
            activations.append(activation)

        # Now go backward from the final cost applying backpropagation
        for l in range(1, self.num_layers):
            if l == 1:
                # chain rule: dC/db = dC/da * da/dz * dz/db = delta * dz/db
                # dC/dw = dC/da * da/dz * dz/dw = delta * dz/dw
                delta = self.cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1])
            else:
                # dC/dz_2 = dC/dz_1 * dz_1/da_2 * da_2/dz_2
                delta = self.weights[-l + 1].transpose() @ delta * sigmoid_prime(zs[-l])
            delta_nabla_b[-l] = delta
            delta_nabla_w[-l] = delta @ activations[-l - 1].transpose()
        return delta_nabla_b, delta_nabla_w

    def evaluate(self, x_test_data, y_test_data):
        # Count the number of correct answers for test_data
        test_results = [(np.argmax(self.feedforward(x_test_data[i].reshape(784, 1))), np.argmax(y_test_data[i]))
                        for i in range(len(x_test_data))]
        # return accuracy
        return np.mean([int(x == y) for (x, y) in test_results])

    def cost_derivative(self, output_activations, y):
        return output_activations - y

    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[i * mini_batch_size:(i * mini_batch_size + mini_batch_size)]
                y_mini_batch = y_train[i * mini_batch_size:(i * mini_batch_size + mini_batch_size)]
                self.update_mini_batch(x_mini_batch, y_mini_batch, eta)
            if test_data:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate(x_test, y_test)))
            else:
                print("Epoch: {0}".format(j))


Let's have a look at the backprop function. It is easy to notice that after the forward pass we keep in memory vectors of values before and after activations for each non-input layer. We also store a vector delta_nabla_b of size equal to the number of biases of the network and a vector delta_nabla_w of size equal to the number of weights of the network. All in all, for a network with k non-input layers, each containing at most n neurons, we use extra space equal to roughly 2n*k + biases.size() + weights.size().


2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.


In [4]:
class NetworkWithCheckpointing(Network):

    def __init__(self, sizes, checkpoint_every_nth_layer: int = 0, *args, **kwargs):
        super().__init__(sizes, *args, **kwargs)
        self.n = max(checkpoint_every_nth_layer, 1) # no checkpointing if input is 0 or smaller

    def forward_between_checkpoints(self, a, checkpoint_idx_start, layer_idx_end, checkpoint=False):
        """
        Given activations of a layer, the index of a starting layer and the index of the ending
        layer, compute forward pass and save values before activation in a dict (layer index -> values).
        If checkpoint is True, only every n-th value is saved
        """
        zs = {}
        i = checkpoint_idx_start
        for b, w in zip(self.biases[checkpoint_idx_start - 1:layer_idx_end],
                        self.weights[checkpoint_idx_start - 1:layer_idx_end]):
            z = w @ a + b
            a = sigmoid(z)
            if not checkpoint or i % self.n == 0:
                zs[i] = z
            i += 1
        return zs

    def backprop(self, x, y):
        def _compute_z(i, i_to_z, i_to_z_recomputed):
            """ Given index of a target layer, a dictionary of checkpoint values and a dictionary
            of values computed from last checkpoint, compute the values of the target layer before activation"""
            if i in i_to_z.keys():
                new_z = i_to_z[i]
                del i_to_z[i]
            elif i in i_to_z_recomputed:
                new_z = i_to_z_recomputed[i]
            else:
                start_i = max(i_to_z.keys()) if len(i_to_z) else 0
                start_value = x if start_i == 0 else sigmoid(i_to_z[start_i])
                i_to_z_recomputed = self.forward_between_checkpoints(start_value, start_i + 1, i)
                new_z = i_to_z_recomputed[i]
            return new_z, i_to_z, i_to_z_recomputed

        delta_nabla_b = [np.zeros_like(p) for p in self.biases]
        delta_nabla_w = [np.zeros_like(p) for p in self.weights]

        zs = self.forward_between_checkpoints(x, 1, self.num_layers - 1, True)
        zs_recomputed = {}
        z, zs, zs_recomputed = _compute_z(self.num_layers - 1, zs, zs_recomputed)
        for l in range(1, self.num_layers):
            if l == 1:
                delta = self.cost_derivative(sigmoid(z), y) * sigmoid_prime(z)
            else:
                delta = self.weights[-l + 1].transpose() @ delta * sigmoid_prime(z)
            if l != self.num_layers - 1:
                next_z, zs, zs_recomputed = _compute_z(self.num_layers - l - 1, zs, zs_recomputed)
                next_a = sigmoid(next_z)
            else:
                next_a = x
            delta_nabla_b[-l] = delta
            delta_nabla_w[-l] = delta @ next_a.transpose()
            z = next_z
        return delta_nabla_b, delta_nabla_w

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.

Let's first analyse memory usage. Similarly as in the network without checkpointing, we store vectors delta_nabla_b and delta_nabla_w with sizes equal to the numbers of biases and weights respectively. For a checkpointing parameter c and a number of layers k, we keep in memory at most k/c checkpointed vectors of values before activation. Moreover, at each time we keep in memory the number of recomputed vectors before activation bounded by c. All in all, for a network with k non-input layers, each containing at most n neurons and a checkpointing parameter c, we use extra space roughly equal to c*n + n*k/c + biases.size() + weights.size().

Let's now look at the running times. Since running times should be similar for different epochs, we will only run one epoch.

In [5]:
network = Network([784, 30, 30, 30, 30, 30, 30, 30, 30, 10])
st = time.time()
network.SGD((x_train, y_train), epochs=1, mini_batch_size=100, eta=3., test_data=(x_test, y_test))
et = time.time()
print(f"Running time: {et-st}")

Epoch: 0, Accuracy: 0.34
Running time: 52.688164472579956


In [7]:
times = []
for x in range(2, 6):
    print(f"---Network with checkpoints every {x} layers---")
    network = NetworkWithCheckpointing([784, 30, 30, 30, 30, 30, 30, 30, 30, 10], x)
    st = time.time()
    network.SGD((x_train, y_train), epochs=1, mini_batch_size=100, eta=3., test_data=(x_test, y_test))
    et = time.time()
    times.append(et-st)
    print(f"Running time: {et-st}")

---Network with checkpoints every 2 layers---
Epoch: 0, Accuracy: 0.34
Running time: 65.82348823547363
---Network with checkpoints every 3 layers---
Epoch: 0, Accuracy: 0.34
Running time: 65.10165405273438
---Network with checkpoints every 4 layers---
Epoch: 0, Accuracy: 0.34
Running time: 67.4464328289032
---Network with checkpoints every 5 layers---
Epoch: 0, Accuracy: 0.34
Running time: 64.5918242931366


In the above example we see very similar running times for all checkpointing parameters in [2, 5]. The running times of the networks with checkpointing are larger compared to the regular network, but in general we expect them to be at most twice as large. That is because when using checkpointing we compute the values of each layer at most twice (once in the forward pass and once using the closest checkpoint).