<center><img src='https://drive.google.com/uc?id=1_utx_ZGclmCwNttSe40kYA6VHzNocdET' height="60"></center>

AI TECH - Akademia Innowacyjnych Zastosowań Technologii Cyfrowych. Program Operacyjny Polska Cyfrowa na lata 2014-2020
<hr>

<center><img src='https://drive.google.com/uc?id=1BXZ0u3562N_MqCLcekI-Ens77Kk4LpPm'></center>

<center>
Projekt współfinansowany ze środków Unii Europejskiej w ramach Europejskiego Funduszu Rozwoju Regionalnego
Program Operacyjny Polska Cyfrowa na lata 2014-2020,
Oś Priorytetowa nr 3 "Cyfrowe kompetencje społeczeństwa" Działanie  nr 3.2 "Innowacyjne rozwiązania na rzecz aktywizacji cyfrowej"
Tytuł projektu:  „Akademia Innowacyjnych Zastosowań Technologii Cyfrowych (AI Tech)”
    </center>

# Laboratory Scenario 2 - Backpropagation and Gradient Checkpointing

In this lab scenario, you are given an implementation of a simple neural network, and your goal is to implement the backpropagation procedure for this network.  
To be more precise, the network inputs a tensor $x$ of shape `(MINI_BATCH_SIZE, 28*28)`, where each element of the batch represents a flattened image of shape `(28, 28)`.  
In exercise 1, you can assume that elements of the minibatch are fed to the network one by one (as tensors of shape `(1, 28*28)` - single image and `(1, 10)` - image class).  
In exercise 2 you are asked to make the backpropagation work without this assumption.  
In exercise 3, you will implement a technique called gradient checkpointing, that allows you to reduce the amount of memory used to store activations for backpropagation.

In [None]:
import random
import numpy as np
from torchvision import datasets, transforms
from typing import List, Any, Tuple, Optional
from numpy.typing import NDArray

In [None]:
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

In [None]:
# Let's read the mnist dataset


def load_mnist(path="mnist.npz"):
    with np.load(path) as f:
        x_train, _y_train = f["x_train"], f["y_train"]
        x_test, _y_test = f["x_test"], f["y_test"]

    x_train = x_train.reshape(-1, 28 * 28) / 255.0
    x_test = x_test.reshape(-1, 28 * 28) / 255.0

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)


(x_train, y_train), (x_test, y_test) = load_mnist()

## Exercise 1

In this exercise, your task is to fill in the gaps in this code by implementing the backpropagation algorithm.
Once done, you can run the network on the MNIST example and see how it performs.  
Feel free to play with the parameters. Your model should achieve 90%+ accuracy after a few epochs.  

Before you start you should note a few things:
+ `backprop` - is the function that you need to implement
+ `update_mini_batch` - calls `backprop` to get the gradients for network parameters
+ The derivative of the loss is already computed by `cost_derivative`.
+ Your goal is to compare $\frac{d L\left(\text{model}(x), y\right)}{d p}$ for each parameter $p$ of the network


## Exercise 2 (Optional)

Implement a "fully vectorized" version, i.e. one using matrix operations instead of going over examples one by one within a minibatch.

## Help required?
At the end of this notebook, we show how you can utilize `JAX` to check whether you implemented the derivative computation correctly.


In [None]:
def sigmoid(z: NDArray[float]):
    return 1.0 / (1.0 + np.exp(-z))


def sigmoid_prime(z: NDArray[float]):
    # Derivative of the sigmoid
    return sigmoid(z) * (1 - sigmoid(z))


