# Mnist Classification with Convolutional Neural Networks - with JAX

We start with a few imports

In [None]:
import time
import itertools

import numpy.random as npr
import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax import jit, grad, random
rng = random.PRNGKey(0)

from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu, LogSoftmax, Conv, Flatten, Identity


if 'google.colab' in str(get_ipython()):
    !git clone https://github.com/vincentadam87/intro_to_jax.git
    import sys  
    sys.path.insert(0,'/content/intro_to_jax/notebooks')
import datasets

Setting up a Convolutional Neural Network (CNN) model with stax 

In [None]:
# Creating random inital weights and biases for a neural network

# this also creates the forward model of the neural network from image x to class label probability

init_random_params, predict = stax.serial(
    Conv(16,(3,3), padding="SAME"),
    Relu,
    Conv(16, (3,3), padding="SAME"),
    Relu,
    Flatten,
    Dense(10),
    LogSoftmax
)

# Note: the parameters will be stored as list of lists (of weight + bias)

As before we need to declare the loss, accuracy and update functions

In [None]:
# the loss of the classifier: it is the log likelihood of the observations

def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))

# we can do with a metric a more interpretable metric to evaluate how well we do
@jit
def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)

In [None]:
# the update method to run the training loop

# Here we don't need to manually code the update. This will be done automatically later!

@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

Now the main learning routine

In [None]:
# setting up training parameters
step_size = 0.001
num_epochs = 10
batch_size = 128


In [None]:
# loading the data and preprocessing (not that here we don't flatten the images)

train_images, train_labels, test_images, test_labels = datasets.mnist(flatten_images=False)
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

print(train_images.shape)

In [None]:
# creating a data stream to easily get the batches
def data_stream():
    rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield train_images[batch_idx], train_labels[batch_idx]


Now the main training loop

In [None]:

# we initiate a SGD optimizer which provides the update function
opt_init, opt_update, get_params = optimizers.sgd(step_size)


batches = data_stream()

_, init_params = init_random_params(rng, (-1, 28, 28, 1))
opt_state = opt_init(init_params)
itercount = itertools.count()

# now iterate over epoch
# for each epoch, iterate over batches
# and update the weights
# at the end of the epoch print the training and test accuracies

# <SOLUTION
print("\nStarting training...")
for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
        if _%50 == 0:
            print('epoch %d/%d'%(epoch, num_epochs), 'batch: %d/%d'%(_, num_batches))
        opt_state = update(next(itercount), opt_state, next(batches))
    epoch_time = time.time() - start_time

    params = get_params(opt_state)
    train_acc = accuracy(params, (train_images, train_labels))
    test_acc = accuracy(params, (test_images, test_labels))
    print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
    print(f"Training set accuracy {train_acc}")
    print(f"Test set accuracy {test_acc}")
    
# SOLUTION>


# Interpreting the convolutional weights

Convolutional neural networks were inspired by receptive fields in visual area V1 of the brain.

We know that the visual system consists of a sequence of feature extractors,
starting from simple line, edge detectors.

Convolutional neural networks learn similar strategies from data!
Let's look at the learned feature extractors

In [None]:
# plot the weights of the first layer (the one directly applied to the input image)

# hint:  weights for the first layer are in params[0][o]

# <SOLUTION
fig, axarr = plt.subplots(4,4)
for i in range(4):
    for j in range(4):
        n = i*4 + j
        axarr[i,j].imshow(params[0][0][:,:, 0, n])
plt.show()
# SOLUTION>

### Questions
* Can you see the edge detector?
* Play with different architectures, can you get a better accuracy