# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

In [5]:
import random
import numpy as np
from torchvision import datasets, transforms

In [8]:
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

--2022-11-06 21:42:23--  https://s3.amazonaws.com/img-datasets/mnist.npz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.62.72
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.62.72|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11490434 (11M) [application/octet-stream]
Saving to: ‘mnist.npz’


2022-11-06 21:42:35 (997 KB/s) - ‘mnist.npz’ saved [11490434/11490434]



In [9]:
# Let's read the mnist dataset

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [61]:
!pip install psutil

3863.35s - pydevd: Sending message related to process being replaced timed-out after 5 seconds



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m22.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [62]:
import time
import os
import psutil


def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

class Network(object):
    def __init__(self, sizes):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x) 
                        for x, y in zip(sizes[:-1], sizes[1:])]
    def feedforward(self, a):
        # Run the network on a single case
        for b, w in zip(self.biases, self.weights):
            a = sigmoid(np.dot(w, a)+b)
        return a
    
    def update_mini_batch(self, x_mini_batch, y_mini_batch, eta):
        # Update networks weights and biases by applying a single step
        # of gradient descent using backpropagation to compute the gradient.
        # The gradient is computed for a mini_batch.
        # eta is the learning rate
        nabla_b, nabla_w = self.backprop(np.expand_dims(x_mini_batch, axis=-1), np.expand_dims(y_mini_batch, axis=-1))
        self.weights = [w-(eta)*nw 
                        for w, nw in zip(self.weights, nabla_w)]
        self.biases = [b-(eta)*nb 
                       for b, nb in zip(self.biases, nabla_b)]
        
    def backprop(self, x_batch, y_batch):
        # For a single input (x,y) return a tuple of lists.
        # First contains gradients over biases, second over weights.
        
        # First initialize the list of gradient arrays
        delta_nabla_b = [None for _ in range(len(self.biases))]
        delta_nabla_w = [None for _ in range(len(self.weights))]
        
        # Then go forward remembering all values before and after activations
        # in two other array lists
        before = []
        after = []

        for l in range(self.num_layers - 1):
            before.append(x_batch)
            x_batch = sigmoid(self.weights[l] @ x_batch + self.biases[l])
            after.append(x_batch)
        
        # Now go backward from the final cost applying backpropagation
        d_loss_act = after[self.num_layers - 2] - y_batch
        for l in range(self.num_layers - 2, -1, -1):
            delta_nabla_b[l] = d_loss_act * after[l] * (1 - after[l])
            delta_nabla_w[l] = delta_nabla_b[l] @ np.transpose(before[l], (0, 2, 1))
            d_loss_act = self.weights[l].T @ delta_nabla_b[l]

        
        return [b.mean(axis=0) for b in delta_nabla_b], [w.mean(axis=0) for w in delta_nabla_w]

    def evaluate(self, x_test_data, y_test_data):
        # Count the number of correct answers for test_data
        test_results = [(np.argmax(self.feedforward(x_test_data[i].reshape(784,1))), np.argmax(y_test_data[i]))
                        for i in range(len(x_test_data))]
        # return accuracy
        return np.mean([int(x == y) for (x, y) in test_results])
    
    def cost_derivative(self, output_activations, y):
        return (output_activations-y) 
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        process = psutil.Process(os.getpid())
        memusage = []
        updatetime = []
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                t0 = time.time()
                memusage.append(process.memory_info().rss)
                x_mini_batch = x_train[i*mini_batch_size:(i*mini_batch_size + mini_batch_size)] 
                y_mini_batch = y_train[i*mini_batch_size:(i*mini_batch_size + mini_batch_size)] 
                self.update_mini_batch(x_mini_batch, y_mini_batch, eta)
                updatetime.append(time.time() - t0)
            if test_data:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate(x_test, y_test)))
            else:
                print("Epoch: {0}".format(j))
        return memusage, updatetime

In [63]:
from typing import List


