# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

In [1]:
import random
import numpy as np
import math
from torchvision import datasets, transforms

In [2]:
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

--2022-11-14 10:08:23--  https://s3.amazonaws.com/img-datasets/mnist.npz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 54.231.132.224
Connecting to s3.amazonaws.com (s3.amazonaws.com)|54.231.132.224|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11490434 (11M) [application/octet-stream]
Saving to: ‘mnist.npz’


2022-11-14 10:08:23 (97.4 MB/s) - ‘mnist.npz’ saved [11490434/11490434]



In [3]:
# Let's read the mnist dataset

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [4]:
def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

def sigmoid_prime(z):
    # Derivative of the sigmoid
    return sigmoid(z)*(1-sigmoid(z))

class Network(object):
    def __init__(self, sizes):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x) 
                        for x, y in zip(sizes[:-1], sizes[1:])]

    def feedforward(self, a):
        for b, w in zip(self.biases, self.weights):
            a = sigmoid(np.dot(w, a)+b)
        return a
    
    def update_mini_batch(self, x_mini_batch, y_mini_batch, eta):
        delta_nabla_b, delta_nabla_w = self.backprop(x_mini_batch.T,y_mini_batch.T)
        self.weights = [w-(eta/len(x_mini_batch))*nw 
                        for w, nw in zip(self.weights, delta_nabla_w)]
        self.biases = [b-(eta/len(x_mini_batch))*nb 
                       for b, nb in zip(self.biases, delta_nabla_b)]
        
    def cost_derivative(self, output_activations, y):
        return (output_activations-y) 

    def backprop(self, x, y):
        delta_nabla_b = [np.zeros_like(p) for p in self.biases]
        delta_nabla_w = [np.zeros_like(p) for p in self.weights]
        
        gs = [x]

        for b, w in zip(self.biases, self.weights):
          x = np.dot(w, x)+b
          x = sigmoid(x)
          gs.append(x)

        delta_g = self.cost_derivative(x, y)

        i = 1
        for b, w in zip(self.biases, self.weights):
          delta_f = np.multiply(np.multiply(delta_g, gs[-i]), (np.ones_like(gs[-i]) - gs[-i]))

          delta_nabla_w[-i] = delta_f @ (gs[-(i+1)].T)
          delta_nabla_b[-i] = np.sum(delta_f, axis=1).reshape(len(self.biases[-i]), 1)

          delta_g = self.weights[-i].T @ delta_f
          i += 1
        
        return delta_nabla_b, delta_nabla_w

    def evaluate(self, x_test_data, y_test_data):
        # Count the number of correct answers for test_data
        test_results = [(np.argmax(self.feedforward(x_test_data[i].reshape(784,1))), np.argmax(y_test_data[i]))
                        for i in range(len(x_test_data))]
        # return accuracy
        return np.mean([int(x == y) for (x, y) in test_results])
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[i*mini_batch_size:(i*mini_batch_size + mini_batch_size)] 
                y_mini_batch = y_train[i*mini_batch_size:(i*mini_batch_size + mini_batch_size)] 
                self.update_mini_batch(x_mini_batch, y_mini_batch, eta)
            if test_data:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate(x_test, y_test)))
            else:
                print("Epoch: {0}".format(j))


network = Network([784,100,30,20,10])
network.SGD((x_train, y_train), epochs=10, mini_batch_size=100, eta=3., test_data=(x_test, y_test))


Epoch: 0, Accuracy: 0.7828
Epoch: 1, Accuracy: 0.8629
Epoch: 2, Accuracy: 0.8856
Epoch: 3, Accuracy: 0.8972
Epoch: 4, Accuracy: 0.9067
Epoch: 5, Accuracy: 0.912
Epoch: 6, Accuracy: 0.9169
Epoch: 7, Accuracy: 0.9207
Epoch: 8, Accuracy: 0.9246
Epoch: 9, Accuracy: 0.9278


In [5]:
class NetworkWithCheckpointing(Network):

    def __init__(self, sizes, checkpoint_interval: int = 2, *args, **kwargs):
        super().__init__(sizes, *args, **kwargs)
        self.checkpoint_interval = checkpoint_interval
        assert self.checkpoint_interval > 0
        checkpoint_count = len(sizes) // self.checkpoint_interval

    def forward_between_checkpoints(self, checkpoints_g, checkpoint_idx_start, layer_idx_start, layer_idx_end):
        gs = [checkpoints_g[checkpoint_idx_start]]
        a = gs[0]
        
        for b, w in zip(self.biases[layer_idx_start : layer_idx_end ], self.weights[layer_idx_start : layer_idx_end]):
          a = sigmoid(np.dot(w, a)+b)
          gs.append(a)

        gs.reverse()
        return gs

    def backprop(self, x, y):
        delta_nabla_b = [np.zeros_like(p) for p in self.biases]
        delta_nabla_w = [np.zeros_like(p) for p in self.weights]

        checkpoints_g = []
        checkpoints_g.append(x)
        # 1st checkpoint is at index=0, 2nd at index=checkpoint_interval, 3rd at index=2*checkpoint_interval etc
        i = 1
        for b, w in zip(self.biases, self.weights):
          x = sigmoid(np.dot(w, x)+b)
          if (i % self.checkpoint_interval == 0 and i != len(self.biases)): # no checkpoint on last layer, just keep variable x
            checkpoints_g.append(x)
          i += 1

        delta_g = self.cost_derivative(x, y)
        current_layer = len(self.biases)
        current_checkpoint = len(checkpoints_g) - 1
        while (current_layer > 0):
          gs = self.forward_between_checkpoints(checkpoints_g, current_checkpoint, current_checkpoint * self.checkpoint_interval, current_layer)

          for j in range(len(gs) - 1):
            delta_f = np.multiply(np.multiply(delta_g, gs[j]), (np.ones_like(gs[j]) - gs[j]))

            delta_nabla_w[current_layer - 1] = delta_f @ (gs[j + 1].T)
            delta_nabla_b[current_layer - 1] = np.sum(delta_f, axis=1).reshape(len(self.biases[current_layer - 1]), 1)

            delta_g = self.weights[current_layer - 1].T @ delta_f
            current_layer -= 1
            
          current_checkpoint -= 1
        
        return delta_nabla_b, delta_nabla_w

network = NetworkWithCheckpointing([784,200,30,20,10], 2)
network.SGD((x_train, y_train), epochs=10, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.7871
Epoch: 1, Accuracy: 0.862
Epoch: 2, Accuracy: 0.8883
Epoch: 3, Accuracy: 0.9025
Epoch: 4, Accuracy: 0.9105
Epoch: 5, Accuracy: 0.9158
Epoch: 6, Accuracy: 0.9201
Epoch: 7, Accuracy: 0.9219
Epoch: 8, Accuracy: 0.9243
Epoch: 9, Accuracy: 0.9259


N = total number of layers

**Memory usage:**

Decreased roughly by a factor of checkpoint_interval when using checkpoint as opposed to keeping all values in memory.

Layers kept in memory without checkpoints: N

Layers kept in memory with checkpoints: N / checkpoint_interval + checkpoint_interval

**Running time:**

The major difference is that in the checkpoint version backprop effectively needs to calculate the forward pass again, in checkpoint_interval-sized chunks, so total work is significantly increased. 

Operations without checkpoints: 1N forward pass ops + N backprop ops

Operations with checkpoints: 2N forward pass ops + 1N backprop ops