# Mnist Classification with Neural Networks - from scratch with JAX

In [None]:
if 'google.colab' in str(get_ipython()):
    !git clone https://github.com/vincentadam87/intro_to_jax.git
    import sys  
    sys.path.insert(0,'/content/intro_to_jax/notebooks')

import datasets  # check you have the datasets.py module next to this notebook
from matplotlib import pyplot as plt

import time

import numpy as np
import numpy.random as npr

from jax import jit, grad
from jax.scipy.special import logsumexp
import jax.numpy as jnp

## What are deep neural networks?


Tensorflow is popular because it enables to fit popular models: deep neural networks.

A deep neural network is a model made of combination of simple artificial neurons, which are an abstraction of what a neuron is.

### Artificial neuron 
A neuron receives many inputs and spits an output in a non-linear fashion.
$$ y = \phi( W {\bf x} + b )$$
where $\phi$ is a non-linearity such as the sigmoid function.

__[Remark]__ bold characters represent vectors or matrices

### Hidden layer

A neural network typically combines these in a sequence, so the output of some neurons serves as input to other neurons. Here is an example of two neurons where the output $h_1$ of neuron $1$ is the input of neuron $2$.
Letter $h$ is used to refer to a `hidden` layer
$$ h_1 = \phi(W_1 {\bf x} + b_1) $$
$$ y = \phi(W_2 h_1 + b_2) $$

### Wide networks

In practice, a layer may be `wide` and consist of many neurons sharing the same input
$$ {\bf h}_1 = \phi(W_1 {\bf x} + b_1) $$


### Deep networks 

And a neuron network may be `deep` and consist of many layers
$$ {\bf h}_1 = \phi(W_1 {\bf x} + b_1) $$
$$ \vdots $$
$$ {\bf h}_K = \phi(W_K {\bf h}_{K-1} + b_K) $$
$$ y = \phi(W_{K+1} {\bf h}_K + b_K) $$




# Task : Image Classification

Let's look at the data

In [None]:

# the data loader provided splits both the images and labels into
# a training and test set
train_images, train_labels, test_images, test_labels = datasets.mnist()


__Question__

* Can you have a look at the data ? What are the shapes of X and Y?

* Can you plot some images and associated label?

* Do you think logistic regression would work for these complex inputs?


In [None]:

# print a few images and check the size/shape of the data
# <SOLUTION
print(train_images.shape)
fig, axarr = plt.subplots(4,4, figsize=(5,5))
for i, ax in enumerate(axarr.flatten()):
    ax.imshow(train_images[i].reshape(28,28))
    ax.set_title("label = %i" %np.argmax(train_labels[i]))
fig.tight_layout()
# SOLUTION>


We want to construct a general model of the form

\begin{align} 
{\bf h}_1 &= \phi(W_1 {\bf x} + b_1)\quad \text{  x is a 28 by 28 image}\\
&\vdots \\
{\bf h}_{K-1} &= \phi(W_{K-1} {\bf h}_{K-2} + b_{K-1}) \quad\text{  the nextwork has a total of K layers} \\
p &= \phi(W_{K} {\bf h}_{K-1} + b_{K}) \quad \text{  p is a vector of class probability of size 10}
\end{align}



In [None]:
# The forward model of the neural network from image x to class label probability


def predict(params, inputs):
    """
    Computing the a deep network
    
    :param params: the model parameters
    :param inputs: the input images
    
    :return: a vector of the predicted log probabilities of each class, log p(y=i|x) 
    """
    
    # sequential call of each layer
    # the output of a layer is the input of the next layer
    activations = inputs
    for w, b in params[:-1]:
        # outputs = ... # your code here
        outputs = jnp.dot(activations, w) + b
        # activations = ... your code here
        # <SOLUTION
        activations = jnp.tanh(outputs)
        # SOLUTION>
        
    # the final layer is treated separately
    final_w, final_b = params[-1]
    # outputs are the log probabilities of each class (logits = log p(y|x))
    logits = ... # your code here
    # <SOLUTION
    pre_logits = jnp.dot(activations, final_w) + final_b
    logits = pre_logits - logsumexp(pre_logits, axis=1, keepdims=True)
    # SOLUTION>
    return logits


In [None]:
# Creating random inital weights and biases for a neural network