class Network(object):
    def __init__(self, sizes: List[int]):
        # initialize biases and weights with random normal distr.
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y) for y in sizes[1:]]
        self.weights = [np.random.randn(x, y) for x, y in zip(sizes[:-1], sizes[1:])]

    def feedforward(self, a: NDArray[float]) -> NDArray[float]:
        # Run the network on a single case
        for b, w in zip(self.biases, self.weights):
            a = sigmoid(a @ w + b)

        return a

    def update_mini_batch(
        self, x_mini_batch: NDArray[float], y_mini_batch: NDArray[float], eta: float
    ) -> None:
        # Update network weights and biases by applying a single step
        # of gradient descent using backpropagation to compute the gradient.
        # The gradient is computed for a mini_batch.
        # eta is the learning rate
        nabla_b = [np.zeros(b.shape) for b in self.biases]
        nabla_w = [np.zeros(w.shape) for w in self.weights]
        for x, y in zip(x_mini_batch, y_mini_batch):
            delta_nabla_b, delta_nabla_w = self.backprop(
                x.reshape(1, 784), y.reshape(1, 10)
            )
            nabla_b = [nb + dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
            nabla_w = [nw + dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]

        self.weights = [
            w - (eta / len(x_mini_batch)) * nw for w, nw in zip(self.weights, nabla_w)
        ]
        self.biases = [
            b - (eta / len(x_mini_batch)) * nb for b, nb in zip(self.biases, nabla_b)
        ]

    def backprop(
        self, x: NDArray[float], y: NDArray[float]
    ) -> Tuple[List[NDArray[float]], List[NDArray[float]]]:
        # For a single input (x,y) return a tuple of lists.
        # First contains gradients over biases, second over weights.

        assert len(x.shape) == 2  # batch, features
        assert len(y.shape) == 2  # batch, classes
        assert x.shape[0] == y.shape[0]

        # First initialize the list of gradient arrays
        delta_nabla_b = []
        delta_nabla_w = []

        # Then go forward remembering each layer input and value
        # before sigmoid activation
        # TODO
        ###{
        pass
        ###}

        # Now go backward from the final cost applying backpropagation
        # hint: you can use reversed(list(zip(a, b, ...)))
        # TODO
        ###{
        pass
        ###}

        # Check shapes
        delta_nabla_b = list(delta_nabla_b)
        delta_nabla_w = list(delta_nabla_w)
        assert len(delta_nabla_b) == len(self.biases), (
            len(delta_nabla_b),
            len(self.biases),
        )
        assert len(delta_nabla_w) == len(self.weights), (
            len(delta_nabla_w),
            len(self.weights),
        )
        for lid in range(len(self.weights)):
            assert delta_nabla_b[lid].shape == self.biases[lid].shape, (
                delta_nabla_b[lid].shape,
                self.biases[lid].shape,
            )
            assert delta_nabla_w[lid].shape == self.weights[lid].shape, (
                delta_nabla_w[lid].shape,
                self.weights[lid].shape,
            )

        return delta_nabla_b, delta_nabla_w

    def evaluate(
        self, x_test_data: NDArray[float], y_test_data: NDArray[float]
    ) -> float:
        # Count the number of correct answers for test_data
        test_results = [
            (
                np.argmax(self.feedforward(x_test_data[i].reshape(1, 784)), axis=-1),
                np.argmax(y_test_data[i], axis=-1),
            )
            for i in range(len(x_test_data))
        ]
        # return accuracy
        return np.mean([int((x == y).item()) for (x, y) in test_results]).item()

    def cost_derivative(
        self, output_activations: NDArray[float], y: NDArray[float]
    ) -> NDArray[float]:
        assert output_activations.shape == y.shape, (output_activations.shape, y.shape)
        return output_activations - y

    def SGD(
        self,
        training_data: Tuple[NDArray[float], NDArray[float]],
        epochs: int,
        mini_batch_size: int,
        eta: float,
        test_data: Optional[Tuple[NDArray[float], NDArray[float]]] = None,
    ) -> None:
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[
                    i * mini_batch_size : (i * mini_batch_size + mini_batch_size)
                ]
                y_mini_batch = y_train[
                    i * mini_batch_size : (i * mini_batch_size + mini_batch_size)
                ]
                self.update_mini_batch(x_mini_batch, y_mini_batch, eta)
            if test_data:
                print(
                    "Epoch: {0}, Accuracy: {1}".format(j, self.evaluate(x_test, y_test))
                )
            else:
                print("Epoch: {0}".format(j))


network = Network([784, 30, 10])
network.SGD(
    (x_train, y_train),
    epochs=10,
    mini_batch_size=100,
    eta=3.0,
    test_data=(x_test, y_test),
)

# Excercise 3 (optional)

The standard backpropagation method requires memorization of all outputs of all layers, which can take much of precious GPU memory.
Instead of doing that, one can memorize only a select few layers and then recompute the rest as they are needed.  
Your task is to complete the code below to implement backpropagation with checkpoints.
To keep things simple, use 1-example mini-batches (or, if you are bored, vectorize the code below)

In [None]:
class NetworkWithCheckpoints(object):
    def __init__(self, sizes: List[int], checkpoints: List[int]):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes) - 1
        self.sizes = sizes
        self.checkpoints = list(
            sorted(list(set([0] + checkpoints + [self.num_layers - 1])))
        )
        self.biases = [np.random.randn(y) for y in sizes[1:]]
        self.weights = [np.random.randn(x, y) for x, y in zip(sizes[:-1], sizes[1:])]

    def feedforward(self, a: NDArray[float]) -> NDArray[float]:
        # Run the network on a single case
        for b, w in zip(self.biases, self.weights):
            a = sigmoid(a @ w + b)
        return a

    def feedforward_with_checkpoints(
        self, x: NDArray[float]
    ) -> Tuple[List[NDArray[float]], List[NDArray[float]], NDArray[float]]:
        # Runs network on a single case, memorizing the inputs of layers included in checkpoints.
        # Notice that gs (outputs of non-linearities) are shifted by one
        layer_input = []
        before_act = []
        for i, (w, b) in enumerate(zip(self.weights, self.biases)):
            f = x @ w + b
            g = sigmoid(f)
            if i in self.checkpoints:
                layer_input.append(x)
                before_act.append(f)
            else:
                layer_input.append(None)
                before_act.append(None)
            x = g
        return layer_input, before_act, x

    def feedforward_between_layers(
        self,
        first_layer: int,
        last_layer: int,
        acc_layer_input: List[NDArray[float]],
        acc_before_act: List[NDArray[float]],
    ) -> None:
        # feedforward input acc_layer_input[first_layer] for layers [first_layer, last_layer)
        # memorizing their outputs in respective indexes of acc_layer_input, acc_before_act
        # that is for a layer lid \in {first_layer, ..., last_layer-1}
        # acc_layer_input[lid] is the layer input (before linear projection)
        # acc_before_act[lid] is the input to sigmoid activation, that is
        # acc_before_act[lid] = acc_layer_input[lid] @ self.weights[lid] + self.biases[lid]
        # TODO
        ###{
        pass
        ###}

    def backprop_between_layers(
        self,
        start: int,
        end: int,
        acc_layer_input: List[NDArray[float]],
        acc_before_act: List[NDArray[float]],
        dLdg: NDArray[float],
    ) -> Tuple[List[NDArray[float]], List[NDArray[float]], NDArray[float]]:
        # compute the gradients for layers [start, end)
        # dLdg is a gradient with respect to the output (nonlinearity) of layer[end-1]
        # return changed dLdG so that it is a gradient with respect to acc_layer_input[start]
        # that is the input of layer[start] (in other words output of layer[start - 1])
        dLdWs = []
        dLdBs = []

        # TODO
        ###{
        pass
        ###}

        # Checking shapes
        dLdWs = list(dLdWs)
        dLdBs = list(dLdBs)
        assert len(dLdWs) == len(dLdBs), (len(dLdWs), len(dLdBs))
        assert len(dLdWs) == end - start, (len(dLdWs), start, end)

        for lid in range(start, end):
            assert dLdWs[lid - start].shape == self.weights[lid].shape, (
                dLdWs[lid - start].shape,
                self.weights[lid].shape,
            )
            assert dLdBs[lid - start].shape == self.biases[lid].shape, (
                dLdBs[lid - start].shape,
                self.biases[lid].shape,
            )

        return dLdWs, dLdBs, dLdg

    def update_mini_batch(
        self, x_mini_batch: NDArray[float], y_mini_batch: NDArray[float], eta: float
    ) -> None:
        # Update network weights and biases by applying a single step
        # of gradient descent using backpropagation with checkpoints to compute the gradient.
        # For this exercise, we assume 1 element mini_batch
        # eta is the learning rate
        x_mini_batch = x_mini_batch.reshape(1, -1)  # batch, features
        y_mini_batch = y_mini_batch.reshape(1, -1)

        layer_input, before_act, output = self.feedforward_with_checkpoints(
            x_mini_batch
        )
        dLdg = self.cost_derivative(output, y_mini_batch)
        for start, end in reversed(
            list(
                zip(
                    self.checkpoints[:-1],
                    self.checkpoints[1:][:-1] + [self.checkpoints[-1] + 1],
                )
            )
        ):
            # those copies are inefficient, but we do them to keep indexing simple
            acc_layer_input = layer_input.copy()
            acc_before_act = before_act.copy()
            self.feedforward_between_layers(start, end, acc_layer_input, acc_before_act)
            nabla_w, nabla_b, dLdg = self.backprop_between_layers(
                start, end, acc_layer_input, acc_before_act, dLdg
            )
            self.weights[start:end] = [
                w - eta * dw for w, dw in zip(self.weights[start:end], nabla_w)
            ]
            self.biases[start:end] = [
                b - eta * db for b, db in zip(self.biases[start:end], nabla_b)
            ]

    def evaluate(
        self, x_test_data: NDArray[float], y_test_data: NDArray[float]
    ) -> float:
        # Count the number of correct answers for test_data
        test_results = [
            (
                np.argmax(self.feedforward(x_test_data[i].reshape(1, 784)), axis=-1),
                np.argmax(y_test_data[i], axis=-1),
            )
            for i in range(len(x_test_data))
        ]
        # return accuracy
        return np.mean([int((x == y).item()) for (x, y) in test_results]).item()

    def cost_derivative(
        self, output_activations: NDArray[float], y: NDArray[float]
    ) -> NDArray[float]:
        return output_activations - y

    def SGD(
        self,
        training_data: Tuple[NDArray[float], NDArray[float]],
        epochs: int,
        mini_batch_size: int,
        eta: float,
        test_data: Optional[Tuple[NDArray[float], NDArray[float]]] = None,
    ):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[
                    i * mini_batch_size : (i * mini_batch_size + mini_batch_size)
                ]
                y_mini_batch = y_train[
                    i * mini_batch_size : (i * mini_batch_size + mini_batch_size)
                ]
                self.update_mini_batch(x_mini_batch, y_mini_batch, eta)
            if test_data:
                print(
                    "Epoch: {0}, Accuracy: {1}".format(j, self.evaluate(x_test, y_test))
                )
            else:
                print("Epoch: {0}".format(j))

