<a href="https://colab.research.google.com/github/tomwphillips/jax-experiments/blob/master/MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MNIST classification with JAX
[JAX](https://github.com/google/jax) gives you automatic differentiation and GPU/TPU support for Python and numpy transformations.

Here I've had a play around with it (largely following the [MNIST example](https://github.com/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb) in the JAX repo). `jax.vmap` vectorizes numpy operations (for batching predictions) and `jax.grad` takes the derivative of the loss function with respect to the parameters.

Pretty neat. I like how lightweight it is compared to TensorFlow.

In [0]:
import jax
import jax.numpy as np
import jax.random as random
from jax.scipy.special import logsumexp
import tensorflow_datasets as tfds

# Load mnist

In [0]:
mnist, info = tfds.load('mnist', batch_size=-1, with_info=True)
mnist = tfds.as_numpy(mnist)
train, test = mnist['train'], mnist['test']

In [0]:
train['image'].shape, train['label'].shape

((60000, 28, 28, 1), (60000,))

In [0]:
test['image'].shape, test['label'].shape

((10000, 28, 28, 1), (10000,))

In [0]:
train_images = train['image'].reshape(len(train['image']), -1)
test_images = test['image'].reshape(len(test['image']), -1)

assert train_images.ndim == 2
assert test_images.ndim == 2

In [0]:
def one_hot_encode(x, k, dtype=np.float32):
  """One-hot encode x with k classes."""
  return np.array(x[:, None] == np.arange(k), dtype)

num_classes = info.features['label'].num_classes

assert num_classes == 10

train_labels = one_hot_encode(train['label'], num_classes)
test_labels = one_hot_encode(test['label'], num_classes)

assert train_labels.shape[-1] == num_classes
assert test_labels.shape[-1] == num_classes

# Initialize weights and bias for layers of network

In [0]:
def initalize_layer_parameters(input_units, output_units, key):
  weight_key, bias_key = random.split(key)
  weights = 1e-2 * random.normal(weight_key, (output_units, input_units))
  bias = 1e-2 * random.normal(bias_key, (output_units,))
  return weights, bias

In [0]:
weights, bias = initalize_layer_parameters(32, 16, random.PRNGKey(0))
assert weights.shape == (16, 32)
assert bias.shape == (16,)

In [0]:
def initialize_all_parameters(layer_units, key):
  keys = random.split(key, len(layer_units))
  return [initalize_layer_parameters(input_units, output_units, key)
          for input_units, output_units, key in zip(layer_units[:-1], layer_units[1:], keys)]

In [0]:
assert len(initialize_all_parameters([64, 32, 12], random.PRNGKey(0))) == 2

# Forward pass

## Define for a single image

In [0]:
def relu(x):
  return np.maximum(0, x)


def softmax(x):
  return x - logsumexp(x)


def predict(parameters, image):
  activations = image
  activation_functions = [*[relu] * (len(parameters) - 1), softmax]

  for (weights, bias), activation in zip(parameters, activation_functions):
    outputs = np.dot(weights, activations) + bias
    activations = activation(outputs)

  return activations

In [0]:
key = random.PRNGKey(52)
image = random.normal(key, (784,))
parameters = initialize_all_parameters([784, 512, 10], random.PRNGKey(0))
prediction = predict(parameters, image)
prediction.shape

(10,)

## Use jax.vmap to batch `predict`

`predict` won't work for a multiple images:

In [0]:
images = random.normal(random.PRNGKey(0), (128, 784,))

try:
  predictions = predict(parameters, images)
except TypeError as exception:
  print(exception)

Incompatible shapes for dot: got (512, 784) and (128, 784).


In [0]:
# in_axes=(None, 0) tells jax to not map over the first argument to predict and map over the batch dim of second argument
predict_batch = jax.vmap(predict, in_axes=(None, 0))
predictions = predict_batch(parameters, images)
predictions.shape

(128, 10)

# Training loop using `jax.grad` and `jax.jit`

In [0]:
def loss(parameters, images, targets):
  # TODO: what loss is this??
  predictions = predict_batch(parameters, images)
  return -np.sum(targets * predictions)

def accuracy(parameters, images, targets):
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(predict_batch(parameters, images), axis=1)
  return np.mean(predicted_class == target_class)

In [0]:
@jax.jit
def update_parameters(parameters, images, targets, learning_rate):
  gradients = jax.grad(loss)(parameters, images, targets)
  return [(w - (learning_rate * dw), b - (learning_rate * db))
          for (w, b), (dw, db) in zip(parameters, gradients)]

## Batch gradient descent

In [0]:
key = random.PRNGKey(52)
layers = [784, 512, 256, 10]
parameters = initialize_all_parameters(layers, key)
epochs = 100
learning_rate = 0.0000001

for epoch in range(epochs):
  parameters = update_parameters(parameters, train_images, train_labels, learning_rate)
  epoch_loss = loss(parameters, train_images, train_labels)
  epoch_accuracy = accuracy(parameters, train_images, train_labels)

  if (epoch + 1) % 10 == 0:
    print(f"Epoch: {epoch + 1}\tLoss: {epoch_loss:.2E}\t Accuracy: {epoch_accuracy:.2f}")

Epoch: 10	Loss: 4.91E+04	 Accuracy: 0.82
Epoch: 20	Loss: 3.42E+04	 Accuracy: 0.85
Epoch: 30	Loss: 3.00E+04	 Accuracy: 0.84
Epoch: 40	Loss: 2.37E+04	 Accuracy: 0.89
Epoch: 50	Loss: 3.34E+04	 Accuracy: 0.81
Epoch: 60	Loss: 1.96E+04	 Accuracy: 0.91
Epoch: 70	Loss: 1.83E+04	 Accuracy: 0.91
Epoch: 80	Loss: 1.73E+04	 Accuracy: 0.92
Epoch: 90	Loss: 1.66E+04	 Accuracy: 0.92
Epoch: 100	Loss: 1.60E+04	 Accuracy: 0.92


## Mini-batch gradient descent

In [0]:
def get_training_data(batch_size=128):
  """Returns generator of (X, y) arrays."""
  # as_supervised=True returns (X, y) instead of dict
  ds = tfds.load('mnist', batch_size=batch_size, split='train', as_supervised=True).prefetch(1)
  return tfds.as_numpy(ds)

In [0]:
key = random.PRNGKey(52)
layers = [784, 512, 256, 10]
parameters = initialize_all_parameters(layers, key)
epochs = 5
learning_rate = 0.00001

for epoch in range(epochs):
  for images, labels in get_training_data():
    images = images.reshape(len(images), -1)
    labels = one_hot_encode(labels, num_classes)
    parameters = update_parameters(parameters, images, labels, learning_rate)

  epoch_loss = loss(parameters, train_images, train_labels)
  epoch_train_accuracy = accuracy(parameters, train_images, train_labels)
  epoch_test_accuracy = accuracy(parameters, test_images, test_labels)

  print(f"Epoch: {epoch + 1}\t Loss: {epoch_loss:.2E}\t Accuracy: {epoch_train_accuracy:.2f}\t Test accuracy: {epoch_test_accuracy:.2f}")

Epoch: 1	 Loss: 1.61E+04	 Accuracy: 0.92	 Test accuracy: 0.92
Epoch: 2	 Loss: 1.22E+04	 Accuracy: 0.94	 Test accuracy: 0.94
Epoch: 3	 Loss: 9.91E+03	 Accuracy: 0.95	 Test accuracy: 0.95
Epoch: 4	 Loss: 8.35E+03	 Accuracy: 0.96	 Test accuracy: 0.96
Epoch: 5	 Loss: 7.21E+03	 Accuracy: 0.97	 Test accuracy: 0.96