def init_random_params(layer_sizes, scale=0.1, rng=npr.RandomState(0)):
    """
    :param scale: magnitude of the random noise initialisation
    :param layer_sizes: list of the input/output size of the layers of the network
    :param rng: Random Number Generator, an object to sample random variables
    
    :return: a list of lists of parameters for each layer (weight and bias)
    """

    return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

### The classification loss

We are here doing multiclass classification 

the loss is 
$$ l(w) = \sum_n \log\,p(y=y_n|x=x_n) $$

We are working with a one-hot encoding of the labels, (class = 2 $\to$ `y=[0,0,1,0,0,0,0,0,0]`)

This is simpler to implement as follows

$$ l(w) = \sum_n \sum_c \log\,p(y=c|x=x_n) * \delta(y_n = c)$$


In [None]:
# the loss of the classifier: it is the log likelihood of the observations

def loss(params, batch):
    """
    :param params: the parameters of the neural network
    :param batch: a list of the imput images and associated label
    
    :return: a scalar number, the loss to be minimized
    """
    inputs, targets = batch
    # your code here
    # <SOLUTION
    preds = predict(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))
    # SOLUTION>

In [None]:
# the update method to run gradient descent

@jit
def update(params, batch):
    grads = grad(loss)(params, batch)
    return [(w - learning_rate * dw, b - learning_rate * db)
            for (w, b), (dw, db) in zip(params, grads)]

In [None]:
# we can do with a metric a more interpretable metric to evaluate how well we do

def accuracy(params, batch):
    """
    The accuracy is the fraction of correct predictions.
    
    :param params: the parameters of the neural network
    :param batch: the imput images 
    """
    # <SOLUTION
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)
    # SOLUTION>


# Now the main learning routine

## Let's start with Logistic regression

Logistic regression corresponds to a 1 layer neural network

\begin{align} 
p &= \phi(W_{1} {\bf x} + b_{1}) \quad \text{  p is a vector of class probability of size 10}
\end{align}

In this code this corresponds to `layer_sizes = [784, 10]`

Later we'll add layers easily: `layer_sizes = [784, 1024, ..., 1024 , 10]`

In [None]:
# setting up the network structure
batch_size = 128
layer_sizes = [784, 10]  # this is logistic regression!
# layer_sizes = [784, 1024, 1024, 10] # try that later for deeper networks

# setting up the training parameters
learning_rate = 0.001
num_epochs = 30 # number of passes throught the whole dataset

In [None]:
# We are going to do stochastic gradient descent and evaluate the loss on batches of data

num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

In [None]:
# creating a data stream to easily get the batches
def data_stream():
    rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield train_images[batch_idx], train_labels[batch_idx]


Now the main training loop

In [None]:

# preparing the data stream
batches = data_stream()


# initializing the parameters
params = init_random_params(layer_sizes)

# iterate over epochs (an epoch is one full sweep through the entire dataset)
# for each epoch, iterate over all batches
# print the accuracy after each epoch


# <SOLUTION
for epoch in range(num_epochs):
    start_time = time.time()
    # iterating over all batches in the dataset
    for _ in range(num_batches):
        params = update(params, next(batches))
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, (train_images, train_labels))
    test_acc = accuracy(params, (test_images, test_labels))
    print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
    print(f"Training set accuracy {train_acc}")
    print(f"Test set accuracy {test_acc}")
# SOLUTION>


# BONUS: Interpreting the learned features

Let's have a look at the weights of the network (for logistic regression)
Since the weights for each class is the size of an image, it can be plotted and interpreted


In [None]:
# plot the feature image (the w) for each class 

# <SOLUTION
# extract the weights from the params
weights = params[0][0].reshape(28,28,10)
biases = params[0][1]

# plot the weights
fig, axarr = plt.subplots(2,5, figsize=(8, 5))
for i, ax in enumerate(axarr.flatten()):
    ax.imshow(weights[..., i])
    ax.set_title("feature %i" %i)
fig.tight_layout()

# SOLUTION>

## BONUS:  Deeper networks

Now let's go deeper!

* can you re-rerun this code by changing the neural network architecture (adding layers)?

just change `layer_sizes` in the code you have (maybe note down the previous final accuracy before!)


* What happens to the accuracy?