In [None]:
## Debug your solution
# correctly implemented checkpointing should give similar results to the non-checkpointed network when seeds are fixed

np.random.seed(42)
network = NetworkWithCheckpoints([784, 20, 15, 10, 10], checkpoints=[2])
network.SGD(
    (x_train, y_train),
    epochs=1,
    mini_batch_size=1,
    eta=0.02,
    test_data=(x_test, y_test),
)

np.random.seed(42)
network = Network([784, 20, 15, 10, 10])
network.SGD(
    (x_train, y_train),
    epochs=1,
    mini_batch_size=1,
    eta=0.02,
    test_data=(x_test, y_test),
)

In [None]:
network = NetworkWithCheckpoints([784, 30, 30, 10], checkpoints=[1])
network.SGD(
    (x_train, y_train),
    epochs=5,
    mini_batch_size=1,
    eta=0.05,
    test_data=(x_test, y_test),
)  # per-example descend is really slow, try vectorizing it!
# Just so you know, un-vectorized version takes about 25-35s per epoch

# JAX Playground (Optional)
JAX is a framework that allows the creation of neural networks with numpy-like syntax.  
In this course, we will use Pytorch instead of JAX, but for this lab scenario, JAX can help us test our gradient computation implementation.  
Let's give it a try  

In [None]:
!pip3 install jax