class NetworkWithCheckpointing(Network):

    def __init__(self, sizes, checkpoint_every_nth_layer: int = 0, *args, **kwargs):
        super().__init__(sizes, *args, **kwargs)
        self.checkpoint_every_nth_layer = checkpoint_every_nth_layer
        self.checkpoints_layer_idx = []

        # initialize checkpoints_layer_idx
        check_idx_layer = 0
        while check_idx_layer < self.num_layers - 1:
            self.checkpoints_layer_idx.append(check_idx_layer)
            check_idx_layer += checkpoint_every_nth_layer

    def forward_between_checkpoints(self, a, layer_idx_start, layer_idx_end) -> List[np.ndarray]:
        # I'm not sure how var 'a' should be used - not function documentation
        after = [a]
        for i in range(layer_idx_start, layer_idx_end):
            a = sigmoid(self.weights[i] @ a + self.biases[i])
            after.append(a)

        return after

    def backprop(self, x, y):

        delta_nabla_b = [None for _ in range(len(self.biases))]
        delta_nabla_w = [None for _ in range(len(self.weights))]
    
        # go forward saving checkpoints
        a = x
        next_checkpoint_idx = 0
        checkpoints = []
        for l in range(self.num_layers-1):
            if self.checkpoints_layer_idx[next_checkpoint_idx] == l:
                checkpoints.append(a)
                if next_checkpoint_idx < len(self.checkpoints_layer_idx) - 1:
                    next_checkpoint_idx += 1
                    if next_checkpoint_idx == len(self.checkpoints_layer_idx):
                        break

            a = sigmoid(self.weights[l] @ a + self.biases[l])

        # go backward occasionaly recomputing activations using 'forward_between_checkpoints'
        d_loss_act = None # derivative of loss with respect to current activation
        stop_layer_idx = self.num_layers - 1
        for j in reversed(range(len(self.checkpoints_layer_idx))):
            checkpoint_activation = checkpoints[j]
            checkpoint_layer_idx = self.checkpoints_layer_idx[j]
            activations = self.forward_between_checkpoints(
                checkpoint_activation,
                checkpoint_layer_idx,
                stop_layer_idx
                )
            stop_layer_idx = checkpoint_layer_idx
            
            for i in reversed(range(len(activations)-1)):
                before, after = activations[i], activations[i+1]
                l_idx = checkpoint_layer_idx + i
                # when we have acces to the last layer initialize 'd_loss_act'
                if d_loss_act is None:
                    d_loss_act = after - y
                
                delta_nabla_b[l_idx] = d_loss_act * after * (1 - after)
                delta_nabla_w[l_idx] = delta_nabla_b[l_idx] @ np.transpose(before, (0, 2, 1))
                d_loss_act = self.weights[l_idx].T @ delta_nabla_b[l_idx]

        return [b.mean(axis=0) for b in delta_nabla_b], [w.mean(axis=0) for w in delta_nabla_w]


In [82]:
network = Network([784,40,30,10])
memusage1, updatetimes1 = network.SGD((x_train, y_train), epochs=10, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.7997
Epoch: 1, Accuracy: 0.8595
Epoch: 2, Accuracy: 0.8808
Epoch: 3, Accuracy: 0.8952
Epoch: 4, Accuracy: 0.9043
Epoch: 5, Accuracy: 0.9089
Epoch: 6, Accuracy: 0.9132
Epoch: 7, Accuracy: 0.9178
Epoch: 8, Accuracy: 0.9201
Epoch: 9, Accuracy: 0.9216


In [83]:
import plotly.express as px
from itertools import accumulate
import operator as op


cumupdatetimes1 = list(accumulate(updatetimes1, op.add))

px.line({"mem usage (MB)": [m / 1e6 for m in memusage1], "n_update": list(range(len(memusage1)))}, x="n_update", y="mem usage (MB)", title='Current memory usage per weight update').show()
px.line({"update time (s)": cumupdatetimes1, "n_update": range(len(cumupdatetimes1))}, x="n_update", y="update time (s)", title='Update time per weight update').show()

In [84]:
network = NetworkWithCheckpointing([784,40,30,10], checkpoint_every_nth_layer=2)
memusage2, updatetimes2 = network.SGD((x_train, y_train), epochs=10, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.8091
Epoch: 1, Accuracy: 0.8663
Epoch: 2, Accuracy: 0.8859
Epoch: 3, Accuracy: 0.8991
Epoch: 4, Accuracy: 0.9066
Epoch: 5, Accuracy: 0.9117
Epoch: 6, Accuracy: 0.9172
Epoch: 7, Accuracy: 0.9209
Epoch: 8, Accuracy: 0.9241
Epoch: 9, Accuracy: 0.9274


In [85]:
cumupdatetimes2 = list(accumulate(updatetimes2, op.add))

px.line({"mem usage (MB)": [m / 1e6 for m in memusage2], "n_update": list(range(len(memusage2)))}, x="n_update", y="mem usage (MB)", title='Current memory usage per weight update').show()
px.line({"update time (s)": cumupdatetimes2, "n_update": range(len(cumupdatetimes2))}, x="n_update", y="update time (s)", title='Update time per weight update').show()

We can clearly observe memory / computation time tradeoff.