This notebook follows jax official website: https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

# Hyperparameters

In [3]:
# randomly initialize weights and biases
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# initialize all layers for fully-connected neural network
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 521, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

# Auto-batching predictions

Let us first define our prediction function. Note that we’re defining this for a *single* image example. We’re going to use JAX’s `vmap` function to automatically handle mini-batches, with no performance penalty.

In [4]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)