In [None]:
import jax
import jax.numpy as jnp


def sigmoid(z: jax.Array):
    return 1.0 / (1.0 + jnp.exp(-z))


def sigmoid_prime(z: NDArray[float]):
    return sigmoid(z) * (1 - sigmoid(z))


key = jax.random.key(42)

key, subkey = jax.random.split(key)
w = jax.random.normal(subkey, (5, 5))
key, subkey = jax.random.split(key)
b = jax.random.normal(subkey, (5,))
x = jnp.arange(5, dtype=w.dtype).reshape(1, 5)


# Define a jax function
# We emphasize that function (not procedure)
# In fact there are more requirements for writing good
# jax code but this is just an example (see https://jax.readthedocs.io/en/latest/tutorials.html)
def forward(x: jax.Array, w: jax.Array, b: jax.Array) -> jax.Array:
    f = x @ w + b
    g = sigmoid(f)
    loss = g.sum()
    return loss, g


# this will calculate gradient for first, second, and third argument
# has_aux tells that in addition to loss our function returns something else
forward_backward = jax.value_and_grad(fun=forward, argnums=[0, 1, 2], has_aux=True)


def manual_backward(x, w, b):
    ## TODO
    ###{
    pass
    ###}
    return dx, dw, db


(loss, res), grad = forward_backward(x, w, b)
jax_dx, jax_dw, jax_db = grad
dx, dw, db = manual_backward(x, w, b)

print(
    f"""
diff dx = {jnp.mean(jnp.abs(jax_dx - dx))}
diff dw = {jnp.mean(jnp.abs(jax_dw - dw))}
diff db = {jnp.mean(jnp.abs(jax_db - db))}
emach = {np.finfo(dx.dtype).eps}
"""
)